zed/crates/assistant_tooling
Kyle Kelley f176e8f0e4
Accept Views on LanguageModelTools (#10956)
Creates a `ToolView` trait to allow interactivity. This brings expanding
and collapsing to the excerpts from project index searches.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2024-04-25 13:03:43 -07:00
..
src Accept Views on LanguageModelTools (#10956) 2024-04-25 13:03:43 -07:00
Cargo.toml
LICENSE-GPL
README.md

Assistant Tooling

Bringing OpenAI compatible tool calling to GPUI.

This unlocks:

  • Structured Extraction of model responses
  • Validation of model inputs
  • Execution of chosen toolsn

Overview

Language Models can produce structured outputs that are perfect for calling functions. The most famous of these is OpenAI's tool calling. When make a chat completion you can pass a list of tools available to the model. The model will choose 0..n tools to help them complete a user's task. It's up to you to create the tools that the model can call.

User: "Hey I need help with implementing a collapsible panel in GPUI"

Assistant: "Sure, I can help with that. Let me see what I can find."

tool_calls: ["name": "query_codebase", arguments: "{ 'query': 'GPUI collapsible panel' }"]

result: "['crates/gpui/src/panel.rs:12: impl Panel { ... }', 'crates/gpui/src/panel.rs:20: impl Panel { ... }']"

Assistant: "Here are some excerpts from the GPUI codebase that might help you."

This library is designed to facilitate this interaction mode by allowing you to go from struct to tool with a simple trait, LanguageModelTool.

Example

Let's expose querying a semantic index directly by the model. First, we'll set up some necessary imports

use anyhow::Result;
use assistant_tooling::{LanguageModelTool, ToolRegistry};
use gpui::{App, AppContext, Task};
use schemars::JsonSchema;
use serde::Deserialize;
use serde_json::json;

Then we'll define the query structure the model must fill in. This must derive Deserialize from serde and JsonSchema from the schemars crate.

#[derive(Deserialize, JsonSchema)]
struct CodebaseQuery {
    query: String,
}

After that we can define our tool, with the expectation that it will need a ProjectIndex to search against. For this example, the index uses the same interface as semantic_index::ProjectIndex.

struct ProjectIndex {}

impl ProjectIndex {
    fn new() -> Self {
        ProjectIndex {}
    }

    fn search(&self, _query: &str, _limit: usize, _cx: &AppContext) -> Task<Result<Vec<String>>> {
        // Instead of hooking up a real index, we're going to fake it
        if _query.contains("gpui") {
            return Task::ready(Ok(vec![r#"// crates/gpui/src/gpui.rs
    //! # Welcome to GPUI!
    //!
    //! GPUI is a hybrid immediate and retained mode, GPU accelerated, UI framework
    //! for Rust, designed to support a wide variety of applications
    "#
            .to_string()]));
        }
        return Task::ready(Ok(vec![]));
    }
}

struct ProjectIndexTool {
    project_index: ProjectIndex,
}

Now we can implement the LanguageModelTool trait for our tool by:

  • Defining the Input from the model, which is CodebaseQuery
  • Defining the Output
  • Implementing the name and description functions to provide the model information when it's choosing a tool
  • Implementing the execute function to run the tool
impl LanguageModelTool for ProjectIndexTool {
    type Input = CodebaseQuery;
    type Output = String;

    fn name(&self) -> String {
        "query_codebase".to_string()
    }

    fn description(&self) -> String {
        "Executes a query against the codebase, returning excerpts related to the query".to_string()
    }

    fn execute(&self, query: Self::Input, cx: &AppContext) -> Task<Result<Self::Output>> {
        let results = self.project_index.search(query.query.as_str(), 10, cx);

        cx.spawn(|_cx| async move {
            let results = results.await?;

            if !results.is_empty() {
                Ok(results.join("\n"))
            } else {
                Ok("No results".to_string())
            }
        })
    }
}

For the sake of this example, let's look at the types that OpenAI will be passing to us

// OpenAI definitions, shown here for demonstration
#[derive(Deserialize)]
struct FunctionCall {
    name: String,
    args: String,
}

#[derive(Deserialize, Eq, PartialEq)]
enum ToolCallType {
    #[serde(rename = "function")]
    Function,
    Other,
}

#[derive(Deserialize, Clone, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
struct ToolCallId(String);

#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ToolCall {
    Function {
        #[allow(dead_code)]
        id: ToolCallId,
        function: FunctionCall,
    },
    Other {
        #[allow(dead_code)]
        id: ToolCallId,
    },
}

#[derive(Deserialize)]
struct AssistantMessage {
    role: String,
    content: Option<String>,
    tool_calls: Option<Vec<ToolCall>>,
}

When the model wants to call tools, it will pass a list of ToolCalls. When those are functions that we can handle, we'll pass them to our ToolRegistry to get a future that we can await.

// Inside `fn main()`
App::new().run(|cx: &mut AppContext| {
    let tool = ProjectIndexTool {
        project_index: ProjectIndex::new(),
    };

    let mut registry = ToolRegistry::new();
    let registered = registry.register(tool);
    assert!(registered.is_ok());

Let's pretend the model sent us back a message requesting

let model_response = json!({
    "role": "assistant",
    "tool_calls": [
        {
            "id": "call_1",
            "function": {
                "name": "query_codebase",
                "args": r#"{"query":"GPUI Task background_executor"}"#
            },
            "type": "function"
        }
    ]
});

let message: AssistantMessage = serde_json::from_value(model_response).unwrap();

// We know there's a tool call, so let's skip straight to it for this example
let tool_calls = message.tool_calls.as_ref().unwrap();
let tool_call = tool_calls.get(0).unwrap();

We can now use our registry to call the tool.

let task = registry.call(
    tool_call.name,
    tool_call.args,
);

cx.spawn(|_cx| async move {
    let result = task.await?;
    println!("{}", result.unwrap());
    Ok(())
})