From b41f530d1eb98e72404d5eb8119e9267ffa8a592 Mon Sep 17 00:00:00 2001 From: Azalea Date: Sat, 9 May 2026 21:28:38 -0400 Subject: [PATCH] [O] Parallel everything (#5) --- README.md | 2 +- src/config.rs | 2 +- src/main.rs | 1 + src/parallel.rs | 73 ++++++++++++++++ src/provider.rs | 54 +++++++++++- src/sync.rs | 195 ++++++++++++++++++++++--------------------- src/webhook.rs | 72 +++++----------- tests/unit/config.rs | 1 + 8 files changed, 252 insertions(+), 148 deletions(-) create mode 100644 src/parallel.rs diff --git a/README.md b/README.md index fb68040..28d8ba6 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ Retry only repositories that failed during the previous non-dry-run sync: refray sync --retry-failed ``` -Control parallelism for sync, serve, and webhook commands in config: +Control parallelism for sync, serve, and webhook commands in config. The default is 10 workers: ```toml jobs = 8 diff --git a/src/config.rs b/src/config.rs index 36604ac..e55d717 100644 --- a/src/config.rs +++ b/src/config.rs @@ -9,7 +9,7 @@ use regex::Regex; use serde::{Deserialize, Serialize}; const APP_NAME: &str = "refray"; -pub const DEFAULT_JOBS: usize = 4; +pub const DEFAULT_JOBS: usize = 10; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct Config { diff --git a/src/main.rs b/src/main.rs index 4632683..419b2ff 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod config; mod git; mod interactive; mod logging; +mod parallel; mod provider; mod state; mod sync; diff --git a/src/parallel.rs b/src/parallel.rs new file mode 100644 index 0000000..22c53b0 --- /dev/null +++ b/src/parallel.rs @@ -0,0 +1,73 @@ +use std::collections::VecDeque; +use std::sync::{Arc, Mutex, mpsc}; +use std::thread; + +use anyhow::{Context, Result, bail}; + +pub fn map(items: Vec, jobs: usize, f: F) -> Result> +where + I: Send, + O: Send, + F: Fn(I) -> Result + Sync, +{ + if jobs == 0 { + bail!("jobs must be at least 1"); + } + if items.is_empty() { + return Ok(Vec::new()); + } + + let worker_count = jobs.min(items.len()); + let queue = Arc::new(Mutex::new(VecDeque::from(items))); + let (sender, receiver) = mpsc::channel(); + + thread::scope(|scope| { + for _ in 0..worker_count { + let queue = Arc::clone(&queue); + let sender = sender.clone(); + let f = &f; + scope.spawn(move || { + while let Some(item) = pop_item(&queue) { + if sender.send(f(item)).is_err() { + break; + } + } + }); + } + drop(sender); + + collect_results(receiver) + }) +} + +fn pop_item(queue: &Arc>>) -> Option { + queue + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .pop_front() +} + +fn collect_results(receiver: mpsc::Receiver>) -> Result> { + let mut outputs = Vec::new(); + let mut first_failure = None; + let mut failure_count = 0; + + for result in receiver { + match result { + Ok(output) => outputs.push(output), + Err(error) => { + failure_count += 1; + first_failure.get_or_insert(error); + } + } + } + + match (failure_count, first_failure) { + (0, None) => Ok(outputs), + (1, Some(error)) => Err(error), + (_, Some(error)) => { + Err(error).with_context(|| format!("{failure_count} parallel tasks failed")) + } + _ => unreachable!(), + } +} diff --git a/src/provider.rs b/src/provider.rs index bc3d1a9..3ca2856 100644 --- a/src/provider.rs +++ b/src/provider.rs @@ -1,13 +1,17 @@ use std::collections::HashMap; use anyhow::{Context, Result, anyhow, bail}; +use console::style; use reqwest::blocking::{Client, Response}; use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT}; use serde::Deserialize; use serde_json::json; use url::Url; -use crate::config::{EndpointConfig, NamespaceKind, ProviderKind, SiteConfig, Visibility}; +use crate::config::{ + Config, EndpointConfig, MirrorConfig, NamespaceKind, ProviderKind, RepoNameFilter, SiteConfig, + Visibility, +}; #[derive(Clone, Debug)] pub struct RemoteRepo { @@ -36,6 +40,54 @@ pub struct PullRequestInfo { pub url: Option, } +pub fn list_mirror_repos( + config: &Config, + mirror: &MirrorConfig, + repo_filter: &RepoNameFilter, + jobs: usize, +) -> Result> { + let endpoint_jobs = mirror + .endpoints + .iter() + .cloned() + .enumerate() + .collect::>(); + let worker_count = jobs.min(endpoint_jobs.len()); + if worker_count > 1 { + crate::logln!( + " {} listing repositories with {} workers", + style("jobs").cyan().bold(), + worker_count + ); + } + + let mut listed = crate::parallel::map(endpoint_jobs, jobs, |(index, endpoint)| { + let site = config.site(&endpoint.site).unwrap(); + let client = ProviderClient::new(site)?; + crate::logln!( + " {} {}", + style("list").cyan().bold(), + style(endpoint.label()).dim() + ); + let repos = client + .list_repos(&endpoint) + .with_context(|| format!("failed to list repos for {}", endpoint.label()))?; + let repos = repos + .into_iter() + .filter(|repo| mirror.sync_visibility.matches_private(repo.private)) + .filter(|repo| repo_filter.matches(&repo.name)) + .map(|repo| EndpointRepo { + endpoint: endpoint.clone(), + repo, + }) + .collect::>(); + Ok((index, repos)) + })?; + listed.sort_by_key(|(index, _)| *index); + + Ok(listed.into_iter().flat_map(|(_, repos)| repos).collect()) +} + pub struct ProviderClient<'a> { site: &'a SiteConfig, token: String, diff --git a/src/sync.rs b/src/sync.rs index 7d8ff54..7930167 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -17,7 +17,9 @@ use crate::git::{ is_disabled_repository_error, ls_remote_refs, safe_remote_name, }; use crate::logging; -use crate::provider::{EndpointRepo, ProviderClient, PullRequestRequest, repos_by_name}; +use crate::provider::{ + EndpointRepo, ProviderClient, PullRequestRequest, list_mirror_repos, repos_by_name, +}; use crate::webhook; mod output; @@ -163,7 +165,8 @@ fn sync_group( .unwrap_or(mirror.create_missing); let repo_filter = mirror.repo_filter()?; - let all_endpoint_repos = list_group_repos(context.config, mirror, &repo_filter)?; + let all_endpoint_repos = + list_mirror_repos(context.config, mirror, &repo_filter, context.options.jobs)?; if !context.options.dry_run { webhook::ensure_configured_webhooks( context.config, @@ -258,6 +261,7 @@ fn sync_group( let queue = Arc::new(Mutex::new(repo_jobs)); let (sender, receiver) = mpsc::channel(); let use_status_area = worker_count > 1; + let jobs = context.options.jobs; let _status_guard = use_status_area.then(|| logging::start_status_area(worker_count)); let failures = thread::scope(|scope| { for worker_id in 0..worker_count { @@ -280,6 +284,7 @@ fn sync_group( work_dir, redactor: redactor.clone(), dry_run, + jobs, }; let result = sync_repo( &repo_context, @@ -340,50 +345,19 @@ fn sync_group( }); if create_missing && !context.options.dry_run { - let repos = list_group_repos(context.config, mirror, &repo_filter)?; + let repos = list_mirror_repos(context.config, mirror, &repo_filter, jobs)?; webhook::ensure_configured_webhooks( context.config, mirror, &repos, context.work_dir, - context.options.jobs, + jobs, )?; } Ok(failures) } -fn list_group_repos( - config: &Config, - mirror: &MirrorConfig, - repo_filter: &RepoNameFilter, -) -> Result> { - let mut all_endpoint_repos = Vec::new(); - for endpoint in &mirror.endpoints { - let site = config.site(&endpoint.site).unwrap(); - let client = ProviderClient::new(site)?; - crate::logln!( - " {} {}", - style("list").cyan().bold(), - style(endpoint.label()).dim() - ); - let repos = client - .list_repos(endpoint) - .with_context(|| format!("failed to list repos for {}", endpoint.label()))?; - for repo in repos - .into_iter() - .filter(|repo| mirror.sync_visibility.matches_private(repo.private)) - .filter(|repo| repo_filter.matches(&repo.name)) - { - all_endpoint_repos.push(EndpointRepo { - endpoint: endpoint.clone(), - repo, - }); - } - } - Ok(all_endpoint_repos) -} - fn sync_candidate_repo_names( repos: &HashMap>, ref_state: &RefState, @@ -434,57 +408,71 @@ struct RepoWorkerFailure { } fn ensure_missing_repos( - config: &Config, - mirror: &MirrorConfig, + context: &RepoSyncContext<'_>, repo_name: &str, existing: &mut Vec, create_missing: bool, - dry_run: bool, ) -> Result<()> { let present = existing .iter() .map(|repo| repo.endpoint.clone()) .collect::>(); let template = existing.first().map(|repo| repo.repo.clone()); + let missing = context + .mirror + .endpoints + .iter() + .filter(|endpoint| !present.contains(*endpoint)) + .cloned() + .collect::>(); - for endpoint in &mirror.endpoints { - if present.contains(endpoint) { - continue; - } - if !create_missing { + if !create_missing || context.dry_run { + for endpoint in &missing { + if !create_missing { + crate::logln!( + " {} {} missing on {} ({})", + style("skip").yellow().bold(), + style(repo_name).cyan(), + style(endpoint.label()).dim(), + style("creation disabled").dim() + ); + continue; + } crate::logln!( - " {} {} missing on {} ({})", - style("skip").yellow().bold(), + " {} {} {}", + style("create").green().bold(), style(repo_name).cyan(), - style(endpoint.label()).dim(), - style("creation disabled").dim() + style(format!("on {}", endpoint.label())).dim() ); - continue; } + return Ok(()); + } + let description = template.and_then(|repo| repo.description); + let expected_private = matches!( + &context.mirror.visibility, + crate::config::Visibility::Private + ); + let create_jobs = missing.into_iter().enumerate().collect::>(); + let mut created = crate::parallel::map(create_jobs, context.jobs, |(index, endpoint)| { crate::logln!( " {} {} {}", style("create").green().bold(), style(repo_name).cyan(), style(format!("on {}", endpoint.label())).dim() ); - if dry_run { - continue; - } - let site = config.site(&endpoint.site).unwrap(); + let site = context.config.site(&endpoint.site).unwrap(); let client = ProviderClient::new(site)?; let created = client .create_repo( - endpoint, + &endpoint, repo_name, - &mirror.visibility, - template - .as_ref() - .and_then(|repo| repo.description.as_deref()), + &context.mirror.visibility, + description.as_deref(), ) .with_context(|| format!("failed to create {} on {}", repo_name, endpoint.label()))?; - if created.private != matches!(mirror.visibility, crate::config::Visibility::Private) { + if created.private != expected_private { crate::logln!( " {} created {} on {}, but provider reported a different visibility than requested", style("warn").yellow().bold(), @@ -492,11 +480,16 @@ fn ensure_missing_repos( style(endpoint.label()).dim() ); } - existing.push(EndpointRepo { - endpoint: endpoint.clone(), - repo: created, - }); - } + Ok(( + index, + EndpointRepo { + endpoint, + repo: created, + }, + )) + })?; + created.sort_by_key(|(index, _)| *index); + existing.extend(created.into_iter().map(|(_, repo)| repo)); Ok(()) } @@ -507,6 +500,7 @@ struct RepoSyncContext<'a> { work_dir: &'a Path, redactor: Redactor, dry_run: bool, + jobs: usize, } #[derive(Default)] @@ -592,14 +586,7 @@ fn sync_repo( } } - ensure_missing_repos( - context.config, - context.mirror, - repo_name, - repos, - create_missing, - context.dry_run, - )?; + ensure_missing_repos(context, repo_name, repos, create_missing)?; if repos.len() < 2 { crate::logln!( @@ -729,26 +716,30 @@ fn delete_repos( repos: &[EndpointRepo], target_remotes: &[String], ) -> Result<()> { - for repo in repos { - let remote_name = remote_name_for_endpoint_repo(repo); - if !target_remotes.contains(&remote_name) { - continue; + let delete_jobs = repos + .iter() + .filter(|repo| target_remotes.contains(&remote_name_for_endpoint_repo(repo))) + .cloned() + .collect::>(); + if context.dry_run { + for repo in &delete_jobs { + crate::logln!( + " {} {} {}", + style("would delete").red().bold(), + style(repo_name).cyan(), + style(format!("from {}", repo.endpoint.label())).dim() + ); } + return Ok(()); + } + + crate::parallel::map(delete_jobs, context.jobs, |repo| { crate::logln!( " {} {} {}", - style(if context.dry_run { - "would delete" - } else { - "delete" - }) - .red() - .bold(), + style("delete").red().bold(), style(repo_name).cyan(), style(format!("from {}", repo.endpoint.label())).dim() ); - if context.dry_run { - continue; - } let site = context.config.site(&repo.endpoint.site).unwrap(); let client = ProviderClient::new(site)?; client @@ -760,7 +751,8 @@ fn delete_repos( repo.endpoint.label() ) })?; - } + Ok(()) + })?; Ok(()) } @@ -803,15 +795,20 @@ fn check_remote_refs( repo_name: &str, remotes: &[RemoteSpec], ) -> Result>> { - let mut refs = BTreeMap::new(); - for remote in remotes { + enum RemoteRefCheck { + Found(String, RemoteRefState), + Blocked, + } + + let ref_jobs = remotes.to_vec(); + let results = crate::parallel::map(ref_jobs, context.jobs, |remote| { crate::logln!( " {} {}", style("check refs").cyan().bold(), style(&remote.display).dim() ); - let snapshot = match ls_remote_refs(remote, &context.redactor) { - Ok(snapshot) => snapshot, + match ls_remote_refs(&remote, &context.redactor) { + Ok(snapshot) => Ok(RemoteRefCheck::Found(remote.name, snapshot.into())), Err(error) if is_disabled_repository_error(&error) => { crate::logln!( " {} {} {}", @@ -819,14 +816,22 @@ fn check_remote_refs( style(repo_name).cyan(), style(format!("provider blocked access on {}", remote.display)).dim() ); - return Ok(None); + Ok(RemoteRefCheck::Blocked) } Err(error) => { - return Err(error) - .with_context(|| format!("failed to check refs for {}", remote.display)); + Err(error).with_context(|| format!("failed to check refs for {}", remote.display)) } - }; - refs.insert(remote.name.clone(), snapshot.into()); + } + })?; + + let mut refs = BTreeMap::new(); + for result in results { + match result { + RemoteRefCheck::Found(remote, refs_for_remote) => { + refs.insert(remote, refs_for_remote); + } + RemoteRefCheck::Blocked => return Ok(None), + } } Ok(Some(refs)) } diff --git a/src/webhook.rs b/src/webhook.rs index c09802e..fd28b30 100644 --- a/src/webhook.rs +++ b/src/webhook.rs @@ -18,7 +18,7 @@ use crate::config::{ Config, EndpointConfig, MirrorConfig, ProviderKind, RepoNameFilter, default_work_dir, validate_config, }; -use crate::provider::{EndpointRepo, ProviderClient, RemoteRepo}; +use crate::provider::{EndpointRepo, ProviderClient, RemoteRepo, list_mirror_repos}; use crate::state::{load_toml_or_default, save_toml}; use crate::sync::{SyncOptions, sync_all}; @@ -189,31 +189,17 @@ pub fn install_webhooks(config: &Config, options: WebhookInstallOptions) -> Resu ); let repo_filter = mirror.repo_filter()?; let mut tasks = Vec::new(); - for endpoint in &mirror.endpoints { - let site = config.site(&endpoint.site).unwrap(); - let client = ProviderClient::new(site)?; - crate::logln!( - " {} {}", - style("list").cyan().bold(), - style(endpoint.label()).dim() - ); - let repos = client - .list_repos(endpoint) - .with_context(|| format!("failed to list repos for {}", endpoint.label()))?; - for repo in repos - .into_iter() - .filter(|repo| webhook_repo_matches(mirror, &repo_filter, repo)) - { - tasks.push(WebhookInstallTask { - site: site.clone(), - group: mirror.name.clone(), - endpoint: endpoint.clone(), - repo, - url: options.url.clone(), - secret: options.secret.clone(), - dry_run: options.dry_run, - }); - } + for endpoint_repo in list_mirror_repos(config, mirror, &repo_filter, options.jobs)? { + let site = config.site(&endpoint_repo.endpoint.site).unwrap(); + tasks.push(WebhookInstallTask { + site: site.clone(), + group: mirror.name.clone(), + endpoint: endpoint_repo.endpoint, + repo: endpoint_repo.repo, + url: options.url.clone(), + secret: options.secret.clone(), + dry_run: options.dry_run, + }); } run_install_tasks(tasks, options.jobs, Arc::clone(&state))?; } @@ -242,30 +228,16 @@ pub fn uninstall_webhooks(config: &Config, options: WebhookUninstallOptions) -> style(&mirror.name).bold() ); let repo_filter = mirror.repo_filter()?; - for endpoint in &mirror.endpoints { - let site = config.site(&endpoint.site).unwrap(); - let client = ProviderClient::new(site)?; - crate::logln!( - " {} {}", - style("list").cyan().bold(), - style(endpoint.label()).dim() - ); - let repos = client - .list_repos(endpoint) - .with_context(|| format!("failed to list repos for {}", endpoint.label()))?; - for repo in repos - .into_iter() - .filter(|repo| webhook_repo_matches(mirror, &repo_filter, repo)) - { - tasks.push(WebhookUninstallTask { - group: mirror.name.clone(), - site: site.clone(), - endpoint: endpoint.clone(), - repo, - url: options.url.clone(), - dry_run: options.dry_run, - }); - } + for endpoint_repo in list_mirror_repos(config, mirror, &repo_filter, options.jobs)? { + let site = config.site(&endpoint_repo.endpoint.site).unwrap(); + tasks.push(WebhookUninstallTask { + group: mirror.name.clone(), + site: site.clone(), + endpoint: endpoint_repo.endpoint, + repo: endpoint_repo.repo, + url: options.url.clone(), + dry_run: options.dry_run, + }); } } let removed_keys = run_uninstall_tasks(tasks, options.jobs)?; diff --git a/tests/unit/config.rs b/tests/unit/config.rs index 085415b..594dd17 100644 --- a/tests/unit/config.rs +++ b/tests/unit/config.rs @@ -88,6 +88,7 @@ fn env_token_form_is_rejected() { fn config_defaults_jobs() { let config: Config = toml::from_str("").unwrap(); + assert_eq!(DEFAULT_JOBS, 10); assert_eq!(config.jobs, DEFAULT_JOBS); }