diff --git a/Cargo.toml b/Cargo.toml index 34322b4a..d0f79d5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "salsa" -version = "0.5.0" +version = "0.6.0" authors = ["Niko Matsakis "] edition = "2018" license = "Apache-2.0 OR MIT" diff --git a/examples/compiler/class_table.rs b/examples/compiler/class_table.rs index db21980a..0e009a28 100644 --- a/examples/compiler/class_table.rs +++ b/examples/compiler/class_table.rs @@ -9,12 +9,12 @@ salsa::query_group! { } /// Get the list of all classes - fn all_classes(key: ()) -> Arc> { + fn all_classes() -> Arc> { type AllClasses; } /// Get the list of all fields - fn all_fields(key: ()) -> Arc> { + fn all_fields() -> Arc> { type AllFields; } } @@ -23,7 +23,7 @@ salsa::query_group! { #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)] pub struct DefId(usize); -fn all_classes(_: &impl ClassTableDatabase, (): ()) -> Arc> { +fn all_classes(_: &impl ClassTableDatabase) -> Arc> { Arc::new(vec![DefId(0), DefId(10)]) // dummy impl } @@ -31,14 +31,15 @@ fn fields(_: &impl ClassTableDatabase, class: DefId) -> Arc> { Arc::new(vec![DefId(class.0 + 1), DefId(class.0 + 2)]) // dummy impl } -fn all_fields(db: &impl ClassTableDatabase, (): ()) -> Arc> { +fn all_fields(db: &impl ClassTableDatabase) -> Arc> { Arc::new( - db.all_classes(()) + db.all_classes() .iter() .cloned() .flat_map(|def_id| { let fields = db.fields(def_id); (0..fields.len()).map(move |i| fields[i]) - }).collect(), + }) + .collect(), ) } diff --git a/examples/compiler/main.rs b/examples/compiler/main.rs index 85d9c5fc..5f0cee0f 100644 --- a/examples/compiler/main.rs +++ b/examples/compiler/main.rs @@ -8,7 +8,7 @@ use self::implementation::DatabaseImpl; #[test] fn test() { let query = DatabaseImpl::default(); - let all_def_ids = query.all_fields(()); + let all_def_ids = query.all_fields(); assert_eq!( format!("{:?}", all_def_ids), "[DefId(1), DefId(2), DefId(11), DefId(12)]" @@ -17,7 +17,7 @@ fn test() { fn main() { let query = DatabaseImpl::default(); - for f in query.all_fields(()).iter() { + for f in query.all_fields().iter() { println!("{:?}", f); } } diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md index d84b4165..57450f21 100644 --- a/examples/hello_world/README.md +++ b/examples/hello_world/README.md @@ -51,8 +51,11 @@ database). Within this trait, we list out the queries that this group provides. Here, there are two: `input_string` and `length`. For each query, you -specify the key and value type of the query in the form of a function: -but the "fn body" is obviously not real Rust syntax. Rather, it's just +specify a function signature: the parameters to the function are +called the "key types" (in this case, we just give a single key of +type `()`) and the return type is the "value type". You can have any +number of key types. As you can see, though, this is not a real fn -- +the "fn body" is obviously not real Rust syntax. Rather, it's just used to specify a few bits of metadata about the query. We'll see how to define the fn body in the next step. diff --git a/src/lib.rs b/src/lib.rs index 682bda79..1fab6bf0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -140,10 +140,10 @@ where /// expand into a trait and a set of structs, one per query. /// /// For each query, you give the name of the accessor method to invoke -/// the query (e.g., `my_query`, below), as well as its input/output -/// types. You also give the name for a query type (e.g., `MyQuery`, -/// below) that represents the query, and optionally other details, -/// such as its storage. +/// the query (e.g., `my_query`, below), as well as its parameter +/// types and the output type. You also give the name for a query type +/// (e.g., `MyQuery`, below) that represents the query, and optionally +/// other details, such as its storage. /// /// ### Examples /// @@ -158,6 +158,12 @@ where /// storage memoized; // optional, this is the default /// use fn path::to::fn; // optional, default is `my_query` /// } +/// +/// /// Queries can have any number of inputs; the key type will be +/// /// a tuple of the input types, so in this case `(u32, f32)`. +/// fn other_query(input1: u32, input2: f32) -> u64 { +/// type OtherQuery; +/// } /// } /// } /// ``` @@ -190,7 +196,7 @@ macro_rules! query_group { tokens[{ $( $(#[$method_attr:meta])* - fn $method_name:ident($key_name:ident: $key_ty:ty) -> $value_ty:ty { + fn $method_name:ident($($key_name:ident: $key_ty:ty),* $(,)*) -> $value_ty:ty { type $QueryType:ident; $(storage $storage:ident;)* // FIXME(rust-lang/rust#48075) should be `?` $(use fn $fn_path:path;)* // FIXME(rust-lang/rust#48075) should be `?` @@ -201,9 +207,9 @@ macro_rules! query_group { $($trait_attr)* $v trait $query_trait: $($crate::plumbing::GetQueryTable<$QueryType> +)* $($header)* { $( $(#[$method_attr])* - fn $method_name(&self, key: $key_ty) -> $value_ty { + fn $method_name(&self, $($key_name: $key_ty),*) -> $value_ty { >::get_query_table(self) - .get(key) + .get(($($key_name),*)) } )* } @@ -216,7 +222,7 @@ macro_rules! query_group { where DB: $query_trait, { - type Key = $key_ty; + type Key = ($($key_ty),*); type Value = $value_ty; type Storage = $crate::query_group! { @storage_ty[DB, Self, $($storage)*] }; } @@ -228,6 +234,7 @@ macro_rules! query_group { fn_path($($fn_path)*); db_trait($query_trait); query_type($QueryType); + key($($key_name: $key_ty),*); ] } )* @@ -277,6 +284,8 @@ macro_rules! query_group { } }; + // Handle fns of one argument: once parenthesized patterns are stable on beta, + // we can remove this special case. ( @query_fn[ storage($($storage:ident)*); @@ -284,15 +293,39 @@ macro_rules! query_group { fn_path($fn_path:path); db_trait($DbTrait:path); query_type($QueryType:ty); + key($key_name:ident: $key_ty:ty); ] ) => { impl $crate::plumbing::QueryFunction for $QueryType where DB: $DbTrait { - fn execute(db: &DB, key: >::Key) + fn execute(db: &DB, $key_name: >::Key) -> >::Value { - $fn_path(db, key) + $fn_path(db, $key_name) + } + } + }; + + // Handle fns of N arguments: once parenthesized patterns are stable on beta, + // we can use this code for all cases. + ( + @query_fn[ + storage($($storage:ident)*); + method_name($method_name:ident); + fn_path($fn_path:path); + db_trait($DbTrait:path); + query_type($QueryType:ty); + key($($key_name:ident: $key_ty:ty),*); + ] + ) => { + impl $crate::plumbing::QueryFunction for $QueryType + where DB: $DbTrait + { + fn execute(db: &DB, ($($key_name),*): >::Key) + -> >::Value + { + $fn_path(db, $($key_name),*) } } }; diff --git a/tests/cycles.rs b/tests/cycles.rs index 4bc27c36..39e887aa 100644 --- a/tests/cycles.rs +++ b/tests/cycles.rs @@ -23,49 +23,49 @@ salsa::database_storage! { salsa::query_group! { trait Database: salsa::Database { // `a` and `b` depend on each other and form a cycle - fn memoized_a(key: ()) -> () { + fn memoized_a() -> () { type MemoizedA; } - fn memoized_b(key: ()) -> () { + fn memoized_b() -> () { type MemoizedB; } - fn volatile_a(key: ()) -> () { + fn volatile_a() -> () { type VolatileA; storage volatile; } - fn volatile_b(key: ()) -> () { + fn volatile_b() -> () { type VolatileB; storage volatile; } } } -fn memoized_a(db: &impl Database, (): ()) -> () { - db.memoized_b(()) +fn memoized_a(db: &impl Database) -> () { + db.memoized_b() } -fn memoized_b(db: &impl Database, (): ()) -> () { - db.memoized_a(()) +fn memoized_b(db: &impl Database) -> () { + db.memoized_a() } -fn volatile_a(db: &impl Database, (): ()) -> () { - db.volatile_b(()) +fn volatile_a(db: &impl Database) -> () { + db.volatile_b() } -fn volatile_b(db: &impl Database, (): ()) -> () { - db.volatile_a(()) +fn volatile_b(db: &impl Database) -> () { + db.volatile_a() } #[test] #[should_panic(expected = "cycle detected")] fn cycle_memoized() { let query = DatabaseImpl::default(); - query.memoized_a(()); + query.memoized_a(); } #[test] #[should_panic(expected = "cycle detected")] fn cycle_volatile() { let query = DatabaseImpl::default(); - query.volatile_a(()); + query.volatile_a(); } diff --git a/tests/incremental/memoized_dep_inputs.rs b/tests/incremental/memoized_dep_inputs.rs index 152ceb61..66692784 100644 --- a/tests/incremental/memoized_dep_inputs.rs +++ b/tests/incremental/memoized_dep_inputs.rs @@ -3,40 +3,40 @@ use salsa::Database; salsa::query_group! { pub(crate) trait MemoizedDepInputsContext: TestContext { - fn dep_memoized2(key: ()) -> usize { + fn dep_memoized2() -> usize { type Memoized2; } - fn dep_memoized1(key: ()) -> usize { + fn dep_memoized1() -> usize { type Memoized1; } - fn dep_derived1(key: ()) -> usize { + fn dep_derived1() -> usize { type Derived1; storage dependencies; } - fn dep_input1(key: ()) -> usize { + fn dep_input1() -> usize { type Input1; storage input; } - fn dep_input2(key: ()) -> usize { + fn dep_input2() -> usize { type Input2; storage input; } } } -fn dep_memoized2(db: &impl MemoizedDepInputsContext, (): ()) -> usize { +fn dep_memoized2(db: &impl MemoizedDepInputsContext) -> usize { db.log().add("Memoized2 invoked"); - db.dep_memoized1(()) + db.dep_memoized1() } -fn dep_memoized1(db: &impl MemoizedDepInputsContext, (): ()) -> usize { +fn dep_memoized1(db: &impl MemoizedDepInputsContext) -> usize { db.log().add("Memoized1 invoked"); - db.dep_derived1(()) * 2 + db.dep_derived1() * 2 } -fn dep_derived1(db: &impl MemoizedDepInputsContext, (): ()) -> usize { +fn dep_derived1(db: &impl MemoizedDepInputsContext) -> usize { db.log().add("Derived1 invoked"); - db.dep_input1(()) / 2 + db.dep_input1() / 2 } #[test] @@ -44,7 +44,7 @@ fn revalidate() { let db = &TestContextImpl::default(); // Initial run starts from Memoized2: - let v = db.dep_memoized2(()); + let v = db.dep_memoized2(); assert_eq!(v, 0); db.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Derived1 invoked"]); @@ -52,19 +52,19 @@ fn revalidate() { // running Memoized2. Note that we don't try to validate // Derived1, so it is invoked by Memoized1. db.query(Input1).set((), 44); - let v = db.dep_memoized2(()); + let v = db.dep_memoized2(); assert_eq!(v, 44); db.assert_log(&["Memoized1 invoked", "Derived1 invoked", "Memoized2 invoked"]); // Here validation of Memoized1 succeeds so Memoized2 never runs. db.query(Input1).set((), 45); - let v = db.dep_memoized2(()); + let v = db.dep_memoized2(); assert_eq!(v, 44); db.assert_log(&["Memoized1 invoked", "Derived1 invoked"]); // Here, a change to input2 doesn't affect us, so nothing runs. db.query(Input2).set((), 45); - let v = db.dep_memoized2(()); + let v = db.dep_memoized2(); assert_eq!(v, 44); db.assert_log(&[]); } diff --git a/tests/incremental/memoized_inputs.rs b/tests/incremental/memoized_inputs.rs index 23e690f3..d518d711 100644 --- a/tests/incremental/memoized_inputs.rs +++ b/tests/incremental/memoized_inputs.rs @@ -3,45 +3,45 @@ use salsa::Database; salsa::query_group! { pub(crate) trait MemoizedInputsContext: TestContext { - fn max(key: ()) -> usize { + fn max() -> usize { type Max; } - fn input1(key: ()) -> usize { + fn input1() -> usize { type Input1; storage input; } - fn input2(key: ()) -> usize { + fn input2() -> usize { type Input2; storage input; } } } -fn max(db: &impl MemoizedInputsContext, (): ()) -> usize { +fn max(db: &impl MemoizedInputsContext) -> usize { db.log().add("Max invoked"); - std::cmp::max(db.input1(()), db.input2(())) + std::cmp::max(db.input1(), db.input2()) } #[test] fn revalidate() { let db = &TestContextImpl::default(); - let v = db.max(()); + let v = db.max(); assert_eq!(v, 0); db.assert_log(&["Max invoked"]); - let v = db.max(()); + let v = db.max(); assert_eq!(v, 0); db.assert_log(&[]); db.query(Input1).set((), 44); db.assert_log(&[]); - let v = db.max(()); + let v = db.max(); assert_eq!(v, 44); db.assert_log(&["Max invoked"]); - let v = db.max(()); + let v = db.max(); assert_eq!(v, 44); db.assert_log(&[]); @@ -52,11 +52,11 @@ fn revalidate() { db.query(Input1).set((), 64); db.assert_log(&[]); - let v = db.max(()); + let v = db.max(); assert_eq!(v, 66); db.assert_log(&["Max invoked"]); - let v = db.max(()); + let v = db.max(); assert_eq!(v, 66); db.assert_log(&[]); } @@ -68,12 +68,12 @@ fn set_after_no_change() { let db = &TestContextImpl::default(); db.query(Input1).set((), 44); - let v = db.max(()); + let v = db.max(); assert_eq!(v, 44); db.assert_log(&["Max invoked"]); db.query(Input1).set((), 44); - let v = db.max(()); + let v = db.max(); assert_eq!(v, 44); db.assert_log(&[]); } diff --git a/tests/incremental/memoized_volatile.rs b/tests/incremental/memoized_volatile.rs index b01c2976..ac0166ba 100644 --- a/tests/incremental/memoized_volatile.rs +++ b/tests/incremental/memoized_volatile.rs @@ -5,31 +5,31 @@ salsa::query_group! { pub(crate) trait MemoizedVolatileContext: TestContext { // Queries for testing a "volatile" value wrapped by // memoization. - fn memoized2(key: ()) -> usize { + fn memoized2() -> usize { type Memoized2; } - fn memoized1(key: ()) -> usize { + fn memoized1() -> usize { type Memoized1; } - fn volatile(key: ()) -> usize { + fn volatile() -> usize { type Volatile; storage volatile; } } } -fn memoized2(db: &impl MemoizedVolatileContext, (): ()) -> usize { +fn memoized2(db: &impl MemoizedVolatileContext) -> usize { db.log().add("Memoized2 invoked"); - db.memoized1(()) + db.memoized1() } -fn memoized1(db: &impl MemoizedVolatileContext, (): ()) -> usize { +fn memoized1(db: &impl MemoizedVolatileContext) -> usize { db.log().add("Memoized1 invoked"); - let v = db.volatile(()); + let v = db.volatile(); v / 2 } -fn volatile(db: &impl MemoizedVolatileContext, (): ()) -> usize { +fn volatile(db: &impl MemoizedVolatileContext) -> usize { db.log().add("Volatile invoked"); db.clock().increment() } @@ -40,8 +40,8 @@ fn volatile_x2() { // Invoking volatile twice doesn't execute twice, because volatile // queries are memoized by default. - query.volatile(()); - query.volatile(()); + query.volatile(); + query.volatile(); query.assert_log(&["Volatile invoked"]); } @@ -57,20 +57,20 @@ fn volatile_x2() { fn revalidate() { let query = TestContextImpl::default(); - query.memoized2(()); + query.memoized2(); query.assert_log(&["Memoized2 invoked", "Memoized1 invoked", "Volatile invoked"]); - query.memoized2(()); + query.memoized2(); query.assert_log(&[]); // Second generation: volatile will change (to 1) but memoized1 // will not (still 0, as 1/2 = 0) query.salsa_runtime().next_revision(); - query.memoized2(()); + query.memoized2(); query.assert_log(&["Volatile invoked", "Memoized1 invoked"]); - query.memoized2(()); + query.memoized2(); query.assert_log(&[]); // Third generation: volatile will change (to 2) and memoized1 @@ -78,9 +78,9 @@ fn revalidate() { // changed, we now invoke Memoized2. query.salsa_runtime().next_revision(); - query.memoized2(()); + query.memoized2(); query.assert_log(&["Volatile invoked", "Memoized1 invoked", "Memoized2 invoked"]); - query.memoized2(()); + query.memoized2(); query.assert_log(&[]); } diff --git a/tests/set_unchecked.rs b/tests/set_unchecked.rs index 3569b687..f27b0c8c 100644 --- a/tests/set_unchecked.rs +++ b/tests/set_unchecked.rs @@ -2,29 +2,29 @@ use salsa::Database; salsa::query_group! { trait HelloWorldDatabase: salsa::Database { - fn input(key: ()) -> String { + fn input() -> String { type Input; storage input; } - fn length(key: ()) -> usize { + fn length() -> usize { type Length; } - fn double_length(key: ()) -> usize { + fn double_length() -> usize { type DoubleLength; } } } -fn length(db: &impl HelloWorldDatabase, (): ()) -> usize { - let l = db.input(()).len(); +fn length(db: &impl HelloWorldDatabase) -> usize { + let l = db.input().len(); assert!(l > 0); // not meant to be invoked with no input l } -fn double_length(db: &impl HelloWorldDatabase, (): ()) -> usize { - db.length(()) * 2 +fn double_length(db: &impl HelloWorldDatabase) -> usize { + db.length() * 2 } #[derive(Default)] @@ -52,35 +52,35 @@ salsa::database_storage! { fn normal() { let db = DatabaseStruct::default(); db.query(Input).set((), format!("Hello, world")); - assert_eq!(db.double_length(()), 24); + assert_eq!(db.double_length(), 24); db.query(Input).set((), format!("Hello, world!")); - assert_eq!(db.double_length(()), 26); + assert_eq!(db.double_length(), 26); } #[test] #[should_panic] fn use_without_set() { let db = DatabaseStruct::default(); - db.double_length(()); + db.double_length(); } #[test] fn using_set_unchecked_on_input() { let db = DatabaseStruct::default(); db.query(Input).set_unchecked((), format!("Hello, world")); - assert_eq!(db.double_length(()), 24); + assert_eq!(db.double_length(), 24); } #[test] fn using_set_unchecked_on_input_after() { let db = DatabaseStruct::default(); db.query(Input).set((), format!("Hello, world")); - assert_eq!(db.double_length(()), 24); + assert_eq!(db.double_length(), 24); // If we use `set_unchecked`, we don't notice that `double_length` // is out of date. Oh well, don't do that. db.query(Input).set_unchecked((), format!("Hello, world!")); - assert_eq!(db.double_length(()), 24); + assert_eq!(db.double_length(), 24); } #[test] @@ -91,5 +91,5 @@ fn using_set_unchecked() { // demonstrating that the code never runs. db.query(Length).set_unchecked((), 24); - assert_eq!(db.double_length(()), 48); + assert_eq!(db.double_length(), 48); } diff --git a/tests/storage_varieties/queries.rs b/tests/storage_varieties/queries.rs index 763e48ab..89cfaedf 100644 --- a/tests/storage_varieties/queries.rs +++ b/tests/storage_varieties/queries.rs @@ -4,10 +4,10 @@ pub(crate) trait Counter: salsa::Database { salsa::query_group! { pub(crate) trait Database: Counter { - fn memoized(key: ()) -> usize { + fn memoized() -> usize { type Memoized; } - fn volatile(key: ()) -> usize { + fn volatile() -> usize { type Volatile; storage volatile; } @@ -16,12 +16,12 @@ salsa::query_group! { /// Because this query is memoized, we only increment the counter /// the first time it is invoked. -fn memoized(db: &impl Database, (): ()) -> usize { - db.volatile(()) +fn memoized(db: &impl Database) -> usize { + db.volatile() } /// Because this query is volatile, each time it is invoked, /// we will increment the counter. -fn volatile(db: &impl Database, (): ()) -> usize { +fn volatile(db: &impl Database) -> usize { db.increment() } diff --git a/tests/storage_varieties/tests.rs b/tests/storage_varieties/tests.rs index 486a44fe..ce38ba8e 100644 --- a/tests/storage_varieties/tests.rs +++ b/tests/storage_varieties/tests.rs @@ -7,22 +7,22 @@ use salsa::Database as _Database; #[test] fn memoized_twice() { let db = DatabaseImpl::default(); - let v1 = db.memoized(()); - let v2 = db.memoized(()); + let v1 = db.memoized(); + let v2 = db.memoized(); assert_eq!(v1, v2); } #[test] fn volatile_twice() { let db = DatabaseImpl::default(); - let v1 = db.volatile(()); - let v2 = db.volatile(()); // volatiles are cached, so 2nd read returns the same + let v1 = db.volatile(); + let v2 = db.volatile(); // volatiles are cached, so 2nd read returns the same assert_eq!(v1, v2); db.salsa_runtime().next_revision(); // clears volatile caches - let v3 = db.volatile(()); // will re-increment the counter - let v4 = db.volatile(()); // second call will be cached + let v3 = db.volatile(); // will re-increment the counter + let v4 = db.volatile(); // second call will be cached assert_eq!(v1 + 1, v3); assert_eq!(v3, v4); } @@ -30,10 +30,10 @@ fn volatile_twice() { #[test] fn intermingled() { let db = DatabaseImpl::default(); - let v1 = db.volatile(()); - let v2 = db.memoized(()); - let v3 = db.volatile(()); // cached - let v4 = db.memoized(()); // cached + let v1 = db.volatile(); + let v2 = db.memoized(); + let v3 = db.volatile(); // cached + let v4 = db.memoized(); // cached assert_eq!(v1, v2); assert_eq!(v1, v3); @@ -41,8 +41,8 @@ fn intermingled() { db.salsa_runtime().next_revision(); // clears volatile caches - let v5 = db.memoized(()); // re-executes volatile, caches new result - let v6 = db.memoized(()); // re-use cached result + let v5 = db.memoized(); // re-executes volatile, caches new result + let v6 = db.memoized(); // re-use cached result assert_eq!(v4 + 1, v5); assert_eq!(v5, v6); } diff --git a/tests/variadic.rs b/tests/variadic.rs new file mode 100644 index 00000000..4d184e85 --- /dev/null +++ b/tests/variadic.rs @@ -0,0 +1,79 @@ +use salsa::Database; + +salsa::query_group! { + trait HelloWorldDatabase: salsa::Database { + fn input(a: u32, b: u32) -> u32 { + type Input; + storage input; + } + + fn none() -> u32 { + type None; + } + + fn one(k: u32) -> u32 { + type One; + } + + fn two(a: u32, b: u32) -> u32 { + type Two; + } + + fn trailing(a: u32, b: u32,) -> u32 { + type Trailing; + } + } +} + +fn none(_db: &impl HelloWorldDatabase) -> u32 { + 22 +} + +fn one(_db: &impl HelloWorldDatabase, k: u32) -> u32 { + k * 2 +} + +fn two(_db: &impl HelloWorldDatabase, a: u32, b: u32) -> u32 { + a * b +} + +fn trailing(_db: &impl HelloWorldDatabase, a: u32, b: u32) -> u32 { + a - b +} + +#[derive(Default)] +struct DatabaseStruct { + runtime: salsa::Runtime, +} + +impl salsa::Database for DatabaseStruct { + fn salsa_runtime(&self) -> &salsa::Runtime { + &self.runtime + } +} + +salsa::database_storage! { + struct DatabaseStorage for DatabaseStruct { + impl HelloWorldDatabase { + fn input() for Input; + fn none() for None; + fn one() for One; + fn two() for Two; + fn trailing() for Trailing; + } + } +} + +#[test] +fn execute() { + let db = DatabaseStruct::default(); + + // test what happens with inputs: + db.query(Input).set((1, 2), 3); + assert_eq!(db.input(1, 2), 3); + + assert_eq!(db.none(), 22); + assert_eq!(db.one(11), 22); + assert_eq!(db.two(11, 2), 22); + assert_eq!(db.trailing(24, 2), 22); +}