mirror of
https://github.com/zed-industries/zed.git
synced 2025-01-29 21:49:33 +00:00
Make LanguageModel::use_any_tool return a stream of chunks (#16262)
Some checks are pending
CI / Check formatting and spelling (push) Waiting to run
CI / (macOS) Run Clippy and tests (push) Waiting to run
CI / (Linux) Run Clippy and tests (push) Waiting to run
CI / (Windows) Run Clippy and tests (push) Waiting to run
CI / Create a macOS bundle (push) Blocked by required conditions
CI / Create a Linux bundle (push) Blocked by required conditions
CI / Create arm64 Linux bundle (push) Blocked by required conditions
Deploy Docs / Deploy Docs (push) Waiting to run
Docs / Check formatting (push) Waiting to run
Some checks are pending
CI / Check formatting and spelling (push) Waiting to run
CI / (macOS) Run Clippy and tests (push) Waiting to run
CI / (Linux) Run Clippy and tests (push) Waiting to run
CI / (Windows) Run Clippy and tests (push) Waiting to run
CI / Create a macOS bundle (push) Blocked by required conditions
CI / Create a Linux bundle (push) Blocked by required conditions
CI / Create arm64 Linux bundle (push) Blocked by required conditions
Deploy Docs / Deploy Docs (push) Waiting to run
Docs / Check formatting (push) Waiting to run
This PR is a refactor to pave the way for allowing the user to view and edit workflow step resolutions. I've made tool calls work more like normal streaming completions for all providers. The `use_any_tool` method returns a stream of strings (which contain chunks of JSON). I've also done some minor cleanup of language model providers in general, removing the duplication around handling streaming responses. Release Notes: - N/A
This commit is contained in:
parent
1117d89057
commit
4c390b82fb
14 changed files with 253 additions and 400 deletions
|
@ -5,8 +5,8 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
use isahc::config::Configurable;
|
use isahc::config::Configurable;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::str::FromStr;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use std::{pin::Pin, str::FromStr};
|
||||||
use strum::{EnumIter, EnumString};
|
use strum::{EnumIter, EnumString};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
|
@ -241,6 +241,50 @@ pub fn extract_text_from_events(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn extract_tool_args_from_events(
|
||||||
|
tool_name: String,
|
||||||
|
mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
|
||||||
|
) -> Result<impl Send + Stream<Item = Result<String>>> {
|
||||||
|
let mut tool_use_index = None;
|
||||||
|
while let Some(event) = events.next().await {
|
||||||
|
if let Event::ContentBlockStart {
|
||||||
|
index,
|
||||||
|
content_block,
|
||||||
|
} = event?
|
||||||
|
{
|
||||||
|
if let Content::ToolUse { name, .. } = content_block {
|
||||||
|
if name == tool_name {
|
||||||
|
tool_use_index = Some(index);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(tool_use_index) = tool_use_index else {
|
||||||
|
return Err(anyhow!("tool not used"));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(events.filter_map(move |event| {
|
||||||
|
let result = match event {
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
Ok(Event::ContentBlockDelta { index, delta }) => match delta {
|
||||||
|
ContentDelta::TextDelta { .. } => None,
|
||||||
|
ContentDelta::InputJsonDelta { partial_json } => {
|
||||||
|
if index == tool_use_index {
|
||||||
|
Some(Ok(partial_json))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
async move { result }
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub role: Role,
|
pub role: Role,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InitialInsertion,
|
prompts::PromptBuilder, slash_command::SlashCommandLine, AssistantPanel, InlineAssistId,
|
||||||
InlineAssistId, InlineAssistant, MessageId, MessageStatus,
|
InlineAssistant, MessageId, MessageStatus,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use assistant_slash_command::{
|
use assistant_slash_command::{
|
||||||
|
@ -3342,7 +3342,7 @@ mod tests {
|
||||||
|
|
||||||
model
|
model
|
||||||
.as_fake()
|
.as_fake()
|
||||||
.respond_to_last_tool_use(Ok(serde_json::to_value(tool::WorkflowStepResolution {
|
.respond_to_last_tool_use(tool::WorkflowStepResolution {
|
||||||
step_title: "Title".into(),
|
step_title: "Title".into(),
|
||||||
suggestions: vec![tool::WorkflowSuggestion {
|
suggestions: vec![tool::WorkflowSuggestion {
|
||||||
path: "/root/hello.rs".into(),
|
path: "/root/hello.rs".into(),
|
||||||
|
@ -3352,8 +3352,7 @@ mod tests {
|
||||||
description: "Extract a greeting function".into(),
|
description: "Extract a greeting function".into(),
|
||||||
},
|
},
|
||||||
}],
|
}],
|
||||||
})
|
});
|
||||||
.unwrap()));
|
|
||||||
|
|
||||||
// Wait for tool use to be processed.
|
// Wait for tool use to be processed.
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
@ -4084,44 +4083,4 @@ mod tool {
|
||||||
symbol: String,
|
symbol: String,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WorkflowSuggestionKind {
|
|
||||||
pub fn symbol(&self) -> Option<&str> {
|
|
||||||
match self {
|
|
||||||
Self::Update { symbol, .. } => Some(symbol),
|
|
||||||
Self::InsertSiblingBefore { symbol, .. } => Some(symbol),
|
|
||||||
Self::InsertSiblingAfter { symbol, .. } => Some(symbol),
|
|
||||||
Self::PrependChild { symbol, .. } => symbol.as_deref(),
|
|
||||||
Self::AppendChild { symbol, .. } => symbol.as_deref(),
|
|
||||||
Self::Delete { symbol } => Some(symbol),
|
|
||||||
Self::Create { .. } => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn description(&self) -> Option<&str> {
|
|
||||||
match self {
|
|
||||||
Self::Update { description, .. } => Some(description),
|
|
||||||
Self::Create { description } => Some(description),
|
|
||||||
Self::InsertSiblingBefore { description, .. } => Some(description),
|
|
||||||
Self::InsertSiblingAfter { description, .. } => Some(description),
|
|
||||||
Self::PrependChild { description, .. } => Some(description),
|
|
||||||
Self::AppendChild { description, .. } => Some(description),
|
|
||||||
Self::Delete { .. } => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn initial_insertion(&self) -> Option<InitialInsertion> {
|
|
||||||
match self {
|
|
||||||
WorkflowSuggestionKind::InsertSiblingBefore { .. } => {
|
|
||||||
Some(InitialInsertion::NewlineAfter)
|
|
||||||
}
|
|
||||||
WorkflowSuggestionKind::InsertSiblingAfter { .. } => {
|
|
||||||
Some(InitialInsertion::NewlineBefore)
|
|
||||||
}
|
|
||||||
WorkflowSuggestionKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
|
|
||||||
WorkflowSuggestionKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
|
|
||||||
_ => None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1280,12 +1280,6 @@ fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
|
||||||
pub enum InitialInsertion {
|
|
||||||
NewlineBefore,
|
|
||||||
NewlineAfter,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||||
pub struct InlineAssistId(usize);
|
pub struct InlineAssistId(usize);
|
||||||
|
|
||||||
|
|
|
@ -351,10 +351,13 @@ impl Asset for ImageAsset {
|
||||||
let mut body = Vec::new();
|
let mut body = Vec::new();
|
||||||
response.body_mut().read_to_end(&mut body).await?;
|
response.body_mut().read_to_end(&mut body).await?;
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
|
let mut body = String::from_utf8_lossy(&body).into_owned();
|
||||||
|
let first_line = body.lines().next().unwrap_or("").trim_end();
|
||||||
|
body.truncate(first_line.len());
|
||||||
return Err(ImageCacheError::BadStatus {
|
return Err(ImageCacheError::BadStatus {
|
||||||
uri,
|
uri,
|
||||||
status: response.status(),
|
status: response.status(),
|
||||||
body: String::from_utf8_lossy(&body).into_owned(),
|
body,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
body
|
body
|
||||||
|
|
|
@ -8,7 +8,7 @@ pub mod settings;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use client::{Client, UserStore};
|
use client::{Client, UserStore};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream};
|
use futures::{future::BoxFuture, stream::BoxStream, TryStreamExt as _};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
|
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
|
||||||
};
|
};
|
||||||
|
@ -76,7 +76,7 @@ pub trait LanguageModel: Send + Sync {
|
||||||
description: String,
|
description: String,
|
||||||
schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>>;
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
|
||||||
#[cfg(any(test, feature = "test-support"))]
|
#[cfg(any(test, feature = "test-support"))]
|
||||||
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
|
fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
|
||||||
|
@ -92,10 +92,11 @@ impl dyn LanguageModel {
|
||||||
) -> impl 'static + Future<Output = Result<T>> {
|
) -> impl 'static + Future<Output = Result<T>> {
|
||||||
let schema = schemars::schema_for!(T);
|
let schema = schemars::schema_for!(T);
|
||||||
let schema_json = serde_json::to_value(&schema).unwrap();
|
let schema_json = serde_json::to_value(&schema).unwrap();
|
||||||
let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
|
let stream = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
|
||||||
async move {
|
async move {
|
||||||
let response = request.await?;
|
let stream = stream.await?;
|
||||||
Ok(serde_json::from_value(response)?)
|
let response = stream.try_collect::<String>().await?;
|
||||||
|
Ok(serde_json::from_str(&response)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ use anthropic::AnthropicError;
|
||||||
use anyhow::{anyhow, Context as _, Result};
|
use anyhow::{anyhow, Context as _, Result};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryStreamExt as _};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
|
AnyView, AppContext, AsyncAppContext, FontStyle, ModelContext, Subscription, Task, TextStyle,
|
||||||
View, WhiteSpace,
|
View, WhiteSpace,
|
||||||
|
@ -264,29 +264,6 @@ pub fn count_anthropic_tokens(
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicModel {
|
impl AnthropicModel {
|
||||||
fn request_completion(
|
|
||||||
&self,
|
|
||||||
request: anthropic::Request,
|
|
||||||
cx: &AsyncAppContext,
|
|
||||||
) -> BoxFuture<'static, Result<anthropic::Response>> {
|
|
||||||
let http_client = self.http_client.clone();
|
|
||||||
|
|
||||||
let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| {
|
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
|
|
||||||
(state.api_key.clone(), settings.api_url.clone())
|
|
||||||
}) else {
|
|
||||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
|
||||||
};
|
|
||||||
|
|
||||||
async move {
|
|
||||||
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
|
|
||||||
anthropic::complete(http_client.as_ref(), &api_url, &api_key, request)
|
|
||||||
.await
|
|
||||||
.context("failed to retrieve completion")
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn stream_completion(
|
fn stream_completion(
|
||||||
&self,
|
&self,
|
||||||
request: anthropic::Request,
|
request: anthropic::Request,
|
||||||
|
@ -381,7 +358,7 @@ impl LanguageModel for AnthropicModel {
|
||||||
tool_description: String,
|
tool_description: String,
|
||||||
input_schema: serde_json::Value,
|
input_schema: serde_json::Value,
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let mut request = request.into_anthropic(self.model.tool_model_id().into());
|
let mut request = request.into_anthropic(self.model.tool_model_id().into());
|
||||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||||
name: tool_name.clone(),
|
name: tool_name.clone(),
|
||||||
|
@ -392,25 +369,16 @@ impl LanguageModel for AnthropicModel {
|
||||||
input_schema,
|
input_schema,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let response = self.request_completion(request, cx);
|
let response = self.stream_completion(request, cx);
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let response = response.await?;
|
let response = response.await?;
|
||||||
response
|
Ok(anthropic::extract_tool_args_from_events(
|
||||||
.content
|
tool_name,
|
||||||
.into_iter()
|
Box::pin(response.map_err(|e| anyhow!(e))),
|
||||||
.find_map(|content| {
|
)
|
||||||
if let anthropic::Content::ToolUse { name, input, .. } = content {
|
.await?
|
||||||
if name == tool_name {
|
.boxed())
|
||||||
Some(input)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.context("tool not used")
|
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,18 +5,21 @@ use crate::{
|
||||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
|
||||||
};
|
};
|
||||||
use anthropic::AnthropicError;
|
use anthropic::AnthropicError;
|
||||||
use anyhow::{anyhow, bail, Context as _, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use feature_flags::{FeatureFlagAppExt, ZedPro};
|
use feature_flags::{FeatureFlagAppExt, ZedPro};
|
||||||
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
|
use futures::{
|
||||||
|
future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, Stream, StreamExt,
|
||||||
|
TryStreamExt as _,
|
||||||
|
};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
|
AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
|
||||||
Subscription, Task,
|
Subscription, Task,
|
||||||
};
|
};
|
||||||
use http_client::{AsyncBody, HttpClient, Method, Response};
|
use http_client::{AsyncBody, HttpClient, Method, Response};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
use serde_json::value::RawValue;
|
use serde_json::value::RawValue;
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use smol::{
|
use smol::{
|
||||||
|
@ -451,21 +454,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let body = BufReader::new(response.into_body());
|
Ok(anthropic::extract_text_from_events(
|
||||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
response_lines(response).map_err(AnthropicError::Other),
|
||||||
let mut buffer = String::new();
|
))
|
||||||
match body.read_line(&mut buffer).await {
|
|
||||||
Ok(0) => Ok(None),
|
|
||||||
Ok(_) => {
|
|
||||||
let event: anthropic::Event = serde_json::from_str(&buffer)
|
|
||||||
.context("failed to parse Anthropic event")?;
|
|
||||||
Ok(Some((event, body)))
|
|
||||||
}
|
|
||||||
Err(err) => Err(AnthropicError::Other(err.into())),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(anthropic::extract_text_from_events(stream))
|
|
||||||
});
|
});
|
||||||
async move {
|
async move {
|
||||||
Ok(future
|
Ok(future
|
||||||
|
@ -492,21 +483,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let body = BufReader::new(response.into_body());
|
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
|
||||||
let mut buffer = String::new();
|
|
||||||
match body.read_line(&mut buffer).await {
|
|
||||||
Ok(0) => Ok(None),
|
|
||||||
Ok(_) => {
|
|
||||||
let event: open_ai::ResponseStreamEvent =
|
|
||||||
serde_json::from_str(&buffer)?;
|
|
||||||
Ok(Some((event, body)))
|
|
||||||
}
|
|
||||||
Err(e) => Err(e.into()),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(open_ai::extract_text_from_events(stream))
|
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
@ -527,21 +504,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let body = BufReader::new(response.into_body());
|
Ok(google_ai::extract_text_from_events(response_lines(
|
||||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
response,
|
||||||
let mut buffer = String::new();
|
)))
|
||||||
match body.read_line(&mut buffer).await {
|
|
||||||
Ok(0) => Ok(None),
|
|
||||||
Ok(_) => {
|
|
||||||
let event: google_ai::GenerateContentResponse =
|
|
||||||
serde_json::from_str(&buffer)?;
|
|
||||||
Ok(Some((event, body)))
|
|
||||||
}
|
|
||||||
Err(e) => Err(e.into()),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(google_ai::extract_text_from_events(stream))
|
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
@ -563,21 +528,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
let body = BufReader::new(response.into_body());
|
Ok(open_ai::extract_text_from_events(response_lines(response)))
|
||||||
let stream = futures::stream::try_unfold(body, move |mut body| async move {
|
|
||||||
let mut buffer = String::new();
|
|
||||||
match body.read_line(&mut buffer).await {
|
|
||||||
Ok(0) => Ok(None),
|
|
||||||
Ok(_) => {
|
|
||||||
let event: open_ai::ResponseStreamEvent =
|
|
||||||
serde_json::from_str(&buffer)?;
|
|
||||||
Ok(Some((event, body)))
|
|
||||||
}
|
|
||||||
Err(e) => Err(e.into()),
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
Ok(open_ai::extract_text_from_events(stream))
|
|
||||||
});
|
});
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
@ -591,10 +542,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
tool_description: String,
|
tool_description: String,
|
||||||
input_schema: serde_json::Value,
|
input_schema: serde_json::Value,
|
||||||
_cx: &AsyncAppContext,
|
_cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
let client = self.client.clone();
|
||||||
|
let llm_api_token = self.llm_api_token.clone();
|
||||||
|
|
||||||
match &self.model {
|
match &self.model {
|
||||||
CloudModel::Anthropic(model) => {
|
CloudModel::Anthropic(model) => {
|
||||||
let client = self.client.clone();
|
|
||||||
let mut request = request.into_anthropic(model.tool_model_id().into());
|
let mut request = request.into_anthropic(model.tool_model_id().into());
|
||||||
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
request.tool_choice = Some(anthropic::ToolChoice::Tool {
|
||||||
name: tool_name.clone(),
|
name: tool_name.clone(),
|
||||||
|
@ -605,7 +558,6 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
input_schema,
|
input_schema,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -621,70 +573,34 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut tool_use_index = None;
|
Ok(anthropic::extract_tool_args_from_events(
|
||||||
let mut tool_input = String::new();
|
tool_name,
|
||||||
let mut body = BufReader::new(response.into_body());
|
Box::pin(response_lines(response)),
|
||||||
let mut line = String::new();
|
)
|
||||||
while body.read_line(&mut line).await? > 0 {
|
.await?
|
||||||
let event: anthropic::Event = serde_json::from_str(&line)?;
|
.boxed())
|
||||||
line.clear();
|
|
||||||
|
|
||||||
match event {
|
|
||||||
anthropic::Event::ContentBlockStart {
|
|
||||||
content_block,
|
|
||||||
index,
|
|
||||||
} => {
|
|
||||||
if let anthropic::Content::ToolUse { name, .. } = content_block
|
|
||||||
{
|
|
||||||
if name == tool_name {
|
|
||||||
tool_use_index = Some(index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
anthropic::Event::ContentBlockDelta { index, delta } => match delta
|
|
||||||
{
|
|
||||||
anthropic::ContentDelta::TextDelta { .. } => {}
|
|
||||||
anthropic::ContentDelta::InputJsonDelta { partial_json } => {
|
|
||||||
if Some(index) == tool_use_index {
|
|
||||||
tool_input.push_str(&partial_json);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
anthropic::Event::ContentBlockStop { index } => {
|
|
||||||
if Some(index) == tool_use_index {
|
|
||||||
return Ok(serde_json::from_str(&tool_input)?);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tool_use_index.is_some() {
|
|
||||||
Err(anyhow!("tool content incomplete"))
|
|
||||||
} else {
|
|
||||||
Err(anyhow!("tool not used"))
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
CloudModel::OpenAi(model) => {
|
CloudModel::OpenAi(model) => {
|
||||||
let mut request = request.into_open_ai(model.id().into());
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
let client = self.client.clone();
|
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
||||||
let mut function = open_ai::FunctionDefinition {
|
open_ai::ToolDefinition::Function {
|
||||||
name: tool_name.clone(),
|
function: open_ai::FunctionDefinition {
|
||||||
description: None,
|
name: tool_name.clone(),
|
||||||
parameters: None,
|
description: None,
|
||||||
};
|
parameters: None,
|
||||||
let func = open_ai::ToolDefinition::Function {
|
},
|
||||||
function: function.clone(),
|
},
|
||||||
};
|
));
|
||||||
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
request.tools = vec![open_ai::ToolDefinition::Function {
|
||||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
function: open_ai::FunctionDefinition {
|
||||||
function.description = Some(tool_description);
|
name: tool_name.clone(),
|
||||||
function.parameters = Some(input_schema);
|
description: Some(tool_description),
|
||||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
parameters: Some(input_schema),
|
||||||
|
},
|
||||||
|
}];
|
||||||
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -700,41 +616,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut body = BufReader::new(response.into_body());
|
Ok(open_ai::extract_tool_args_from_events(
|
||||||
let mut line = String::new();
|
tool_name,
|
||||||
let mut load_state = None;
|
Box::pin(response_lines(response)),
|
||||||
|
)
|
||||||
while body.read_line(&mut line).await? > 0 {
|
.await?
|
||||||
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
.boxed())
|
||||||
line.clear();
|
|
||||||
|
|
||||||
for choice in part.choices {
|
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
for call in tool_calls {
|
|
||||||
if let Some(func) = call.function {
|
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
|
||||||
load_state = Some((String::default(), call.index));
|
|
||||||
}
|
|
||||||
if let Some((arguments, (output, index))) =
|
|
||||||
func.arguments.zip(load_state.as_mut())
|
|
||||||
{
|
|
||||||
if call.index == *index {
|
|
||||||
output.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some((arguments, _)) = load_state {
|
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
|
||||||
} else {
|
|
||||||
bail!("tool not used");
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
@ -744,22 +631,23 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
CloudModel::Zed(model) => {
|
CloudModel::Zed(model) => {
|
||||||
// All Zed models are OpenAI-based at the time of writing.
|
// All Zed models are OpenAI-based at the time of writing.
|
||||||
let mut request = request.into_open_ai(model.id().into());
|
let mut request = request.into_open_ai(model.id().into());
|
||||||
let client = self.client.clone();
|
request.tool_choice = Some(open_ai::ToolChoice::Other(
|
||||||
let mut function = open_ai::FunctionDefinition {
|
open_ai::ToolDefinition::Function {
|
||||||
name: tool_name.clone(),
|
function: open_ai::FunctionDefinition {
|
||||||
description: None,
|
name: tool_name.clone(),
|
||||||
parameters: None,
|
description: None,
|
||||||
};
|
parameters: None,
|
||||||
let func = open_ai::ToolDefinition::Function {
|
},
|
||||||
function: function.clone(),
|
},
|
||||||
};
|
));
|
||||||
request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
|
request.tools = vec![open_ai::ToolDefinition::Function {
|
||||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
function: open_ai::FunctionDefinition {
|
||||||
function.description = Some(tool_description);
|
name: tool_name.clone(),
|
||||||
function.parameters = Some(input_schema);
|
description: Some(tool_description),
|
||||||
request.tools = vec![open_ai::ToolDefinition::Function { function }];
|
parameters: Some(input_schema),
|
||||||
|
},
|
||||||
|
}];
|
||||||
|
|
||||||
let llm_api_token = self.llm_api_token.clone();
|
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let response = Self::perform_llm_completion(
|
let response = Self::perform_llm_completion(
|
||||||
|
@ -775,40 +663,12 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let mut body = BufReader::new(response.into_body());
|
Ok(open_ai::extract_tool_args_from_events(
|
||||||
let mut line = String::new();
|
tool_name,
|
||||||
let mut load_state = None;
|
Box::pin(response_lines(response)),
|
||||||
|
)
|
||||||
while body.read_line(&mut line).await? > 0 {
|
.await?
|
||||||
let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
|
.boxed())
|
||||||
line.clear();
|
|
||||||
|
|
||||||
for choice in part.choices {
|
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
for call in tool_calls {
|
|
||||||
if let Some(func) = call.function {
|
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
|
||||||
load_state = Some((String::default(), call.index));
|
|
||||||
}
|
|
||||||
if let Some((arguments, (output, index))) =
|
|
||||||
func.arguments.zip(load_state.as_mut())
|
|
||||||
{
|
|
||||||
if call.index == *index {
|
|
||||||
output.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some((arguments, _)) = load_state {
|
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
|
||||||
} else {
|
|
||||||
bail!("tool not used");
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
@ -816,6 +676,25 @@ impl LanguageModel for CloudLanguageModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn response_lines<T: DeserializeOwned>(
|
||||||
|
response: Response<AsyncBody>,
|
||||||
|
) -> impl Stream<Item = Result<T>> {
|
||||||
|
futures::stream::try_unfold(
|
||||||
|
(String::new(), BufReader::new(response.into_body())),
|
||||||
|
move |(mut line, mut body)| async {
|
||||||
|
match body.read_line(&mut line).await {
|
||||||
|
Ok(0) => Ok(None),
|
||||||
|
Ok(_) => {
|
||||||
|
let event: T = serde_json::from_str(&line)?;
|
||||||
|
line.clear();
|
||||||
|
Ok(Some((event, (line, body))))
|
||||||
|
}
|
||||||
|
Err(e) => Err(e.into()),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
impl LlmApiToken {
|
impl LlmApiToken {
|
||||||
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
async fn acquire(&self, client: &Arc<Client>) -> Result<String> {
|
||||||
let lock = self.0.upgradable_read().await;
|
let lock = self.0.upgradable_read().await;
|
||||||
|
|
|
@ -252,7 +252,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
_description: String,
|
_description: String,
|
||||||
_schema: serde_json::Value,
|
_schema: serde_json::Value,
|
||||||
_cx: &AsyncAppContext,
|
_cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,16 +3,11 @@ use crate::{
|
||||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||||
LanguageModelRequest,
|
LanguageModelRequest,
|
||||||
};
|
};
|
||||||
use anyhow::Context as _;
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
use futures::{
|
|
||||||
channel::{mpsc, oneshot},
|
|
||||||
future::BoxFuture,
|
|
||||||
stream::BoxStream,
|
|
||||||
FutureExt, StreamExt,
|
|
||||||
};
|
|
||||||
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
|
||||||
use http_client::Result;
|
use http_client::Result;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
|
use serde::Serialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use ui::WindowContext;
|
use ui::WindowContext;
|
||||||
|
|
||||||
|
@ -90,7 +85,7 @@ pub struct ToolUseRequest {
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct FakeLanguageModel {
|
pub struct FakeLanguageModel {
|
||||||
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
|
current_completion_txs: Mutex<Vec<(LanguageModelRequest, mpsc::UnboundedSender<String>)>>,
|
||||||
current_tool_use_txs: Mutex<Vec<(ToolUseRequest, oneshot::Sender<Result<serde_json::Value>>)>>,
|
current_tool_use_txs: Mutex<Vec<(ToolUseRequest, mpsc::UnboundedSender<String>)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeLanguageModel {
|
impl FakeLanguageModel {
|
||||||
|
@ -130,25 +125,11 @@ impl FakeLanguageModel {
|
||||||
self.end_completion_stream(self.pending_completions().last().unwrap());
|
self.end_completion_stream(self.pending_completions().last().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn respond_to_tool_use(
|
pub fn respond_to_last_tool_use<T: Serialize>(&self, response: T) {
|
||||||
&self,
|
let response = serde_json::to_string(&response).unwrap();
|
||||||
tool_call: &ToolUseRequest,
|
|
||||||
response: Result<serde_json::Value>,
|
|
||||||
) {
|
|
||||||
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
|
|
||||||
if let Some(index) = current_tool_call_txs
|
|
||||||
.iter()
|
|
||||||
.position(|(call, _)| call == tool_call)
|
|
||||||
{
|
|
||||||
let (_, tx) = current_tool_call_txs.remove(index);
|
|
||||||
tx.send(response).unwrap();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn respond_to_last_tool_use(&self, response: Result<serde_json::Value>) {
|
|
||||||
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
|
let mut current_tool_call_txs = self.current_tool_use_txs.lock();
|
||||||
let (_, tx) = current_tool_call_txs.pop().unwrap();
|
let (_, tx) = current_tool_call_txs.pop().unwrap();
|
||||||
tx.send(response).unwrap();
|
tx.unbounded_send(response).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -202,8 +183,8 @@ impl LanguageModel for FakeLanguageModel {
|
||||||
description: String,
|
description: String,
|
||||||
schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
_cx: &AsyncAppContext,
|
_cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
let (tx, rx) = oneshot::channel();
|
let (tx, rx) = mpsc::unbounded();
|
||||||
let tool_call = ToolUseRequest {
|
let tool_call = ToolUseRequest {
|
||||||
request,
|
request,
|
||||||
name,
|
name,
|
||||||
|
@ -211,7 +192,7 @@ impl LanguageModel for FakeLanguageModel {
|
||||||
schema,
|
schema,
|
||||||
};
|
};
|
||||||
self.current_tool_use_txs.lock().push((tool_call, tx));
|
self.current_tool_use_txs.lock().push((tool_call, tx));
|
||||||
async move { rx.await.context("FakeLanguageModel was dropped")? }.boxed()
|
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn as_fake(&self) -> &Self {
|
fn as_fake(&self) -> &Self {
|
||||||
|
|
|
@ -302,7 +302,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||||
_description: String,
|
_description: String,
|
||||||
_schema: serde_json::Value,
|
_schema: serde_json::Value,
|
||||||
_cx: &AsyncAppContext,
|
_cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||||
future::ready(Err(anyhow!("not implemented"))).boxed()
|
future::ready(Err(anyhow!("not implemented"))).boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ use ollama::{
|
||||||
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
|
||||||
ChatResponseDelta, OllamaToolCall,
|
ChatResponseDelta, OllamaToolCall,
|
||||||
};
|
};
|
||||||
use serde_json::Value;
|
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::{sync::Arc, time::Duration};
|
use std::{sync::Arc, time::Duration};
|
||||||
use ui::{prelude::*, ButtonLike, Indicator};
|
use ui::{prelude::*, ButtonLike, Indicator};
|
||||||
|
@ -311,7 +310,7 @@ impl LanguageModel for OllamaLanguageModel {
|
||||||
tool_description: String,
|
tool_description: String,
|
||||||
schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
use ollama::{OllamaFunctionTool, OllamaTool};
|
use ollama::{OllamaFunctionTool, OllamaTool};
|
||||||
let function = OllamaFunctionTool {
|
let function = OllamaFunctionTool {
|
||||||
name: tool_name.clone(),
|
name: tool_name.clone(),
|
||||||
|
@ -324,23 +323,19 @@ impl LanguageModel for OllamaLanguageModel {
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let response = response.await?;
|
let response = response.await?;
|
||||||
let ChatMessage::Assistant {
|
let ChatMessage::Assistant { tool_calls, .. } = response.message else {
|
||||||
tool_calls,
|
|
||||||
content,
|
|
||||||
} = response.message
|
|
||||||
else {
|
|
||||||
bail!("message does not have an assistant role");
|
bail!("message does not have an assistant role");
|
||||||
};
|
};
|
||||||
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
|
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
|
||||||
for call in tool_calls {
|
for call in tool_calls {
|
||||||
let OllamaToolCall::Function(function) = call;
|
let OllamaToolCall::Function(function) = call;
|
||||||
if function.name == tool_name {
|
if function.name == tool_name {
|
||||||
return Ok(function.arguments);
|
return Ok(futures::stream::once(async move {
|
||||||
|
Ok(function.arguments.to_string())
|
||||||
|
})
|
||||||
|
.boxed());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if let Ok(args) = serde_json::from_str::<Value>(&content) {
|
|
||||||
// Parse content as arguments.
|
|
||||||
return Ok(args);
|
|
||||||
} else {
|
} else {
|
||||||
bail!("assistant message does not have any tool calls");
|
bail!("assistant message does not have any tool calls");
|
||||||
};
|
};
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use anyhow::{anyhow, bail, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use collections::BTreeMap;
|
use collections::BTreeMap;
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
use futures::{future::BoxFuture, FutureExt, StreamExt};
|
||||||
|
@ -243,6 +243,7 @@ impl OpenAiLanguageModel {
|
||||||
async move { Ok(future.await?.boxed()) }.boxed()
|
async move { Ok(future.await?.boxed()) }.boxed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModel for OpenAiLanguageModel {
|
impl LanguageModel for OpenAiLanguageModel {
|
||||||
fn id(&self) -> LanguageModelId {
|
fn id(&self) -> LanguageModelId {
|
||||||
self.id.clone()
|
self.id.clone()
|
||||||
|
@ -293,55 +294,32 @@ impl LanguageModel for OpenAiLanguageModel {
|
||||||
tool_description: String,
|
tool_description: String,
|
||||||
schema: serde_json::Value,
|
schema: serde_json::Value,
|
||||||
cx: &AsyncAppContext,
|
cx: &AsyncAppContext,
|
||||||
) -> BoxFuture<'static, Result<serde_json::Value>> {
|
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
|
||||||
let mut request = request.into_open_ai(self.model.id().into());
|
let mut request = request.into_open_ai(self.model.id().into());
|
||||||
let mut function = FunctionDefinition {
|
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
|
||||||
name: tool_name.clone(),
|
function: FunctionDefinition {
|
||||||
description: None,
|
name: tool_name.clone(),
|
||||||
parameters: None,
|
description: None,
|
||||||
};
|
parameters: None,
|
||||||
let func = ToolDefinition::Function {
|
},
|
||||||
function: function.clone(),
|
}));
|
||||||
};
|
request.tools = vec![ToolDefinition::Function {
|
||||||
request.tool_choice = Some(ToolChoice::Other(func.clone()));
|
function: FunctionDefinition {
|
||||||
// Fill in description and params separately, as they're not needed for tool_choice field.
|
name: tool_name.clone(),
|
||||||
function.description = Some(tool_description);
|
description: Some(tool_description),
|
||||||
function.parameters = Some(schema);
|
parameters: Some(schema),
|
||||||
request.tools = vec![ToolDefinition::Function { function }];
|
},
|
||||||
|
}];
|
||||||
|
|
||||||
let response = self.stream_completion(request, cx);
|
let response = self.stream_completion(request, cx);
|
||||||
self.request_limiter
|
self.request_limiter
|
||||||
.run(async move {
|
.run(async move {
|
||||||
let mut response = response.await?;
|
let response = response.await?;
|
||||||
|
Ok(
|
||||||
// Call arguments are gonna be streamed in over multiple chunks.
|
open_ai::extract_tool_args_from_events(tool_name, Box::pin(response))
|
||||||
let mut load_state = None;
|
.await?
|
||||||
while let Some(Ok(part)) = response.next().await {
|
.boxed(),
|
||||||
for choice in part.choices {
|
)
|
||||||
let Some(tool_calls) = choice.delta.tool_calls else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
for call in tool_calls {
|
|
||||||
if let Some(func) = call.function {
|
|
||||||
if func.name.as_deref() == Some(tool_name.as_str()) {
|
|
||||||
load_state = Some((String::default(), call.index));
|
|
||||||
}
|
|
||||||
if let Some((arguments, (output, index))) =
|
|
||||||
func.arguments.zip(load_state.as_mut())
|
|
||||||
{
|
|
||||||
if call.index == *index {
|
|
||||||
output.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some((arguments, _)) = load_state {
|
|
||||||
return Ok(serde_json::from_str(&arguments)?);
|
|
||||||
} else {
|
|
||||||
bail!("tool not used");
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
use isahc::config::Configurable;
|
use isahc::config::Configurable;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::{value::RawValue, Value};
|
||||||
use std::{convert::TryFrom, sync::Arc, time::Duration};
|
use std::{convert::TryFrom, sync::Arc, time::Duration};
|
||||||
|
|
||||||
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
|
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
|
||||||
|
@ -92,7 +92,7 @@ impl Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
#[serde(tag = "role", rename_all = "lowercase")]
|
#[serde(tag = "role", rename_all = "lowercase")]
|
||||||
pub enum ChatMessage {
|
pub enum ChatMessage {
|
||||||
Assistant {
|
Assistant {
|
||||||
|
@ -107,16 +107,16 @@ pub enum ChatMessage {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum OllamaToolCall {
|
pub enum OllamaToolCall {
|
||||||
Function(OllamaFunctionCall),
|
Function(OllamaFunctionCall),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
pub struct OllamaFunctionCall {
|
pub struct OllamaFunctionCall {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub arguments: Value,
|
pub arguments: Box<RawValue>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||||
|
|
|
@ -6,7 +6,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||||
use isahc::config::Configurable;
|
use isahc::config::Configurable;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::{convert::TryFrom, future::Future, time::Duration};
|
use std::{convert::TryFrom, future::Future, pin::Pin, time::Duration};
|
||||||
use strum::EnumIter;
|
use strum::EnumIter;
|
||||||
|
|
||||||
pub use supported_countries::*;
|
pub use supported_countries::*;
|
||||||
|
@ -384,6 +384,57 @@ pub fn embed<'a>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn extract_tool_args_from_events(
|
||||||
|
tool_name: String,
|
||||||
|
mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
|
||||||
|
) -> Result<impl Send + Stream<Item = Result<String>>> {
|
||||||
|
let mut tool_use_index = None;
|
||||||
|
let mut first_chunk = None;
|
||||||
|
while let Some(event) = events.next().await {
|
||||||
|
let call = event?.choices.into_iter().find_map(|choice| {
|
||||||
|
choice.delta.tool_calls?.into_iter().find_map(|call| {
|
||||||
|
if call.function.as_ref()?.name.as_deref()? == tool_name {
|
||||||
|
Some(call)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
if let Some(call) = call {
|
||||||
|
tool_use_index = Some(call.index);
|
||||||
|
first_chunk = call.function.and_then(|func| func.arguments);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(tool_use_index) = tool_use_index else {
|
||||||
|
return Err(anyhow!("tool not used"));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(events.filter_map(move |event| {
|
||||||
|
let result = match event {
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
|
||||||
|
choice.delta.tool_calls?.into_iter().find_map(|call| {
|
||||||
|
if call.index == tool_use_index {
|
||||||
|
let func = call.function?;
|
||||||
|
let mut arguments = func.arguments?;
|
||||||
|
if let Some(mut first_chunk) = first_chunk.take() {
|
||||||
|
first_chunk.push_str(&arguments);
|
||||||
|
arguments = first_chunk
|
||||||
|
}
|
||||||
|
Some(Ok(arguments))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
async move { result }
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn extract_text_from_events(
|
pub fn extract_text_from_events(
|
||||||
response: impl Stream<Item = Result<ResponseStreamEvent>>,
|
response: impl Stream<Item = Result<ResponseStreamEvent>>,
|
||||||
) -> impl Stream<Item = Result<String>> {
|
) -> impl Stream<Item = Result<String>> {
|
||||||
|
|
Loading…
Reference in a new issue