From 4565f1a976da9794d781d0139abf2f72429bd58f Mon Sep 17 00:00:00 2001 From: Isaac Clayton Date: Mon, 13 Jun 2022 16:06:39 +0200 Subject: [PATCH] Add async host functions --- crates/plugin_macros/src/lib.rs | 7 -- crates/plugin_runtime/build.rs | 6 +- crates/plugin_runtime/src/lib.rs | 11 ++- crates/plugin_runtime/src/plugin.rs | 111 ++++++++++++++-------------- plugins/json_language/src/lib.rs | 44 +++++------ plugins/test_plugin/src/lib.rs | 8 ++ 6 files changed, 96 insertions(+), 91 deletions(-) diff --git a/crates/plugin_macros/src/lib.rs b/crates/plugin_macros/src/lib.rs index 1d86b6ddaf..55f31b8e0c 100644 --- a/crates/plugin_macros/src/lib.rs +++ b/crates/plugin_macros/src/lib.rs @@ -115,11 +115,8 @@ pub fn import(args: TokenStream, function: TokenStream) -> TokenStream { }) .unzip(); - dbg!("hello"); - let body = TokenStream::from(quote! { { - // dbg!("executing imported function"); // setup let data: (#( #tys ),*) = (#( #args ),*); let data = ::plugin::bincode::serialize(&data).unwrap(); @@ -137,12 +134,8 @@ pub fn import(args: TokenStream, function: TokenStream) -> TokenStream { } }); - dbg!("hello2"); - let block = parse_macro_input!(body as Block); - dbg!("hello {:?}", &block); - let inner_fn = ItemFn { attrs: fn_declare.attrs, vis: fn_declare.vis, diff --git a/crates/plugin_runtime/build.rs b/crates/plugin_runtime/build.rs index 1cc5b859c6..a6f0db38ac 100644 --- a/crates/plugin_runtime/build.rs +++ b/crates/plugin_runtime/build.rs @@ -10,7 +10,7 @@ fn main() { let _ = std::fs::create_dir_all(base.join("bin")).expect("Could not make plugins bin directory"); - std::process::Command::new("cargo") + let build_successful = std::process::Command::new("cargo") .args([ "build", "--release", @@ -20,7 +20,9 @@ fn main() { base.join("Cargo.toml").to_str().unwrap(), ]) .status() - .expect("Could not build plugins"); + .expect("Could not build plugins") + .success(); + assert!(build_successful); let binaries = std::fs::read_dir(base.join("target/wasm32-wasi/release")) .expect("Could not find compiled plugins in target"); diff --git a/crates/plugin_runtime/src/lib.rs b/crates/plugin_runtime/src/lib.rs index 81d7ba7e84..92a75ddb3c 100644 --- a/crates/plugin_runtime/src/lib.rs +++ b/crates/plugin_runtime/src/lib.rs @@ -18,12 +18,9 @@ mod tests { print: WasiFn, and_back: WasiFn, imports: WasiFn, + half_async: WasiFn, } - // async fn half(a: u32) -> u32 { - // a / 2 - // } - async { let mut runtime = PluginBuilder::new_with_default_ctx() .unwrap() @@ -35,8 +32,8 @@ mod tests { .unwrap() .host_function("import_swap", |(a, b): (u32, u32)| (b, a)) .unwrap() - // .host_function_async("import_half", half) - // .unwrap() + .host_function_async("import_half", |a: u32| async move { a / 2 }) + .unwrap() .init(include_bytes!("../../../plugins/bin/test_plugin.wasm")) .await .unwrap(); @@ -51,6 +48,7 @@ mod tests { print: runtime.function("print").unwrap(), and_back: runtime.function("and_back").unwrap(), imports: runtime.function("imports").unwrap(), + half_async: runtime.function("half_async").unwrap(), }; let unsorted = vec![1, 3, 4, 2, 5]; @@ -65,6 +63,7 @@ mod tests { assert_eq!(runtime.call(&plugin.print, "Hi!".into()).await.unwrap(), ()); assert_eq!(runtime.call(&plugin.and_back, 1).await.unwrap(), 8); assert_eq!(runtime.call(&plugin.imports, 1).await.unwrap(), 8); + assert_eq!(runtime.call(&plugin.half_async, 4).await.unwrap(), 2); // dbg!("{}", runtime.call(&plugin.and_back, 1).await.unwrap()); } diff --git a/crates/plugin_runtime/src/plugin.rs b/crates/plugin_runtime/src/plugin.rs index 98d9c2341b..a37cd9e60c 100644 --- a/crates/plugin_runtime/src/plugin.rs +++ b/crates/plugin_runtime/src/plugin.rs @@ -1,3 +1,5 @@ +use std::future::Future; +use std::pin::Pin; use std::{fs::File, marker::PhantomData, path::Path}; use anyhow::{anyhow, Error}; @@ -142,7 +144,7 @@ impl PluginBuilder { // move |_: Caller<'_, WasiCtxAlloc>, _: u64| { // // let function = &function; // Box::new(async { - // let function = function; + // // let function = function; // // Call the Host-side function // let result: u64 = function(7).await; // Ok(result) @@ -152,68 +154,69 @@ impl PluginBuilder { // Ok(self) // } - // pub fn host_function_async(mut self, name: &str, function: F) -> Result - // where - // F: Fn(A) -> Pin + Send + 'static>> + Send + Sync + 'static, - // A: DeserializeOwned + Send, - // R: Serialize + Send + Sync, - // { - // self.linker.func_wrap1_async( - // "env", - // &format!("__{}", name), - // move |mut caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| { - // let function = |args: Vec| { - // let args = args; - // let args: A = Wasi::deserialize_to_type(&args)?; - // Ok(async { - // let result = function(args); - // Wasi::serialize_to_bytes(result.await).map_err(|_| { - // Trap::new("Could not serialize value returned from function").into() - // }) - // }) - // }; + pub fn host_function_async( + mut self, + name: &str, + function: F, + ) -> Result + where + F: Fn(A) -> Fut + Send + Sync + 'static, + Fut: Future + Send + 'static, + A: DeserializeOwned + Send + 'static, + R: Serialize + Send + Sync + 'static, + { + self.linker.func_wrap1_async( + "env", + &format!("__{}", name), + move |mut caller: Caller<'_, WasiCtxAlloc>, packed_buffer: u64| { + // TODO: use try block once avaliable + let result: Result<(WasiBuffer, Memory, _), Trap> = (|| { + // grab a handle to the memory + let mut plugin_memory = match caller.get_export("memory") { + Some(Extern::Memory(mem)) => mem, + _ => return Err(Trap::new("Could not grab slice of plugin memory"))?, + }; - // // TODO: use try block once avaliable - // let result: Result<(WasiBuffer, Memory, _), Trap> = (|| { - // // grab a handle to the memory - // let mut plugin_memory = match caller.get_export("memory") { - // Some(Extern::Memory(mem)) => mem, - // _ => return Err(Trap::new("Could not grab slice of plugin memory"))?, - // }; + let buffer = WasiBuffer::from_u64(packed_buffer); - // let buffer = WasiBuffer::from_u64(packed_buffer); + // get the args passed from Guest + let args = Plugin::buffer_to_bytes(&mut plugin_memory, &mut caller, &buffer)?; - // // get the args passed from Guest - // let args = Wasi::buffer_to_bytes(&mut plugin_memory, &mut caller, &buffer)?; + let args: A = Plugin::deserialize_to_type(&args)?; - // // Call the Host-side function - // let result = function(args); + // Call the Host-side function + let result = function(args); - // Ok((buffer, plugin_memory, result)) - // })(); + Ok((buffer, plugin_memory, result)) + })(); - // Box::new(async move { - // let (buffer, mut plugin_memory, thingo) = result?; - // let thingo: Result<_, Error> = thingo; - // let result: Result, Error> = thingo?.await; + Box::new(async move { + let (buffer, mut plugin_memory, future) = result?; - // // Wasi::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer).await?; + let result: R = future.await; + let result: Result, Error> = Plugin::serialize_to_bytes(result) + .map_err(|_| { + Trap::new("Could not serialize value returned from function").into() + }); + let result = result?; - // // let buffer = Wasi::bytes_to_buffer( - // // caller.data().alloc_buffer(), - // // &mut plugin_memory, - // // &mut caller, - // // result, - // // ) - // // .await?; + Plugin::buffer_to_free(caller.data().free_buffer(), &mut caller, buffer) + .await?; - // // Ok(buffer.into_u64()) - // Ok(27) - // }) - // }, - // )?; - // Ok(self) - // } + let buffer = Plugin::bytes_to_buffer( + caller.data().alloc_buffer(), + &mut plugin_memory, + &mut caller, + result, + ) + .await?; + + Ok(buffer.into_u64()) + }) + }, + )?; + Ok(self) + } pub fn host_function( mut self, diff --git a/plugins/json_language/src/lib.rs b/plugins/json_language/src/lib.rs index f70d620ddb..7299a4748d 100644 --- a/plugins/json_language/src/lib.rs +++ b/plugins/json_language/src/lib.rs @@ -4,8 +4,8 @@ use serde_json::json; use std::fs; use std::path::PathBuf; -// #[import] -// fn command(string: &str) -> Option; +#[import] +fn command(string: &str) -> Option; // #[no_mangle] // // TODO: switch len from usize to u32? @@ -28,29 +28,29 @@ use std::path::PathBuf; // return new_buffer.leak_to_heap(); // } -extern "C" { - fn __command(buffer: u64) -> u64; -} +// extern "C" { +// fn __command(buffer: u64) -> u64; +// } -#[no_mangle] -fn command(string: &str) -> Option> { - dbg!("executing command: {}", string); - // setup - let data = string; - let data = ::plugin::bincode::serialize(&data).unwrap(); - let buffer = unsafe { ::plugin::__Buffer::from_vec(data) }; +// #[no_mangle] +// fn command(string: &str) -> Option> { +// dbg!("executing command: {}", string); +// // setup +// let data = string; +// let data = ::plugin::bincode::serialize(&data).unwrap(); +// let buffer = unsafe { ::plugin::__Buffer::from_vec(data) }; - // operation - let new_buffer = unsafe { __command(buffer.into_u64()) }; - let new_data = unsafe { ::plugin::__Buffer::from_u64(new_buffer).to_vec() }; - let new_data: Option> = match ::plugin::bincode::deserialize(&new_data) { - Ok(d) => d, - Err(e) => panic!("Data returned from function not deserializable."), - }; +// // operation +// let new_buffer = unsafe { __command(buffer.into_u64()) }; +// let new_data = unsafe { ::plugin::__Buffer::from_u64(new_buffer).to_vec() }; +// let new_data: Option> = match ::plugin::bincode::deserialize(&new_data) { +// Ok(d) => d, +// Err(e) => panic!("Data returned from function not deserializable."), +// }; - // teardown - return new_data; -} +// // teardown +// return new_data; +// } // TODO: some sort of macro to generate ABI bindings // extern "C" { diff --git a/plugins/test_plugin/src/lib.rs b/plugins/test_plugin/src/lib.rs index f0991d8c59..34d7d4cb13 100644 --- a/plugins/test_plugin/src/lib.rs +++ b/plugins/test_plugin/src/lib.rs @@ -61,3 +61,11 @@ pub fn imports(x: u32) -> u32 { assert_eq!(x, b); a + b // should be 7 + x } + +#[import] +fn import_half(a: u32) -> u32; + +#[export] +pub fn half_async(a: u32) -> u32 { + import_half(a) +}