mirror of
https://github.com/zed-industries/zed.git
synced 2024-12-24 17:28:40 +00:00
Add async host functions
This commit is contained in:
parent
a5a0abb895
commit
4565f1a976
6 changed files with 96 additions and 91 deletions
|
@ -115,11 +115,8 @@ pub fn import(args: TokenStream, function: TokenStream) -> TokenStream {
|
|||
})
|
||||
.unzip();
|
||||
|
||||
dbg!("hello");
|
||||
|
||||
let body = TokenStream::from(quote! {
|
||||
{
|
||||
// dbg!("executing imported function");
|
||||
// setup
|
||||
let data: (#( #tys ),*) = (#( #args ),*);
|
||||
let data = ::plugin::bincode::serialize(&data).unwrap();
|
||||
|
@ -137,12 +134,8 @@ pub fn import(args: TokenStream, function: TokenStream) -> TokenStream {
|
|||
}
|
||||
});
|
||||
|
||||
dbg!("hello2");
|
||||
|
||||
let block = parse_macro_input!(body as Block);
|
||||
|
||||
dbg!("hello {:?}", &block);
|
||||
|
||||
let inner_fn = ItemFn {
|
||||
attrs: fn_declare.attrs,
|
||||
vis: fn_declare.vis,
|
||||
|
|
|
@ -10,7 +10,7 @@ fn main() {
|
|||
let _ =
|
||||
std::fs::create_dir_all(base.join("bin")).expect("Could not make plugins bin directory");
|
||||
|
||||
std::process::Command::new("cargo")
|
||||
let build_successful = std::process::Command::new("cargo")
|
||||
.args([
|
||||
"build",
|
||||
"--release",
|
||||
|
@ -20,7 +20,9 @@ fn main() {
|
|||
base.join("Cargo.toml").to_str().unwrap(),
|
||||
])
|
||||
.status()
|
||||
.expect("Could not build plugins");
|
||||
.expect("Could not build plugins")
|
||||
.success();
|
||||
assert!(build_successful);
|
||||
|
||||
let binaries = std::fs::read_dir(base.join("target/wasm32-wasi/release"))
|
||||
.expect("Could not find compiled plugins in target");
|
||||
|
|
|
@ -18,12 +18,9 @@ mod tests {
|
|||
print: WasiFn<String, ()>,
|
||||
and_back: WasiFn<u32, u32>,
|
||||
imports: WasiFn<u32, u32>,
|
||||
half_async: WasiFn<u32, u32>,
|
||||
}
|
||||
|
||||
// async fn half(a: u32) -> u32 {
|
||||
// a / 2
|
||||
// }
|
||||
|
||||
async {
|
||||
let mut runtime = PluginBuilder::new_with_default_ctx()
|
||||
.unwrap()
|
||||
|
@ -35,8 +32,8 @@ mod tests {
|
|||
.unwrap()
|
||||
.host_function("import_swap", |(a, b): (u32, u32)| (b, a))
|
||||
.unwrap()
|
||||
// .host_function_async("import_half", half)
|
||||
// .unwrap()
|
||||
.host_function_async("import_half", |a: u32| async move { a / 2 })
|
||||
.unwrap()
|
||||
.init(include_bytes!("../../../plugins/bin/test_plugin.wasm"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
@ -51,6 +48,7 @@ mod tests {
|
|||
print: runtime.function("print").unwrap(),
|
||||
and_back: runtime.function("and_back").unwrap(),
|
||||
imports: runtime.function("imports").unwrap(),
|
||||
half_async: runtime.function("half_async").unwrap(),
|
||||
};
|
||||
|
||||
let unsorted = vec![1, 3, 4, 2, 5];
|
||||
|
@ -65,6 +63,7 @@ mod tests {
|
|||
assert_eq!(runtime.call(&plugin.print, "Hi!".into()).await.unwrap(), ());
|
||||
assert_eq!(runtime.call(&plugin.and_back, 1).await.unwrap(), 8);
|
||||
assert_eq!(runtime.call(&plugin.imports, 1).await.unwrap(), 8);
|
||||
assert_eq!(runtime.call(&plugin.half_async, 4).await.unwrap(), 2);
|
||||
|
||||
// dbg!("{}", runtime.call(&plugin.and_back, 1).await.unwrap());
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::{fs::File, marker::PhantomData, path::Path};
|
||||
|
||||
use anyhow::{anyhow, Error};
|
||||
|
@ -142,7 +144,7 @@ impl PluginBuilder {
|
|||
// move |_: Caller<'_, WasiCtxAlloc>, _: u64| {
|
||||
// // let function = &function;
|
||||
// Box::new(async {
|
||||
// let function = function;
|
||||
// // let function = function;
|
||||
// // Call the Host-side function
|
||||
// let result: u64 = function(7).await;
|
||||
// Ok(result)
|
||||
|
@ -152,68 +154,69 @@ impl PluginBuilder {
|
|||
// Ok(self)
|
||||
// }
|
||||
|
||||
// pub fn host_function_async<F, A, R>(mut self, name: &str, function: F) -> Result<Self, Error>
|
||||
// where
|
||||
// F: Fn(A) -> Pin<Box<dyn Future<Output = R> + Send + 'static>> + Send + Sync + 'static,
|
||||
// A: DeserializeOwned + Send,
|
||||
// R: Serialize + Send + Sync,
|
||||
// {
|
||||
// self.linker.func_wrap1_async(
|
||||
// "env",
|
||||
// &format!("__{}", name),
|
||||
// move |mut caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| {
|
||||
// let function = |args: Vec<u8>| {
|
||||
// let args = args;
|
||||
// let args: A = Wasi::deserialize_to_type(&args)?;
|
||||
// Ok(async {
|
||||
// let result = function(args);
|
||||
// Wasi::serialize_to_bytes(result.await).map_err(|_| {
|
||||
// Trap::new("Could not serialize value returned from function").into()
|
||||
// })
|
||||
// })
|
||||
// };
|
||||
pub fn host_function_async<F, A, R, Fut>(
|
||||
mut self,
|
||||
name: &str,
|
||||
function: F,
|
||||
) -> Result<Self, Error>
|
||||
where
|
||||
F: Fn(A) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = R> + Send + 'static,
|
||||
A: DeserializeOwned + Send + 'static,
|
||||
R: Serialize + Send + Sync + 'static,
|
||||
{
|
||||
self.linker.func_wrap1_async(
|
||||
"env",
|
||||
&format!("__{}", name),
|
||||
move |mut caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| {
|
||||
// TODO: use try block once avaliable
|
||||
let result: Result<(WasiBuffer, Memory, _), Trap> = (|| {
|
||||
// grab a handle to the memory
|
||||
let mut plugin_memory = match caller.get_export("memory") {
|
||||
Some(Extern::Memory(mem)) => mem,
|
||||
_ => return Err(Trap::new("Could not grab slice of plugin memory"))?,
|
||||
};
|
||||
|
||||
// // TODO: use try block once avaliable
|
||||
// let result: Result<(WasiBuffer, Memory, _), Trap> = (|| {
|
||||
// // grab a handle to the memory
|
||||
// let mut plugin_memory = match caller.get_export("memory") {
|
||||
// Some(Extern::Memory(mem)) => mem,
|
||||
// _ => return Err(Trap::new("Could not grab slice of plugin memory"))?,
|
||||
// };
|
||||
let buffer = WasiBuffer::from_u64(packed_buffer);
|
||||
|
||||
// let buffer = WasiBuffer::from_u64(packed_buffer);
|
||||
// get the args passed from Guest
|
||||
let args = Plugin::buffer_to_bytes(&mut plugin_memory, &mut caller, &buffer)?;
|
||||
|
||||
// // get the args passed from Guest
|
||||
// let args = Wasi::buffer_to_bytes(&mut plugin_memory, &mut caller, &buffer)?;
|
||||
let args: A = Plugin::deserialize_to_type(&args)?;
|
||||
|
||||
// // Call the Host-side function
|
||||
// let result = function(args);
|
||||
// Call the Host-side function
|
||||
let result = function(args);
|
||||
|
||||
// Ok((buffer, plugin_memory, result))
|
||||
// })();
|
||||
Ok((buffer, plugin_memory, result))
|
||||
})();
|
||||
|
||||
// Box::new(async move {
|
||||
// let (buffer, mut plugin_memory, thingo) = result?;
|
||||
// let thingo: Result<_, Error> = thingo;
|
||||
// let result: Result<Vec<u8>, Error> = thingo?.await;
|
||||
Box::new(async move {
|
||||
let (buffer, mut plugin_memory, future) = result?;
|
||||
|
||||
// // Wasi::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer).await?;
|
||||
let result: R = future.await;
|
||||
let result: Result<Vec<u8>, Error> = Plugin::serialize_to_bytes(result)
|
||||
.map_err(|_| {
|
||||
Trap::new("Could not serialize value returned from function").into()
|
||||
});
|
||||
let result = result?;
|
||||
|
||||
// // let buffer = Wasi::bytes_to_buffer(
|
||||
// // caller.data().alloc_buffer(),
|
||||
// // &mut plugin_memory,
|
||||
// // &mut caller,
|
||||
// // result,
|
||||
// // )
|
||||
// // .await?;
|
||||
Plugin::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer)
|
||||
.await?;
|
||||
|
||||
// // Ok(buffer.into_u64())
|
||||
// Ok(27)
|
||||
// })
|
||||
// },
|
||||
// )?;
|
||||
// Ok(self)
|
||||
// }
|
||||
let buffer = Plugin::bytes_to_buffer(
|
||||
caller.data().alloc_buffer(),
|
||||
&mut plugin_memory,
|
||||
&mut caller,
|
||||
result,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(buffer.into_u64())
|
||||
})
|
||||
},
|
||||
)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn host_function<A, R>(
|
||||
mut self,
|
||||
|
|
|
@ -4,8 +4,8 @@ use serde_json::json;
|
|||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// #[import]
|
||||
// fn command(string: &str) -> Option<String>;
|
||||
#[import]
|
||||
fn command(string: &str) -> Option<String>;
|
||||
|
||||
// #[no_mangle]
|
||||
// // TODO: switch len from usize to u32?
|
||||
|
@ -28,29 +28,29 @@ use std::path::PathBuf;
|
|||
// return new_buffer.leak_to_heap();
|
||||
// }
|
||||
|
||||
extern "C" {
|
||||
fn __command(buffer: u64) -> u64;
|
||||
}
|
||||
// extern "C" {
|
||||
// fn __command(buffer: u64) -> u64;
|
||||
// }
|
||||
|
||||
#[no_mangle]
|
||||
fn command(string: &str) -> Option<Vec<u8>> {
|
||||
dbg!("executing command: {}", string);
|
||||
// setup
|
||||
let data = string;
|
||||
let data = ::plugin::bincode::serialize(&data).unwrap();
|
||||
let buffer = unsafe { ::plugin::__Buffer::from_vec(data) };
|
||||
// #[no_mangle]
|
||||
// fn command(string: &str) -> Option<Vec<u8>> {
|
||||
// dbg!("executing command: {}", string);
|
||||
// // setup
|
||||
// let data = string;
|
||||
// let data = ::plugin::bincode::serialize(&data).unwrap();
|
||||
// let buffer = unsafe { ::plugin::__Buffer::from_vec(data) };
|
||||
|
||||
// operation
|
||||
let new_buffer = unsafe { __command(buffer.into_u64()) };
|
||||
let new_data = unsafe { ::plugin::__Buffer::from_u64(new_buffer).to_vec() };
|
||||
let new_data: Option<Vec<u8>> = match ::plugin::bincode::deserialize(&new_data) {
|
||||
Ok(d) => d,
|
||||
Err(e) => panic!("Data returned from function not deserializable."),
|
||||
};
|
||||
// // operation
|
||||
// let new_buffer = unsafe { __command(buffer.into_u64()) };
|
||||
// let new_data = unsafe { ::plugin::__Buffer::from_u64(new_buffer).to_vec() };
|
||||
// let new_data: Option<Vec<u8>> = match ::plugin::bincode::deserialize(&new_data) {
|
||||
// Ok(d) => d,
|
||||
// Err(e) => panic!("Data returned from function not deserializable."),
|
||||
// };
|
||||
|
||||
// teardown
|
||||
return new_data;
|
||||
}
|
||||
// // teardown
|
||||
// return new_data;
|
||||
// }
|
||||
|
||||
// TODO: some sort of macro to generate ABI bindings
|
||||
// extern "C" {
|
||||
|
|
|
@ -61,3 +61,11 @@ pub fn imports(x: u32) -> u32 {
|
|||
assert_eq!(x, b);
|
||||
a + b // should be 7 + x
|
||||
}
|
||||
|
||||
#[import]
|
||||
fn import_half(a: u32) -> u32;
|
||||
|
||||
#[export]
|
||||
pub fn half_async(a: u32) -> u32 {
|
||||
import_half(a)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue