zed/script/gemini.py
Nathan Sobo 8ae5a3b61a
Allow AI interactions to be proxied through Zed's server so you don't need an API key (#7367)
Co-authored-by: Antonio <antonio@zed.dev>

Resurrected this from some assistant work I did in Spring of 2023.
- [x] Resurrect streaming responses
- [x] Use streaming responses to enable AI via Zed's servers by default
(but preserve API key option for now)
- [x] Simplify protobuf
- [x] Proxy to OpenAI on zed.dev
- [x] Proxy to Gemini on zed.dev
- [x] Improve UX for switching between openAI and google models
- We current disallow cycling when setting a custom model, but we need a
better solution to keep OpenAI models available while testing the google
ones
- [x] Show remaining tokens correctly for Google models
- [x] Remove semantic index
- [x] Delete `ai` crate
- [x] Cloud front so we can ban abuse
- [x] Rate-limiting
- [x] Fix panic when using inline assistant
- [x] Double check the upgraded `AssistantSettings` are
backwards-compatible
- [x] Add hosted LLM interaction behind a `language-models` feature
flag.

Release Notes:

- We are temporarily removing the semantic index in order to redesign it
from scratch.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Thorsten <thorsten@zed.dev>
Co-authored-by: Max <max@zed.dev>
2024-03-19 19:22:26 +01:00

91 lines
3 KiB
Python

import subprocess
import json
import http.client
import mimetypes
import os
def get_text_files():
text_files = []
# List all files tracked by Git
git_files_proc = subprocess.run(['git', 'ls-files'], stdout=subprocess.PIPE, text=True)
for file in git_files_proc.stdout.strip().split('\n'):
# Check MIME type for each file
mime_check_proc = subprocess.run(['file', '--mime', file], stdout=subprocess.PIPE, text=True)
if 'text' in mime_check_proc.stdout:
text_files.append(file)
print(f"File count: {len(text_files)}")
return text_files
def get_file_contents(file):
# Read file content
with open(file, 'r') as f:
return f.read()
def main():
GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY')
# Your prompt
prompt = "Document the data types and dataflow in this codebase in preparation to port a streaming implementation to rust:\n\n"
# Fetch all text files
text_files = get_text_files()
code_blocks = []
for file in text_files:
file_contents = get_file_contents(file)
# Create a code block for each text file
code_blocks.append(f"\n`{file}`\n\n```{file_contents}```\n")
# Construct the JSON payload
payload = json.dumps({
"contents": [{
"parts": [{
"text": prompt + "".join(code_blocks)
}]
}]
})
# Prepare the HTTP connection
conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
# Define headers
headers = {
'Content-Type': 'application/json',
'Content-Length': str(len(payload))
}
# Output the content length in bytes
print(f"Content Length in kilobytes: {len(payload.encode('utf-8')) / 1024:.2f} KB")
# Send a request to count the tokens
conn.request("POST", f"/v1beta/models/gemini-1.5-pro-latest:countTokens?key={GEMINI_API_KEY}", body=payload, headers=headers)
# Get the response
response = conn.getresponse()
if response.status == 200:
token_count = json.loads(response.read().decode('utf-8')).get('totalTokens')
print(f"Token count: {token_count}")
else:
print(f"Failed to get token count. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
# Prepare the HTTP connection
conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
conn.request("GET", f"/v1beta/models/gemini-1.5-pro-latest:streamGenerateContent?key={GEMINI_API_KEY}", body=payload, headers=headers)
# Get the response in a streaming manner
response = conn.getresponse()
if response.status == 200:
print("Successfully sent the data to the API.")
# Read the response in chunks
while chunk := response.read(4096):
print(chunk.decode('utf-8'))
else:
print(f"Failed to send the data to the API. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
# Close the connection
conn.close()
if __name__ == "__main__":
main()