diff --git a/lib/src/repo.rs b/lib/src/repo.rs index d3ab2fc24..9f87bdb5e 100644 --- a/lib/src/repo.rs +++ b/lib/src/repo.rs @@ -20,6 +20,7 @@ use std::collections::HashSet; use std::fmt::Debug; use std::fmt::Formatter; use std::fs; +use std::iter; use std::path::Path; use std::path::PathBuf; use std::slice; @@ -1075,7 +1076,9 @@ impl MutableRepo { } fn update_all_references(&mut self, settings: &UserSettings) -> BackendResult<()> { - for (old_id, new_ids) in self.resolve_rewrite_mapping_with(|_| true) { + let rewrite_mapping = self.resolve_rewrite_mapping_with(|_| true); + self.update_local_branches(&rewrite_mapping); + for (old_id, new_ids) in rewrite_mapping { self.update_references(settings, old_id, new_ids)?; } Ok(()) @@ -1096,45 +1099,29 @@ impl MutableRepo { &old_commit_id, &new_commit_ids, abandoned_old_commit, - )?; - - // Build a map from commit to branches pointing to it, so we don't need to scan - // all branches each time we rebase a commit. - // TODO: We no longer need to do this now that we update branches for all - // commits at once. - let mut branches: HashMap<_, HashSet<_>> = HashMap::new(); - for (branch_name, target) in self.view().local_branches() { - for commit in target.added_ids() { - branches - .entry(commit.clone()) - .or_default() - .insert(branch_name.to_owned()); - } - } - - if let Some(branch_names) = branches.get(&old_commit_id).cloned() { - let mut branch_updates = vec![]; - for branch_name in &branch_names { - let local_target = self.get_local_branch(branch_name); - for old_add in local_target.added_ids() { - if *old_add == old_commit_id { - branch_updates.push(branch_name.clone()); - } - } - } + ) + } + fn update_local_branches(&mut self, rewrite_mapping: &HashMap>) { + let changed_branches = self + .view() + .local_branches() + .flat_map(|(name, target)| { + target.added_ids().filter_map(|id| { + let change = rewrite_mapping.get_key_value(id)?; + Some((name.to_owned(), change)) + }) + }) + .collect_vec(); + for (branch_name, (old_commit_id, new_commit_ids)) in changed_branches { let old_target = RefTarget::normal(old_commit_id.clone()); assert!(!new_commit_ids.is_empty()); let new_target = RefTarget::from_legacy_form( - std::iter::repeat(old_commit_id).take(new_commit_ids.len() - 1), - new_commit_ids, + iter::repeat(old_commit_id.clone()).take(new_commit_ids.len() - 1), + new_commit_ids.iter().cloned(), ); - for branch_name in &branch_updates { - self.merge_local_branch(branch_name, &old_target, &new_target); - } + self.merge_local_branch(&branch_name, &old_target, &new_target); } - - Ok(()) } fn update_wc_commits(