Make LanguageModel::use_any_tool return a stream of chunks (#16262)
Some checks are pending
CI / Check formatting and spelling (push) Waiting to run
CI / (macOS) Run Clippy and tests (push) Waiting to run
CI / (Linux) Run Clippy and tests (push) Waiting to run
CI / (Windows) Run Clippy and tests (push) Waiting to run
CI / Create a macOS bundle (push) Blocked by required conditions
CI / Create a Linux bundle (push) Blocked by required conditions
CI / Create arm64 Linux bundle (push) Blocked by required conditions
Deploy Docs / Deploy Docs (push) Waiting to run
Docs / Check formatting (push) Waiting to run

This PR is a refactor to pave the way for allowing the user to view and
edit workflow step resolutions. I've made tool calls work more like
normal streaming completions for all providers. The `use_any_tool`
method returns a stream of strings (which contain chunks of JSON). I've
also done some minor cleanup of language model providers in general,
removing the duplication around handling streaming responses.

Release Notes:

- N/A
This commit is contained in:
Max Brunsfeld 2024-08-14 18:02:46 -07:00 committed by GitHub
parent 1117d89057
commit 4c390b82fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 253 additions and 400 deletions

View file

@ -5,8 +5,8 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use std::time::Duration;
use std::{pin::Pin, str::FromStr};
use strum::{EnumIter, EnumString};
use thiserror::Error;
@ -241,6 +241,50 @@ pub fn extract_text_from_events(
})
}
pub async fn extract_tool_args_from_events(
tool_name: String,
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
) -> Result<impl Send + Stream<Item = Result<String>>> {
let mut tool_use_index = None;
while let Some(event) = events.next().await {
if let Event::ContentBlockStart {
index,
content_block,
} = event?
{
if let Content::ToolUse { name, .. } = content_block {
if name == tool_name {
tool_use_index = Some(index);
break;
}
}
}
}
let Some(tool_use_index) = tool_use_index else {
return Err(anyhow!("tool not used"));
};
Ok(events.filter_map(move |event| {
let result = match event {
Err(error) => Some(Err(error)),
Ok(Event::ContentBlockDelta { index, delta }) => match delta {
ContentDelta::TextDelta { .. } => None,
ContentDelta::InputJsonDelta { partial_json } => {
if index == tool_use_index {
Some(Ok(partial_json))
} else {
None
}
}
},
_ => None,
};
async move { result }
}))
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Message {
pub role: Role,

View file

@ -1,6 +1,6 @@
use crate::{
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
InlineAssistId, InlineAssistant, MessageId, MessageStatus,
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InlineAssistId,
InlineAssistant, MessageId, MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
@ -3342,7 +3342,7 @@ mod tests {
model
.as_fake()
.respond_to_last_tool_use(Ok(serde_json::to_value(tool::WorkflowStepResolution {
.respond_to_last_tool_use(tool::WorkflowStepResolution {
step_title: "Title".into(),
suggestions: vec![tool::WorkflowSuggestion {
path: "/root/hello.rs".into(),
@ -3352,8 +3352,7 @@ mod tests {
description: "Extract a greeting function".into(),
},
}],
})
.unwrap()));
});
// Wait for tool use to be processed.
cx.run_until_parked();
@ -4084,44 +4083,4 @@ mod tool {
symbol: String,
},
}
impl WorkflowSuggestionKind {
pub fn symbol(&self) -> Option<&str> {
match self {
Self::Update { symbol, .. } => Some(symbol),
Self::InsertSiblingBefore { symbol, .. } => Some(symbol),
Self::InsertSiblingAfter { symbol, .. } => Some(symbol),
Self::PrependChild { symbol, .. } => symbol.as_deref(),
Self::AppendChild { symbol, .. } => symbol.as_deref(),
Self::Delete { symbol } => Some(symbol),
Self::Create { .. } => None,
}
}
pub fn description(&self) -> Option<&str> {
match self {
Self::Update { description, .. } => Some(description),
Self::Create { description } => Some(description),
Self::InsertSiblingBefore { description, .. } => Some(description),
Self::InsertSiblingAfter { description, .. } => Some(description),
Self::PrependChild { description, .. } => Some(description),
Self::AppendChild { description, .. } => Some(description),
Self::Delete { .. } => None,
}
}
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
match self {
WorkflowSuggestionKind::InsertSiblingBefore { .. } => {
Some(InitialInsertion::NewlineAfter)
}
WorkflowSuggestionKind::InsertSiblingAfter { .. } => {
Some(InitialInsertion::NewlineBefore)
}
WorkflowSuggestionKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
WorkflowSuggestionKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
_ => None,
}
}
}
}

View file

@ -1280,12 +1280,6 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
})
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum InitialInsertion {
NewlineBefore,
NewlineAfter,
}
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct InlineAssistId(usize);

View file

@ -351,10 +351,13 @@ impl Asset for ImageAsset {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if !response.status().is_success() {
let mut body = String::from_utf8_lossy(&body).into_owned();
let first_line = body.lines().next().unwrap_or("").trim_end();
body.truncate(first_line.len());
return Err(ImageCacheError::BadStatus {
uri,
status: response.status(),
body: String::from_utf8_lossy(&body).into_owned(),
body,
});
}
body

View file

@ -8,7 +8,7 @@ pub mod settings;
use anyhow::Result;
use client::{Client, UserStore};
use futures::{future::BoxFuture, stream::BoxStream};
use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
};
@ -76,7 +76,7 @@ pub trait LanguageModel: Send + Sync {
description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>>;
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
#[cfg(any(test, feature = "test-support"))]
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
@ -92,10 +92,11 @@ impl dyn LanguageModel {
) -> impl 'static + Future<Output = Result<T>> {
let schema = schemars::schema_for!(T);
let schema_json = serde_json::to_value(&schema).unwrap();
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
async move {
let response = request.await?;
Ok(serde_json::from_value(response)?)
let stream = stream.await?;
let response = stream.try_collect::<String>().await?;
Ok(serde_json::from_str(&response)?)
}
}
}

View file

@ -7,7 +7,7 @@ use anthropic::AnthropicError;
use anyhow::{anyhow, Context as _, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
use gpui::{
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
View, WhiteSpace,
@ -264,29 +264,6 @@ pub fn count_anthropic_tokens(
}
impl AnthropicModel {
fn request_completion(
&self,
request: anthropic::Request,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<anthropic::Response>> {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(state.api_key.clone(), settings.api_url.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
anthropic::complete(http_client.as_ref(), &api_url, &api_key, request)
.await
.context("failed to retrieve completion")
}
.boxed()
}
fn stream_completion(
&self,
request: anthropic::Request,
@ -381,7 +358,7 @@ impl LanguageModel for AnthropicModel {
tool_description: String,
input_schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let mut request = request.into_anthropic(self.model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
@ -392,25 +369,16 @@ impl LanguageModel for AnthropicModel {
input_schema,
}];
let response = self.request_completion(request, cx);
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
response
.content
.into_iter()
.find_map(|content| {
if let anthropic::Content::ToolUse { name, input, .. } = content {
if name == tool_name {
Some(input)
} else {
None
}
} else {
None
}
})
.context("tool not used")
Ok(anthropic::extract_tool_args_from_events(
tool_name,
Box::pin(response.map_err(|e| anyhow!(e))),
)
.await?
.boxed())
})
.boxed()
}

View file

@ -5,18 +5,21 @@ use crate::{
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
use anthropic::AnthropicError;
use anyhow::{anyhow, bail, Context as _, Result};
use anyhow::{anyhow, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, ZedPro};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use futures::{
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
TryStreamExt as _,
};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
Subscription, Task,
};
use http_client::{AsyncBody, HttpClient, Method, Response};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
use smol::{
@ -451,21 +454,9 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
let body = BufReader::new(response.into_body());
let stream = futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: anthropic::Event = serde_json::from_str(&buffer)
.context("failed to parse Anthropic event")?;
Ok(Some((event, body)))
}
Err(err) => Err(AnthropicError::Other(err.into())),
}
});
Ok(anthropic::extract_text_from_events(stream))
Ok(anthropic::extract_text_from_events(
response_lines(response).map_err(AnthropicError::Other),
))
});
async move {
Ok(future
@ -492,21 +483,7 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
let body = BufReader::new(response.into_body());
let stream = futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: open_ai::ResponseStreamEvent =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(open_ai::extract_text_from_events(stream))
Ok(open_ai::extract_text_from_events(response_lines(response)))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@ -527,21 +504,9 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
let body = BufReader::new(response.into_body());
let stream = futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: google_ai::GenerateContentResponse =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(google_ai::extract_text_from_events(stream))
Ok(google_ai::extract_text_from_events(response_lines(
response,
)))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@ -563,21 +528,7 @@ impl LanguageModel for CloudLanguageModel {
},
)
.await?;
let body = BufReader::new(response.into_body());
let stream = futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
match body.read_line(&mut buffer).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: open_ai::ResponseStreamEvent =
serde_json::from_str(&buffer)?;
Ok(Some((event, body)))
}
Err(e) => Err(e.into()),
}
});
Ok(open_ai::extract_text_from_events(stream))
Ok(open_ai::extract_text_from_events(response_lines(response)))
});
async move { Ok(future.await?.boxed()) }.boxed()
}
@ -591,10 +542,12 @@ impl LanguageModel for CloudLanguageModel {
tool_description: String,
input_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
match &self.model {
CloudModel::Anthropic(model) => {
let client = self.client.clone();
let mut request = request.into_anthropic(model.tool_model_id().into());
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
@ -605,7 +558,6 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@ -621,70 +573,34 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
let mut tool_use_index = None;
let mut tool_input = String::new();
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
while body.read_line(&mut line).await? > 0 {
let event: anthropic::Event = serde_json::from_str(&line)?;
line.clear();
match event {
anthropic::Event::ContentBlockStart {
content_block,
index,
} => {
if let anthropic::Content::ToolUse { name, .. } = content_block
{
if name == tool_name {
tool_use_index = Some(index);
}
}
}
anthropic::Event::ContentBlockDelta { index, delta } => match delta
{
anthropic::ContentDelta::TextDelta { .. } => {}
anthropic::ContentDelta::InputJsonDelta { partial_json } => {
if Some(index) == tool_use_index {
tool_input.push_str(&partial_json);
}
}
},
anthropic::Event::ContentBlockStop { index } => {
if Some(index) == tool_use_index {
return Ok(serde_json::from_str(&tool_input)?);
}
}
_ => {}
}
}
if tool_use_index.is_some() {
Err(anyhow!("tool content incomplete"))
} else {
Err(anyhow!("tool not used"))
}
Ok(anthropic::extract_tool_args_from_events(
tool_name,
Box::pin(response_lines(response)),
)
.await?
.boxed())
})
.boxed()
}
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
let client = self.client.clone();
let mut function = open_ai::FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
};
let func = open_ai::ToolDefinition::Function {
function: function.clone(),
};
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
// Fill in description and params separately, as they're not needed for tool_choice field.
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
request.tool_choice = Some(open_ai::ToolChoice::Other(
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
},
},
));
request.tools = vec![open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: tool_name.clone(),
description: Some(tool_description),
parameters: Some(input_schema),
},
}];
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@ -700,41 +616,12 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
let mut load_state = None;
while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
line.clear();
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
Ok(open_ai::extract_tool_args_from_events(
tool_name,
Box::pin(response_lines(response)),
)
.await?
.boxed())
})
.boxed()
}
@ -744,22 +631,23 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Zed(model) => {
// All Zed models are OpenAI-based at the time of writing.
let mut request = request.into_open_ai(model.id().into());
let client = self.client.clone();
let mut function = open_ai::FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
};
let func = open_ai::ToolDefinition::Function {
function: function.clone(),
};
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
// Fill in description and params separately, as they're not needed for tool_choice field.
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
request.tool_choice = Some(open_ai::ToolChoice::Other(
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
},
},
));
request.tools = vec![open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
name: tool_name.clone(),
description: Some(tool_description),
parameters: Some(input_schema),
},
}];
let llm_api_token = self.llm_api_token.clone();
self.request_limiter
.run(async move {
let response = Self::perform_llm_completion(
@ -775,40 +663,12 @@ impl LanguageModel for CloudLanguageModel {
)
.await?;
let mut body = BufReader::new(response.into_body());
let mut line = String::new();
let mut load_state = None;
while body.read_line(&mut line).await? > 0 {
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
line.clear();
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
Ok(open_ai::extract_tool_args_from_events(
tool_name,
Box::pin(response_lines(response)),
)
.await?
.boxed())
})
.boxed()
}
@ -816,6 +676,25 @@ impl LanguageModel for CloudLanguageModel {
}
}
fn response_lines<T: DeserializeOwned>(
response: Response<AsyncBody>,
) -> impl Stream<Item = Result<T>> {
futures::stream::try_unfold(
(String::new(), BufReader::new(response.into_body())),
move |(mut line, mut body)| async {
match body.read_line(&mut line).await {
Ok(0) => Ok(None),
Ok(_) => {
let event: T = serde_json::from_str(&line)?;
line.clear();
Ok(Some((event, (line, body))))
}
Err(e) => Err(e.into()),
}
},
)
}
impl LlmApiToken {
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
let lock = self.0.upgradable_read().await;

View file

@ -252,7 +252,7 @@ impl LanguageModel for CopilotChatLanguageModel {
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}

View file

@ -3,16 +3,11 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest,
};
use anyhow::Context as _;
use futures::{
channel::{mpsc, oneshot},
future::BoxFuture,
stream::BoxStream,
FutureExt, StreamExt,
};
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
use parking_lot::Mutex;
use serde::Serialize;
use std::sync::Arc;
use ui::WindowContext;
@ -90,7 +85,7 @@ pub struct ToolUseRequest {
#[derive(Default)]
pub struct FakeLanguageModel {
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
}
impl FakeLanguageModel {
@ -130,25 +125,11 @@ impl FakeLanguageModel {
self.end_completion_stream(self.pending_completions().last().unwrap());
}
pub fn respond_to_tool_use(
&self,
tool_call: &ToolUseRequest,
response: Result<serde_json::Value>,
) {
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
if let Some(index) = current_tool_call_txs
.iter()
.position(|(call, _)| call == tool_call)
{
let (_, tx) = current_tool_call_txs.remove(index);
tx.send(response).unwrap();
}
}
pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
let response = serde_json::to_string(&response).unwrap();
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
let (_, tx) = current_tool_call_txs.pop().unwrap();
tx.send(response).unwrap();
tx.unbounded_send(response).unwrap();
}
}
@ -202,8 +183,8 @@ impl LanguageModel for FakeLanguageModel {
description: String,
schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
let (tx, rx) = oneshot::channel();
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::unbounded();
let tool_call = ToolUseRequest {
request,
name,
@ -211,7 +192,7 @@ impl LanguageModel for FakeLanguageModel {
schema,
};
self.current_tool_use_txs.lock().push((tool_call, tx));
async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn as_fake(&self) -> &Self {

View file

@ -302,7 +302,7 @@ impl LanguageModel for GoogleLanguageModel {
_description: String,
_schema: serde_json::Value,
_cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
future::ready(Err(anyhow!("not implemented"))).boxed()
}
}

View file

@ -6,7 +6,6 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
ChatResponseDelta, OllamaToolCall,
};
use serde_json::Value;
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
@ -311,7 +310,7 @@ impl LanguageModel for OllamaLanguageModel {
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
use ollama::{OllamaFunctionTool, OllamaTool};
let function = OllamaFunctionTool {
name: tool_name.clone(),
@ -324,23 +323,19 @@ impl LanguageModel for OllamaLanguageModel {
self.request_limiter
.run(async move {
let response = response.await?;
let ChatMessage::Assistant {
tool_calls,
content,
} = response.message
else {
let ChatMessage::Assistant { tool_calls, .. } = response.message else {
bail!("message does not have an assistant role");
};
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
for call in tool_calls {
let OllamaToolCall::Function(function) = call;
if function.name == tool_name {
return Ok(function.arguments);
return Ok(futures::stream::once(async move {
Ok(function.arguments.to_string())
})
.boxed());
}
}
} else if let Ok(args) = serde_json::from_str::<Value>(&content) {
// Parse content as arguments.
return Ok(args);
} else {
bail!("assistant message does not have any tool calls");
};

View file

@ -1,4 +1,4 @@
use anyhow::{anyhow, bail, Result};
use anyhow::{anyhow, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@ -243,6 +243,7 @@ impl OpenAiLanguageModel {
async move { Ok(future.await?.boxed()) }.boxed()
}
}
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@ -293,55 +294,32 @@ impl LanguageModel for OpenAiLanguageModel {
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let mut request = request.into_open_ai(self.model.id().into());
let mut function = FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
};
let func = ToolDefinition::Function {
function: function.clone(),
};
request.tool_choice = Some(ToolChoice::Other(func.clone()));
// Fill in description and params separately, as they're not needed for tool_choice field.
function.description = Some(tool_description);
function.parameters = Some(schema);
request.tools = vec![ToolDefinition::Function { function }];
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
function: FunctionDefinition {
name: tool_name.clone(),
description: None,
parameters: None,
},
}));
request.tools = vec![ToolDefinition::Function {
function: FunctionDefinition {
name: tool_name.clone(),
description: Some(tool_description),
parameters: Some(schema),
},
}];
let response = self.stream_completion(request, cx);
self.request_limiter
.run(async move {
let mut response = response.await?;
// Call arguments are gonna be streamed in over multiple chunks.
let mut load_state = None;
while let Some(Ok(part)) = response.next().await {
for choice in part.choices {
let Some(tool_calls) = choice.delta.tool_calls else {
continue;
};
for call in tool_calls {
if let Some(func) = call.function {
if func.name.as_deref() == Some(tool_name.as_str()) {
load_state = Some((String::default(), call.index));
}
if let Some((arguments, (output, index))) =
func.arguments.zip(load_state.as_mut())
{
if call.index == *index {
output.push_str(&arguments);
}
}
}
}
}
}
if let Some((arguments, _)) = load_state {
return Ok(serde_json::from_str(&arguments)?);
} else {
bail!("tool not used");
}
let response = response.await?;
Ok(
open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
.await?
.boxed(),
)
})
.boxed()
}

View file

@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use serde_json::{value::RawValue, Value};
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@ -92,7 +92,7 @@ impl Model {
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant {
@ -107,16 +107,16 @@ pub enum ChatMessage {
},
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum OllamaToolCall {
Function(OllamaFunctionCall),
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug)]
pub struct OllamaFunctionCall {
pub name: String,
pub arguments: Value,
pub arguments: Box<RawValue>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]

View file

@ -6,7 +6,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration};
use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
use strum::EnumIter;
pub use supported_countries::*;
@ -384,6 +384,57 @@ pub fn embed<'a>(
}
}
pub async fn extract_tool_args_from_events(
tool_name: String,
mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
) -> Result<impl Send + Stream<Item = Result<String>>> {
let mut tool_use_index = None;
let mut first_chunk = None;
while let Some(event) = events.next().await {
let call = event?.choices.into_iter().find_map(|choice| {
choice.delta.tool_calls?.into_iter().find_map(|call| {
if call.function.as_ref()?.name.as_deref()? == tool_name {
Some(call)
} else {
None
}
})
});
if let Some(call) = call {
tool_use_index = Some(call.index);
first_chunk = call.function.and_then(|func| func.arguments);
break;
}
}
let Some(tool_use_index) = tool_use_index else {
return Err(anyhow!("tool not used"));
};
Ok(events.filter_map(move |event| {
let result = match event {
Err(error) => Some(Err(error)),
Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
choice.delta.tool_calls?.into_iter().find_map(|call| {
if call.index == tool_use_index {
let func = call.function?;
let mut arguments = func.arguments?;
if let Some(mut first_chunk) = first_chunk.take() {
first_chunk.push_str(&arguments);
arguments = first_chunk
}
Some(Ok(arguments))
} else {
None
}
})
}),
};
async move { result }
}))
}
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseStreamEvent>>,
) -> impl Stream<Item = Result<String>> {