diff --git a/lib/src/repo.rs b/lib/src/repo.rs index 227070814..c918f71a1 100644 --- a/lib/src/repo.rs +++ b/lib/src/repo.rs @@ -16,12 +16,12 @@ use std::collections::{HashMap, HashSet}; use std::fmt::{Debug, Formatter}; -use std::fs; use std::io::ErrorKind; use std::ops::Deref; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; +use std::{fs, slice}; use itertools::Itertools; use once_cell::sync::OnceCell; @@ -915,38 +915,49 @@ impl MutableRepo { } pub fn add_head(&mut self, head: &Commit) { + self.add_heads(slice::from_ref(head)); + } + + pub fn add_heads(&mut self, heads: &[Commit]) { let current_heads = self.view.get_mut().heads(); // Use incremental update for common case of adding a single commit on top a // current head. TODO: Also use incremental update when adding a single // commit on top a non-head. - if head - .parent_ids() - .iter() - .all(|parent_id| current_heads.contains(parent_id)) - { - self.index.add_commit(head); - self.view.get_mut().add_head(head.id()); - for parent_id in head.parent_ids() { - self.view.get_mut().remove_head(parent_id); + match heads { + [] => {} + [head] + if head + .parent_ids() + .iter() + .all(|parent_id| current_heads.contains(parent_id)) => + { + self.index.add_commit(head); + self.view.get_mut().add_head(head.id()); + for parent_id in head.parent_ids() { + self.view.get_mut().remove_head(parent_id); + } } - } else { - let missing_commits = dag_walk::topo_order_forward( - vec![head.clone()], - |commit: &Commit| commit.id().clone(), - |commit: &Commit| -> Vec { - commit - .parent_ids() - .iter() - .filter(|id| !self.index().has_id(id)) - .map(|id| self.store().get_commit(id).unwrap()) - .collect() - }, - ); - for missing_commit in &missing_commits { - self.index.add_commit(missing_commit); + _ => { + let missing_commits = dag_walk::topo_order_forward( + heads.iter().cloned(), + |commit: &Commit| commit.id().clone(), + |commit: &Commit| -> Vec { + commit + .parent_ids() + .iter() + .filter(|id| !self.index().has_id(id)) + .map(|id| self.store().get_commit(id).unwrap()) + .collect() + }, + ); + for missing_commit in &missing_commits { + self.index.add_commit(missing_commit); + } + for head in heads { + self.view.get_mut().add_head(head.id()); + } + self.view.mark_dirty(); } - self.view.get_mut().add_head(head.id()); - self.view.mark_dirty(); } }