From b55c4ae0a3c8692b9b09730692f3940ad00564a2 Mon Sep 17 00:00:00 2001 From: Benjamin Saunders Date: Sun, 6 Nov 2022 09:36:52 -0800 Subject: [PATCH] git: move progress callback into a struct --- lib/src/git.rs | 82 +++++++++++++++++++++++++------------------ lib/tests/test_git.rs | 71 ++++++++++++++++++++++++++++++++----- src/commands.rs | 13 +++---- 3 files changed, 115 insertions(+), 51 deletions(-) diff --git a/lib/src/git.rs b/lib/src/git.rs index 36ef447b2..9b313c52b 100644 --- a/lib/src/git.rs +++ b/lib/src/git.rs @@ -17,7 +17,7 @@ use std::fs::OpenOptions; use std::io::{Read, Write}; use std::sync::Arc; -use git2::{Oid, RemoteCallbacks}; +use git2::Oid; use itertools::Itertools; use thiserror::Error; @@ -280,7 +280,7 @@ pub fn fetch( mut_repo: &mut MutableRepo, git_repo: &git2::Repository, remote_name: &str, - progress: Option<&mut dyn FnMut(&Progress)>, + callbacks: RemoteCallbacks<'_>, ) -> Result, GitFetchError> { let mut remote = git_repo @@ -298,7 +298,7 @@ pub fn fetch( let mut proxy_options = git2::ProxyOptions::new(); proxy_options.auto(); fetch_options.proxy_options(proxy_options); - let callbacks = create_remote_callbacks(progress); + let callbacks = callbacks.into_git(); fetch_options.remote_callbacks(callbacks); let refspec: &[&str] = &[]; remote.download(refspec, Some(&mut fetch_options))?; @@ -435,7 +435,7 @@ fn push_refs( let mut proxy_options = git2::ProxyOptions::new(); proxy_options.auto(); push_options.proxy_options(proxy_options); - let mut callbacks = create_remote_callbacks(None); + let mut callbacks = RemoteCallbacks::default().into_git(); callbacks.push_update_reference(|refname, status| { // The status is Some if the ref update was rejected if status.is_none() { @@ -466,39 +466,53 @@ fn push_refs( } } -fn create_remote_callbacks(progress_cb: Option<&mut dyn FnMut(&Progress)>) -> RemoteCallbacks<'_> { - let mut callbacks = git2::RemoteCallbacks::new(); - if let Some(progress_cb) = progress_cb { - callbacks.transfer_progress(move |progress| { - progress_cb(&Progress { - bytes_downloaded: if progress.received_objects() < progress.total_objects() { - Some(progress.received_bytes() as u64) - } else { - None - }, - overall: (progress.indexed_objects() + progress.indexed_deltas()) as f32 - / (progress.total_objects() + progress.total_deltas()) as f32, +#[non_exhaustive] +#[derive(Default)] +pub struct RemoteCallbacks<'a> { + pub progress: Option<&'a mut dyn FnMut(&Progress)>, +} + +impl<'a> RemoteCallbacks<'a> { + fn into_git(self) -> git2::RemoteCallbacks<'a> { + let mut callbacks = git2::RemoteCallbacks::new(); + if let Some(progress_cb) = self.progress { + callbacks.transfer_progress(move |progress| { + progress_cb(&Progress { + bytes_downloaded: if progress.received_objects() < progress.total_objects() { + Some(progress.received_bytes() as u64) + } else { + None + }, + overall: (progress.indexed_objects() + progress.indexed_deltas()) as f32 + / (progress.total_objects() + progress.total_deltas()) as f32, + }); + true }); - true - }); - } - // TODO: We should expose the callbacks to the caller instead -- the library - // crate shouldn't look in $HOME etc. - callbacks.credentials(|_url, username_from_url, allowed_types| { - if allowed_types.contains(git2::CredentialType::SSH_KEY) { - if std::env::var("SSH_AUTH_SOCK").is_ok() || std::env::var("SSH_AGENT_PID").is_ok() { - return git2::Cred::ssh_key_from_agent(username_from_url.unwrap()); - } - if let Ok(home_dir) = std::env::var("HOME") { - let key_path = std::path::Path::new(&home_dir).join(".ssh").join("id_rsa"); - if key_path.is_file() { - return git2::Cred::ssh_key(username_from_url.unwrap(), None, &key_path, None); + } + // TODO: We should expose the callbacks to the caller instead -- the library + // crate shouldn't look in $HOME etc. + callbacks.credentials(|_url, username_from_url, allowed_types| { + if allowed_types.contains(git2::CredentialType::SSH_KEY) { + if std::env::var("SSH_AUTH_SOCK").is_ok() || std::env::var("SSH_AGENT_PID").is_ok() + { + return git2::Cred::ssh_key_from_agent(username_from_url.unwrap()); + } + if let Ok(home_dir) = std::env::var("HOME") { + let key_path = std::path::Path::new(&home_dir).join(".ssh").join("id_rsa"); + if key_path.is_file() { + return git2::Cred::ssh_key( + username_from_url.unwrap(), + None, + &key_path, + None, + ); + } } } - } - git2::Cred::default() - }); - callbacks + git2::Cred::default() + }); + callbacks + } } pub struct Progress { diff --git a/lib/tests/test_git.rs b/lib/tests/test_git.rs index f66d738e3..f61df1de2 100644 --- a/lib/tests/test_git.rs +++ b/lib/tests/test_git.rs @@ -593,7 +593,13 @@ fn test_fetch_empty_repo() { let test_data = GitRepoData::create(); let mut tx = test_data.repo.start_transaction("test"); - let default_branch = git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap(); + let default_branch = git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "origin", + git::RemoteCallbacks::default(), + ) + .unwrap(); // No default branch and no refs assert_eq!(default_branch, None); assert_eq!(*tx.mut_repo().view().git_refs(), btreemap! {}); @@ -606,7 +612,13 @@ fn test_fetch_initial_commit() { let initial_git_commit = empty_git_commit(&test_data.origin_repo, "refs/heads/main", &[]); let mut tx = test_data.repo.start_transaction("test"); - let default_branch = git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap(); + let default_branch = git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "origin", + git::RemoteCallbacks::default(), + ) + .unwrap(); // No default branch because the origin repo's HEAD wasn't set assert_eq!(default_branch, None); let repo = tx.commit(); @@ -637,7 +649,13 @@ fn test_fetch_success() { let initial_git_commit = empty_git_commit(&test_data.origin_repo, "refs/heads/main", &[]); let mut tx = test_data.repo.start_transaction("test"); - git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap(); + git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "origin", + git::RemoteCallbacks::default(), + ) + .unwrap(); test_data.repo = tx.commit(); test_data.origin_repo.set_head("refs/heads/main").unwrap(); @@ -648,7 +666,13 @@ fn test_fetch_success() { ); let mut tx = test_data.repo.start_transaction("test"); - let default_branch = git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap(); + let default_branch = git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "origin", + git::RemoteCallbacks::default(), + ) + .unwrap(); // The default branch is "main" assert_eq!(default_branch, Some("main".to_string())); let repo = tx.commit(); @@ -679,7 +703,13 @@ fn test_fetch_prune_deleted_ref() { empty_git_commit(&test_data.git_repo, "refs/heads/main", &[]); let mut tx = test_data.repo.start_transaction("test"); - git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap(); + git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "origin", + git::RemoteCallbacks::default(), + ) + .unwrap(); // Test the setup assert!(tx.mut_repo().get_branch("main").is_some()); @@ -690,7 +720,13 @@ fn test_fetch_prune_deleted_ref() { .delete() .unwrap(); // After re-fetching, the branch should be deleted - git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap(); + git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "origin", + git::RemoteCallbacks::default(), + ) + .unwrap(); assert!(tx.mut_repo().get_branch("main").is_none()); } @@ -700,7 +736,13 @@ fn test_fetch_no_default_branch() { let initial_git_commit = empty_git_commit(&test_data.origin_repo, "refs/heads/main", &[]); let mut tx = test_data.repo.start_transaction("test"); - git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap(); + git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "origin", + git::RemoteCallbacks::default(), + ) + .unwrap(); empty_git_commit( &test_data.origin_repo, @@ -715,7 +757,13 @@ fn test_fetch_no_default_branch() { .set_head_detached(initial_git_commit.id()) .unwrap(); - let default_branch = git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap(); + let default_branch = git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "origin", + git::RemoteCallbacks::default(), + ) + .unwrap(); // There is no default branch assert_eq!(default_branch, None); } @@ -725,7 +773,12 @@ fn test_fetch_no_such_remote() { let test_data = GitRepoData::create(); let mut tx = test_data.repo.start_transaction("test"); - let result = git::fetch(tx.mut_repo(), &test_data.git_repo, "invalid-remote", None); + let result = git::fetch( + tx.mut_repo(), + &test_data.git_repo, + "invalid-remote", + git::RemoteCallbacks::default(), + ); assert!(matches!(result, Err(GitFetchError::NoSuchRemote(_)))); } diff --git a/src/commands.rs b/src/commands.rs index b5613128f..42d8e04d0 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -4104,14 +4104,11 @@ fn git_fetch( progress.update(Instant::now(), x); }); } - let result = git::fetch( - mut_repo, - git_repo, - remote_name, - callback - .as_mut() - .map(|x| x as &mut dyn FnMut(&git::Progress)), - ); + let mut callbacks = git::RemoteCallbacks::default(); + callbacks.progress = callback + .as_mut() + .map(|x| x as &mut dyn FnMut(&git::Progress)); + let result = git::fetch(mut_repo, git_repo, remote_name, callbacks); result }