use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::serde_json; use isahc::prelude::Configurable; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use std::env; use std::sync::Arc; use util::http::{HttpClient, Request}; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); } pub struct OpenAIEmbeddings { pub client: Arc, } #[derive(Serialize)] struct OpenAIEmbeddingRequest<'a> { model: &'static str, input: Vec<&'a str>, } #[derive(Deserialize)] struct OpenAIEmbeddingResponse { data: Vec, usage: OpenAIEmbeddingUsage, } #[derive(Debug, Deserialize)] struct OpenAIEmbedding { embedding: Vec, index: usize, object: String, } #[derive(Deserialize)] struct OpenAIEmbeddingUsage { prompt_tokens: usize, total_tokens: usize, } #[async_trait] pub trait EmbeddingProvider: Sync { async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; } #[async_trait] impl EmbeddingProvider for OpenAIEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { let api_key = OPENAI_API_KEY .as_ref() .ok_or_else(|| anyhow!("no api key"))?; let request = Request::post("https://api.openai.com/v1/embeddings") .redirect_policy(isahc::config::RedirectPolicy::Follow) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) .body( serde_json::to_string(&OpenAIEmbeddingRequest { input: spans, model: "text-embedding-ada-002", }) .unwrap() .into(), )?; let mut response = self.client.send(request).await?; if !response.status().is_success() { return Err(anyhow!("openai embedding failed {}", response.status())); } let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; log::info!( "openai embedding completed. tokens: {:?}", response.usage.total_tokens ); // do we need to re-order these based on the `index` field? eprintln!( "indices: {:?}", response .data .iter() .map(|embedding| embedding.index) .collect::>() ); Ok(response .data .into_iter() .map(|embedding| embedding.embedding) .collect()) } }