diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 6e23c1e7a0..d7299ca6b5 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -870,7 +870,7 @@ impl Conversation { .messages(cx) .map(|message| SavedMessage { id: message.id, - start: message.range.start, + start: message.offset_range.start, }) .collect(), summary: self @@ -968,7 +968,11 @@ impl Conversation { Role::Assistant => "assistant".into(), Role::System => "system".into(), }, - content: self.buffer.read(cx).text_for_range(message.range).collect(), + content: self + .buffer + .read(cx) + .text_for_range(message.offset_range) + .collect(), name: None, }) }) @@ -1183,10 +1187,19 @@ impl Conversation { .iter() .position(|message| message.id == message_id) { + // Find the next valid message after the one we were given. + let mut next_message_ix = prev_message_ix + 1; + while let Some(next_message) = self.message_anchors.get(next_message_ix) { + if next_message.start.is_valid(self.buffer.read(cx)) { + break; + } + next_message_ix += 1; + } + let start = self.buffer.update(cx, |buffer, cx| { - let offset = self.message_anchors[prev_message_ix + 1..] - .iter() - .find(|message| message.start.is_valid(buffer)) + let offset = self + .message_anchors + .get(next_message_ix) .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1); buffer.edit([(offset..offset, "\n")], None, cx); buffer.anchor_before(offset + 1) @@ -1196,7 +1209,7 @@ impl Conversation { start, }; self.message_anchors - .insert(prev_message_ix + 1, message.clone()); + .insert(next_message_ix, message.clone()); self.messages_metadata.insert( message.id, MessageMetadata { @@ -1221,7 +1234,7 @@ impl Conversation { let end_message = self.message_for_offset(range.end, cx); if let Some((start_message, end_message)) = start_message.zip(end_message) { // Prevent splitting when range spans multiple messages. - if start_message.index != end_message.index { + if start_message.id != end_message.id { return (None, None); } @@ -1230,7 +1243,8 @@ impl Conversation { let mut edited_buffer = false; let mut suffix_start = None; - if range.start > message.range.start && range.end < message.range.end - 1 { + if range.start > message.offset_range.start && range.end < message.offset_range.end - 1 + { if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') { suffix_start = Some(range.end + 1); } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') { @@ -1255,7 +1269,7 @@ impl Conversation { }; self.message_anchors - .insert(message.index + 1, suffix.clone()); + .insert(message.index_range.end + 1, suffix.clone()); self.messages_metadata.insert( suffix.id, MessageMetadata { @@ -1265,49 +1279,52 @@ impl Conversation { }, ); - let new_messages = if range.start == range.end || range.start == message.range.start { - (None, Some(suffix)) - } else { - let mut prefix_end = None; - if range.start > message.range.start && range.end < message.range.end - 1 { - if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { - prefix_end = Some(range.start + 1); - } else if self.buffer.read(cx).reversed_chars_at(range.start).next() - == Some('\n') - { - prefix_end = Some(range.start); - } - } - - let selection = if let Some(prefix_end) = prefix_end { - cx.emit(ConversationEvent::MessagesEdited); - MessageAnchor { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(prefix_end), - } + let new_messages = + if range.start == range.end || range.start == message.offset_range.start { + (None, Some(suffix)) } else { - self.buffer.update(cx, |buffer, cx| { - buffer.edit([(range.start..range.start, "\n")], None, cx) - }); - edited_buffer = true; - MessageAnchor { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(range.end + 1), + let mut prefix_end = None; + if range.start > message.offset_range.start + && range.end < message.offset_range.end - 1 + { + if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { + prefix_end = Some(range.start + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.start).next() + == Some('\n') + { + prefix_end = Some(range.start); + } } - }; - self.message_anchors - .insert(message.index + 1, selection.clone()); - self.messages_metadata.insert( - selection.id, - MessageMetadata { - role, - sent_at: Local::now(), - status: MessageStatus::Done, - }, - ); - (Some(selection), Some(suffix)) - }; + let selection = if let Some(prefix_end) = prefix_end { + cx.emit(ConversationEvent::MessagesEdited); + MessageAnchor { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(prefix_end), + } + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.start..range.start, "\n")], None, cx) + }); + edited_buffer = true; + MessageAnchor { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), + } + }; + + self.message_anchors + .insert(message.index_range.end + 1, selection.clone()); + self.messages_metadata.insert( + selection.id, + MessageMetadata { + role, + sent_at: Local::now(), + status: MessageStatus::Done, + }, + ); + (Some(selection), Some(suffix)) + }; if !edited_buffer { cx.emit(ConversationEvent::MessagesEdited); @@ -1389,7 +1406,7 @@ impl Conversation { while let Some(offset) = offsets.next() { // Locate the message that contains the offset. while current_message.as_ref().map_or(false, |message| { - !message.range.contains(&offset) && messages.peek().is_some() + !message.offset_range.contains(&offset) && messages.peek().is_some() }) { current_message = messages.next(); } @@ -1397,7 +1414,7 @@ impl Conversation { // Skip offsets that are in the same message. while offsets.peek().map_or(false, |offset| { - message.range.contains(offset) || messages.peek().is_none() + message.offset_range.contains(offset) || messages.peek().is_none() }) { offsets.next(); } @@ -1411,15 +1428,17 @@ impl Conversation { let buffer = self.buffer.read(cx); let mut message_anchors = self.message_anchors.iter().enumerate().peekable(); iter::from_fn(move || { - while let Some((ix, message_anchor)) = message_anchors.next() { + while let Some((start_ix, message_anchor)) = message_anchors.next() { let metadata = self.messages_metadata.get(&message_anchor.id)?; let message_start = message_anchor.start.to_offset(buffer); let mut message_end = None; + let mut end_ix = start_ix; while let Some((_, next_message)) = message_anchors.peek() { if next_message.start.is_valid(buffer) { message_end = Some(next_message.start); break; } else { + end_ix += 1; message_anchors.next(); } } @@ -1427,8 +1446,8 @@ impl Conversation { .unwrap_or(language::Anchor::MAX) .to_offset(buffer); return Some(Message { - index: ix, - range: message_start..message_end, + index_range: start_ix..end_ix, + offset_range: message_start..message_end, id: message_anchor.id, anchor: message_anchor.start, role: metadata.role, @@ -1885,11 +1904,11 @@ impl ConversationEditor { let mut copied_text = String::new(); let mut spanned_messages = 0; for message in conversation.messages(cx) { - if message.range.start >= selection.range().end { + if message.offset_range.start >= selection.range().end { break; - } else if message.range.end >= selection.range().start { - let range = cmp::max(message.range.start, selection.range().start) - ..cmp::min(message.range.end, selection.range().end); + } else if message.offset_range.end >= selection.range().start { + let range = cmp::max(message.offset_range.start, selection.range().start) + ..cmp::min(message.offset_range.end, selection.range().end); if !range.is_empty() { spanned_messages += 1; write!(&mut copied_text, "## {}\n\n", message.role).unwrap(); @@ -2005,8 +2024,8 @@ struct MessageAnchor { #[derive(Clone, Debug)] pub struct Message { - range: Range, - index: usize, + offset_range: Range, + index_range: Range, id: MessageId, anchor: language::Anchor, role: Role, @@ -2017,7 +2036,7 @@ pub struct Message { impl Message { fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage { let mut content = format!("[Message {}]\n", self.id.0).to_string(); - content.extend(buffer.text_for_range(self.range.clone())); + content.extend(buffer.text_for_range(self.offset_range.clone())); RequestMessage { role: self.role, content: content.trim_end().into(), @@ -2525,7 +2544,7 @@ mod tests { conversation .read(cx) .messages(cx) - .map(|message| (message.id, message.role, message.range)) + .map(|message| (message.id, message.role, message.offset_range)) .collect() } }