diff --git a/Cargo.toml b/Cargo.toml index f02c15b..b8f53de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,9 @@ edition = "2021" [dependencies] ccutils = { git = "https://git.zhgsun.com:8089/jiachao2130/ccutils.git", version = "0.1.0" } clap = { version = "4.0", features = ["derive"] } +lazy_static = { version = "1.5" } +reqwest = { version = "0.12", features = ["stream"] } +scraper = { version = "0.19" } serde = { version = "1", features = ["serde_derive"] } tokio = { version = "1.38" } toml = { version = "0.8" } diff --git a/src/cli.rs b/src/cli.rs index 4cc52c6..f76ba97 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -28,4 +28,4 @@ pub(crate) struct Cli { /// 从命令行环境变量读取并转换为 `Cli` pub(crate) fn parse() -> Cli { Cli::parse() -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 43da8d7..2519390 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,16 @@ +use std::path::Path; + pub use ccutils::Result; mod cli; mod config; +use config::{Config, WebSycnConf}; mod syncer; +pub const MAX_TASKS: usize = 10; +pub const RETRIES: usize = 5; +pub const CONF: &str = "/etc/cuvars/cuweb-syncer"; + pub fn cumain() -> Result<()> { let cli = cli::parse(); @@ -11,9 +18,35 @@ pub fn cumain() -> Result<()> { match (cli.debug, cli.quiet) { (true, _) => std::env::set_var("RUST_LOG", "debug"), (false, true) => std::env::set_var("RUST_LOG", "error"), - _ => {}, + _ => {} } ccutils::set_up_logging()?; + let config = { + if cli.from.is_some() && cli.dest.is_some() { + let mut config = Config::new(); + let websync_conf = + WebSycnConf::new(cli.from.clone().unwrap(), cli.dest.clone().unwrap()); + config.inner.insert("Task".to_string(), websync_conf); + config + } else { + let default = Path::new(CONF); + let conf = if default.is_file() { + CONF + } else { + "cuweb-syncer" + }; + match Config::load(conf) { + Ok(conf) => conf, + Err(e) => { + panic!("Failed to load {conf}: {e}"); + } + } + } + }; + + let rt = ccutils::async_runtime()?; + + let _ = rt.block_on(syncer::run(&config))?; Ok(()) -} \ No newline at end of file +} diff --git a/src/syncer.rs b/src/syncer.rs index e69de29..eb6f39d 100644 --- a/src/syncer.rs +++ b/src/syncer.rs @@ -0,0 +1,249 @@ +use std::{ + sync::{Arc, Mutex}, + time::Duration, +}; + +use crate::{config::Config, MAX_TASKS, RETRIES}; +use ccutils::{ + file::async_download, + tracing::{debug, error, info, warn}, +}; +use lazy_static::lazy_static; +use reqwest::Client; +use scraper::{Html, Selector}; +use tokio::{ + fs, + sync::{ + mpsc::{self, Receiver, Sender}, + Semaphore, + }, + time::sleep, +}; + +lazy_static! { + /// 任务计数器,用于跟踪任务数量 + static ref COUNTER: Arc> = Arc::new(Mutex::new(0)); +} + +/// 服务器结构体,用于管理下载任务 +struct Server { + pub task_sender: Sender, + task_recv: Option>, + res_sender: Sender>, + res_recv: Option>>, +} + +impl Server { + /// 创建一个新的服务器实例 + pub fn new() -> Self { + let (task_sender, task_recv) = mpsc::channel(MAX_TASKS); + let (res_sender, res_recv) = mpsc::channel(100); + let task_recv = Some(task_recv); + let res_recv = Some(res_recv); + + Server { + task_sender, + task_recv, + res_sender, + res_recv, + } + } + + /// 运行服务器,处理下载任务 + pub async fn run(&mut self, config: &Config) -> crate::Result<()> { + let mut task_recv = self.task_recv.take().unwrap(); + let mut res_recv = self.res_recv.take().unwrap(); + + // 只有 `Server::result` 退出,才是正常,即所有的任务均完成、或多次失败后不再重试 + tokio::select! { + _ = self.websync(config) => {}, + _ = Server::dispatcher(&self.task_sender, &mut task_recv, &self.res_sender) => {}, + _ = Server::result(&mut res_recv) => {} + } + Ok(()) + } + + // 读取配置文件并开始执行访问、下载 + async fn websync(&self, config: &Config) -> crate::Result<()> { + let websyncers = &config.inner; + for (task, conf) in websyncers { + info!("Start to run {task} sync task..."); + Server::download_directory( + self.task_sender.clone(), + &conf.from(), + &conf.dest(), + ) + .await?; + } + // 这里暂没想好如何处理,总之不能退出 + sleep(Duration::from_secs(1000000)).await; + + Ok(()) + } + + /// 调度任务,将任务分配给下载器 + async fn dispatcher( + task_sender: &Sender, + task_recv: &mut Receiver, + res_sender: &Sender>, + ) -> crate::Result<()> { + debug!("Server::dispacher starting..."); + // 初始化一个信号量,使最多同时有 `MAX_TASKS` 个下载任务 + let semaphore = Arc::new(Semaphore::new(MAX_TASKS)); + while let Some(mut task) = task_recv.recv().await { + // 失败次数超过 `RETRIES` 后即不再尝试下载 + if task.retries > RETRIES { + error!("Retried {RETRIES} times, failed to download {}", task.url); + let _ = res_sender.send(Err(task)).await; + } else { + task.retries += 1; + // 只有同时存在的任务数小于 `MAX_TASKS` 时才会开始下载 + let _semaphore = semaphore.clone(); + let _permit = _semaphore.acquire().await?; + debug!("avaliable task num: {}", _semaphore.available_permits()); + tokio::spawn(Server::download( + task_sender.clone(), + res_sender.clone(), + task, + )); + // 每次连接间隔 100 ms + sleep(Duration::from_millis(100)).await; + } + } + + Ok(()) + } + + async fn download_directory( + task_sender: Sender, + url: &str, + path: &str, + ) -> crate::Result<()> { + let client = Client::new(); + let response = client.get(url).send().await?.text().await?; + let document = Html::parse_document(&response); + let selector = Selector::parse("a").unwrap(); + + fs::create_dir_all(path).await?; + debug!("Create local directory: {path}"); + for element in document.select(&selector) { + if let Some(href) = element.value().attr("href") { + if href.starts_with("../") { + continue; + } + // 处理目录的情况 + if href.ends_with('/') { + let new_url = format!("{}/{}", url.trim_end_matches('/'), href); + debug!("Found directory: {new_url}"); + let new_path = format!("{}/{}", path.trim_end_matches('/'), href); + // 因为是递归,所以需要 `pin` 起来 + let future = Box::pin(Server::download_directory( + task_sender.clone(), + &new_url, + &new_path, + )); + future.await?; + } else { + // 若是文件,则直接下载 + let file_url = format!("{}/{}", url.trim_end_matches('/'), href); + let file_path = format!("{}/{}", path.trim_end_matches('/'), href); + let task = Task::new(file_url, file_path); + // 此处使用 COUNTER 计数,即分发下载的任务计数 + { + let mut counter = COUNTER.lock().unwrap(); + *counter += 1; + } + task_sender.send(task).await?; + } + } + } + + Ok(()) + } + + /// 下载任务,处理单个文件的下载 + async fn download( + task_sender: Sender, + res_sender: Sender>, + task: Task, + ) -> crate::Result<()> { + match async_download(task.url.clone(), Some(task.target.clone())).await { + // 若下载成功,则发送至 res channel。 + Ok(()) => { + debug!("Successed synced {}", task.url); + match res_sender.send(Ok(task)).await { + Ok(()) => {} + Err(e) => { + error!("Send task failed: {e}"); + } + } + } + // 如果下载失败,则发回 dispatcher 进行重试 + Err(e) => { + warn!("Failed to download {}: {}", task.url, e); + match task_sender.send(task).await { + Ok(()) => {} + Err(e) => { + error!("Send task res failed: {e}"); + } + } + } + } + + Ok(()) + } + + pub async fn result(res_recv: &mut Receiver>) -> crate::Result<()> { + debug!("Server::result() is running..."); + let mut successed = vec![]; + let mut failed = vec![]; + while let Some(res) = res_recv.recv().await { + match res { + Ok(task) => { + successed.push(task); + } + Err(task) => failed.push(task), + } + + // 如果所有的下载任务已完成,则返回并退出 + let counter = COUNTER.lock().unwrap(); + if *counter == successed.len() + failed.len() { + info!( + "All {} tasks done, successed {}, failed {}.", + counter, + successed.len(), + failed.len() + ); + break; + } + } + + Ok(()) + } +} + +struct Task { + pub url: String, + pub target: String, + pub retries: usize, +} + +impl Task { + pub fn new(url: String, target: String) -> Self { + let retries = 0; + Task { + url, + target, + retries, + } + } +} + +/// 开始从配置 `config` 里执行同步任务 +pub(crate) async fn run(config: &Config) -> crate::Result<()> { + info!("CUWEB-SYNCER starting..."); + let mut server = Server::new(); + server.run(config).await?; + + Ok(()) +}