diff --git a/tests/ui/fail/loop_invariant_generic.rs b/tests/ui/fail/loop_invariant_generic.rs new file mode 100644 index 0000000..5342e97 --- /dev/null +++ b/tests/ui/fail/loop_invariant_generic.rs @@ -0,0 +1,22 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +#[thrust_macros::invariant_context] +fn keep(v: T) { + let mut x = v; + while rand() == 0 { + thrust_macros::invariant!(|v: T| v == v); + x = v; + } + assert!(x == v); +} + +fn main() { + keep(0_i64); + keep(true); +} diff --git a/tests/ui/fail/loop_invariant_self.rs b/tests/ui/fail/loop_invariant_self.rs new file mode 100644 index 0000000..5ff43b6 --- /dev/null +++ b/tests/ui/fail/loop_invariant_self.rs @@ -0,0 +1,30 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +struct Counter(i64); +impl thrust_models::Model for Counter { + type Ty = (thrust_models::model::Int,); +} + +#[thrust_macros::context] +impl Counter { + #[thrust_macros::invariant_context] + fn run(self) { + let mut c = self; + let mut x = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64, c: Self| x >= 2 && c == c); + x = x + 1; + c = Counter(0); + } + let _last = c; + assert!(x >= 1); + } +} + +fn main() { Counter(0).run(); } diff --git a/tests/ui/pass/loop_invariant_generic.rs b/tests/ui/pass/loop_invariant_generic.rs new file mode 100644 index 0000000..7e31b00 --- /dev/null +++ b/tests/ui/pass/loop_invariant_generic.rs @@ -0,0 +1,22 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +#[thrust_macros::invariant_context] +fn keep(v: T) { + let mut x = v; + while rand() == 0 { + thrust_macros::invariant!(|x: T, v: T| x == v); + x = v; + } + assert!(x == v); +} + +fn main() { + keep(0_i64); + keep(true); +} diff --git a/tests/ui/pass/loop_invariant_generic_closure.rs b/tests/ui/pass/loop_invariant_generic_closure.rs new file mode 100644 index 0000000..73668eb --- /dev/null +++ b/tests/ui/pass/loop_invariant_generic_closure.rs @@ -0,0 +1,26 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +// A closure-typed generic param must not be given a `Model` bound: the +// invariant only constrains the `Model`-typed `T`, and `keep` must still be +// callable with a real closure. +#[thrust_macros::invariant_context] +fn keep i64, T: Copy + PartialEq>(f: F, v: T) { + let _ = f; + let mut x = v; + while rand() == 0 { + thrust_macros::invariant!(|x: T, v: T| x == v); + x = v; + } + assert!(x == v); +} + +fn main() { + keep(|x| x + 1, 0_i64); + keep(|x| x, true); +} diff --git a/tests/ui/pass/loop_invariant_self.rs b/tests/ui/pass/loop_invariant_self.rs new file mode 100644 index 0000000..f24809d --- /dev/null +++ b/tests/ui/pass/loop_invariant_self.rs @@ -0,0 +1,30 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[thrust_macros::requires(true)] +#[thrust_macros::ensures(true)] +#[thrust::trusted] +fn rand() -> i64 { unimplemented!() } + +struct Counter(i64); +impl thrust_models::Model for Counter { + type Ty = (thrust_models::model::Int,); +} + +#[thrust_macros::context] +impl Counter { + #[thrust_macros::invariant_context] + fn run(self) { + let mut c = self; + let mut x = 1_i64; + while rand() == 0 { + thrust_macros::invariant!(|x: i64, c: Self| x >= 1 && c == c); + x = x + 1; + c = Counter(0); + } + let _last = c; + assert!(x >= 1); + } +} + +fn main() { Counter(0).run(); } diff --git a/thrust-macros/src/invariant.rs b/thrust-macros/src/invariant.rs index 8edf62e..afdd447 100644 --- a/thrust-macros/src/invariant.rs +++ b/thrust-macros/src/invariant.rs @@ -1,54 +1,217 @@ -//! Expansion of `thrust_macros::invariant!`. +//! Expansion of `thrust_macros::invariant!` and its context-carrying sibling +//! `thrust_macros::_invariant_with_context!`. //! -//! Expands a closure with concrete parameter types into a -//! `#[thrust::formula_fn]` over `Model::Ty` parameters and a marker call -//! referencing it. +//! Both expand a predicate closure with explicit parameter types into a +//! `#[thrust::formula_fn]` over `Model::Ty` parameters plus a marker call +//! referencing it; they share [`expand_invariant`] and differ only in input: +//! +//! - `invariant!(|x: i64| x >= 1)` takes a bare predicate closure and only sees +//! concrete types. +//! - `_invariant_with_context!(..)` additionally carries the enclosing generic +//! context. It is never written by hand: `#[thrust_macros::invariant_context]` +//! rewrites each `invariant!` it finds into this form, pasting the host +//! function's signature (and, in methods, a `#[thrust::_outer_context(..)]` +//! attribute carrying the enclosing `impl`/`trait` header) ahead of the +//! closure: +//! +//! ```ignore +//! _invariant_with_context!( +//! #[thrust::_outer_context(impl Foo where ..)] // methods only +//! fn f(..) -> .. where ..; // host signature, as-is +//! |x: T, v: T| x == v +//! ) +//! ``` +//! +//! The in-scope generics (shadowing the enclosing ones) are re-declared on the +//! formula function and instantiated via turbofish; in methods, `Self` is +//! re-declared as a synthetic type parameter and instantiated with the real +//! `Self` (legal in expression position). use std::sync::atomic::{AtomicUsize, Ordering}; use proc_macro::TokenStream; -use quote::{format_ident, quote}; -use syn::{parse_macro_input, FnArg}; +use proc_macro2::TokenStream as TokenStream2; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, + visit_mut::VisitMut, + FnArg, GenericParam, Signature, WherePredicate, +}; -use crate::fn_params_with_model_ty; +use crate::fn_outer_item::FnOuterItem; static COUNTER: AtomicUsize = AtomicUsize::new(0); +/// Expands `invariant!(CLOSURE)`: a bare predicate closure with no threaded +/// context. pub fn expand(input: TokenStream) -> TokenStream { let closure = parse_macro_input!(input as syn::ExprClosure); + match expand_invariant(&closure, None) { + Ok(expr) => expr.into_token_stream().into(), + Err(e) => e.to_compile_error().into(), + } +} + +/// Expands `_invariant_with_context!(#outer_attr #sig; CLOSURE)`, the form +/// `#[thrust_macros::invariant_context]` rewrites each `invariant!` into. +pub fn expand_with_context(input: TokenStream) -> TokenStream { + struct WithContext { + context: Context, + closure: syn::ExprClosure, + } + + impl Parse for WithContext { + fn parse(input: ParseStream) -> syn::Result { + let attrs = input.call(syn::Attribute::parse_outer)?; + let outer = crate::extract_outer_context(&attrs)?; + let sig: Signature = input.parse()?; + input.parse::()?; + let closure: syn::ExprClosure = input.parse()?; + Ok(Self { + context: Context { sig, outer }, + closure, + }) + } + } + + let WithContext { closure, context } = parse_macro_input!(input as WithContext); + match expand_invariant(&closure, Some(&context)) { + Ok(expr) => expr.into_token_stream().into(), + Err(e) => e.to_compile_error().into(), + } +} + +/// The enclosing context threaded into an invariant by +/// `#[thrust_macros::invariant_context]`: the host function signature and, for a +/// method, its `impl`/`trait` header. A standalone `invariant!` has none. +struct Context { + sig: Signature, + outer: Option, +} + +impl Context { + /// The generic params in scope: the host signature's own, plus the outer + /// `impl`/`trait`'s for a method. + fn generic_params(&self) -> impl Iterator { + self.sig + .generics + .params + .iter() + .chain(self.outer.iter().flat_map(|o| o.generics().params.iter())) + } + + /// The where-predicates in scope, from the host signature and (for a method) + /// the outer `impl`/`trait`. + fn where_predicates(&self) -> impl Iterator { + fn preds(g: &syn::Generics) -> impl Iterator { + g.where_clause.iter().flat_map(|wc| wc.predicates.iter()) + } + preds(&self.sig.generics).chain(self.outer.iter().flat_map(|o| preds(o.generics()))) + } +} +/// Expands a predicate closure into a `#[thrust::formula_fn]` plus a marker +/// call. With `context`, the in-scope generics (and, in methods, `Self`) are +/// re-declared on the formula function and instantiated via turbofish. +fn expand_invariant( + closure: &syn::ExprClosure, + context: Option<&Context>, +) -> syn::Result { let mut fn_params: Vec = Vec::new(); for param in &closure.inputs { let syn::Pat::Type(pt) = param else { - return syn::Error::new_spanned( + return Err(syn::Error::new_spanned( param, "invariant closure parameters must have explicit types, e.g. `|x: i64| ...`", - ) - .to_compile_error() - .into(); + )); }; let pat = &pt.pat; let ty = &pt.ty; fn_params.push(syn::parse_quote!(#pat: #ty)); } - let model_ty_params = fn_params_with_model_ty(&fn_params); + let mut def_params: Vec = Vec::new(); + let mut turbofish_args: Vec = Vec::new(); + for param in context.into_iter().flat_map(Context::generic_params) { + def_params.push(param.to_token_stream()); + match param { + GenericParam::Type(tp) => turbofish_args.push(tp.ident.to_token_stream()), + GenericParam::Const(cp) => turbofish_args.push(cp.ident.to_token_stream()), + GenericParam::Lifetime(_) => {} + } + } + + let mut def_wheres: Vec = context + .into_iter() + .flat_map(Context::where_predicates) + .cloned() + .collect(); + if let Some(context) = context { + def_wheres.extend(crate::model_where_predicates( + &context.sig, + context.outer.as_ref(), + )); + } + + // `Self` in a method context: rewrite it to a synthetic generic, then pass + // the real `Self` via turbofish (legal in expression position). + if crate::tokens_contain_ident(&closure.to_token_stream(), "Self") { + let synth: syn::Ident = format_ident!("__ThrustSelf"); + for param in &mut fn_params { + SelfRewriter { synth: &synth }.visit_fn_arg_mut(param); + } + def_params.push(quote!(#synth)); + def_wheres.extend(crate::model_predicates(&synth)); + turbofish_args.push(quote!(Self)); + } + + let model_ty_params = crate::fn_params_with_model_ty(&fn_params); let body = &closure.body; let id = COUNTER.fetch_add(1, Ordering::Relaxed); let name = format_ident!("_thrust_invariant_{}", id); - quote! { + let def_generics = if def_params.is_empty() { + quote!() + } else { + quote!(<#(#def_params),*>) + }; + let where_clause = if def_wheres.is_empty() { + quote!() + } else { + quote!(where #(#def_wheres),*) + }; + let turbofish = if turbofish_args.is_empty() { + quote!() + } else { + quote!(::<#(#turbofish_args),*>) + }; + + Ok(syn::parse_quote!({ + #[allow(unused_variables)] + #[allow(non_snake_case)] + #[thrust::formula_fn] + fn #name #def_generics(#model_ty_params) -> bool #where_clause { + #body + } + + thrust_models::__invariant_marker(#name #turbofish) + })) +} + +struct SelfRewriter<'a> { + synth: &'a syn::Ident, +} + +impl VisitMut for SelfRewriter<'_> { + fn visit_path_mut(&mut self, path: &mut syn::Path) { + syn::visit_mut::visit_path_mut(self, path); + if path.leading_colon.is_none() + && path.segments.len() == 1 + && path.segments[0].ident == "Self" { - #[allow(unused_variables)] - #[allow(non_snake_case)] - #[thrust::formula_fn] - fn #name(#model_ty_params) -> bool { - #body - } - - thrust_models::__invariant_marker(#name) + path.segments[0].ident = self.synth.clone(); } } - .into() } diff --git a/thrust-macros/src/invariant_context.rs b/thrust-macros/src/invariant_context.rs new file mode 100644 index 0000000..af3a30a --- /dev/null +++ b/thrust-macros/src/invariant_context.rs @@ -0,0 +1,89 @@ +//! Expansion of `#[thrust_macros::invariant_context]`. +//! +//! Threads the surrounding generic context (and, in methods, `Self`) into +//! every `thrust_macros::invariant!(...)` call inside the annotated function, so +//! an invariant may refer to generic- and `Self`-typed variables that the +//! standalone `invariant!` macro cannot see. +//! +//! It also extends the host function's where clause with the `Model` predicates +//! (see [`crate::model_where_predicates`]) for every in-scope type parameter +//! (and for `Self` when used), since each injected marker call instantiates a +//! `Model`-bounded formula function with the host's own generics. + +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::{quote, ToTokens}; +use syn::{visit_mut::VisitMut, Signature}; + +use crate::fn_outer_item::FnOuterItem; + +pub fn expand(item: TokenStream) -> TokenStream { + let mut item_fn = syn::parse_macro_input!(item as syn::ItemFn); + + let outer = match crate::extract_outer_context(&item_fn.attrs) { + Ok(outer) => outer, + Err(e) => return e.to_compile_error().into(), + }; + + let sig = item_fn.sig.clone(); + let mut injector = ContextInjector { + sig: &sig, + outer: outer.as_ref(), + self_used: false, + }; + injector.visit_block_mut(&mut item_fn.block); + + let mut predicates = crate::model_where_predicates(&sig, outer.as_ref()); + if injector.self_used { + predicates.extend(crate::model_predicates("e!(Self))); + } + if !predicates.is_empty() { + item_fn + .sig + .generics + .make_where_clause() + .predicates + .extend(predicates); + } + + item_fn.into_token_stream().into() +} + +struct ContextInjector<'a> { + sig: &'a Signature, + outer: Option<&'a FnOuterItem>, + self_used: bool, +} + +impl<'a> ContextInjector<'a> { + fn inject_context(&self, closure: &TokenStream2) -> TokenStream2 { + let sig = self.sig; + let outer_attr = self + .outer + .map(|outer| quote!(#[thrust::_outer_context(#outer)])); + + quote! { + #outer_attr + #sig; + #closure + } + } +} + +impl VisitMut for ContextInjector<'_> { + fn visit_macro_mut(&mut self, mac: &mut syn::Macro) { + if !is_invariant_macro(&mac.path) { + return; + } + if crate::tokens_contain_ident(&mac.tokens, "Self") { + self.self_used = true; + } + mac.tokens = self.inject_context(&mac.tokens); + mac.path = syn::parse_quote!(::thrust_macros::_invariant_with_context); + } +} + +fn is_invariant_macro(path: &syn::Path) -> bool { + // TODO: identify the macro precisely + path.segments.last().is_some_and(|s| s.ident == "invariant") +} diff --git a/thrust-macros/src/lib.rs b/thrust-macros/src/lib.rs index 9f62b3a..a0eb9bd 100644 --- a/thrust-macros/src/lib.rs +++ b/thrust-macros/src/lib.rs @@ -1,11 +1,14 @@ use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{TokenStream as TokenStream2, TokenTree as TokenTree2}; mod context; mod fn_outer_item; mod invariant; +mod invariant_context; mod spec; +use fn_outer_item::FnOuterItem; + #[proc_macro_attribute] pub fn context(_attr: TokenStream, item: TokenStream) -> TokenStream { context::expand(item) @@ -30,6 +33,25 @@ pub fn invariant(input: TokenStream) -> TokenStream { invariant::expand(input) } +/// Context-carrying counterpart of `invariant!`, emitted by +/// `#[thrust_macros::invariant_context]`. Not intended to be written by hand: +/// it takes a `fn` header carrying the threaded generics/where clause whose +/// body is the predicate closure (see [`invariant`]). +#[proc_macro] +pub fn _invariant_with_context(input: TokenStream) -> TokenStream { + invariant::expand_with_context(input) +} + +/// Threads the surrounding generic context (and, in methods, `Self`) into +/// every `thrust_macros::invariant!(...)` inside the annotated function, so an +/// invariant may refer to generic- and `Self`-typed variables that the +/// standalone `invariant!` macro cannot see. Each such call is rewritten into +/// `thrust_macros::_invariant_with_context!`. +#[proc_macro_attribute] +pub fn invariant_context(_attr: TokenStream, item: TokenStream) -> TokenStream { + invariant_context::expand(item) +} + #[proc_macro_attribute] pub fn predicate(_attr: TokenStream, item: TokenStream) -> TokenStream { spec::expand_predicate(item) @@ -50,6 +72,138 @@ pub fn _requires_ensures(attr: TokenStream, item: TokenStream) -> TokenStream { spec::expand_requires_ensures(attr, item) } +/// Reads the `#[thrust::_outer_context(..)]` attribute stamped onto methods by +/// `#[thrust_macros::context]` (and threaded by `invariant_context`), returning +/// the enclosing `impl`/`trait` header it carries, or `None` if absent. +fn extract_outer_context(attrs: &[syn::Attribute]) -> syn::Result> { + let outer_context_path: syn::Path = syn::parse_quote!(thrust::_outer_context); + let mut outer_context = None; + for attr in attrs { + if attr.path() != &outer_context_path { + continue; + } + if outer_context.is_some() { + return Err(syn::Error::new_spanned( + attr, + "multiple _outer_context attributes found; expected at most one", + )); + } + outer_context = Some(attr.parse_args()?); + } + Ok(outer_context) +} + +fn has_fn_bound<'a>(bounds: impl IntoIterator) -> bool { + bounds.into_iter().any(|b| { + let syn::TypeParamBound::Trait(tb) = b else { + return false; + }; + tb.path + .segments + .last() + .is_some_and(|s| matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut")) + }) +} + +fn model_predicates(ty: &impl quote::ToTokens) -> [syn::WherePredicate; 2] { + [ + syn::parse_quote!(#ty: thrust_models::Model), + syn::parse_quote!(<#ty as thrust_models::Model>::Ty: PartialEq), + ] +} + +/// `T: Model` / `::Ty: PartialEq` predicates for every type param +/// in scope for `sig` (its own, plus the outer `impl`/`trait`'s and, for a +/// trait, `Self`) that does not carry an `Fn`/`FnOnce`/`FnMut` bound, plus the +/// same for any generic associated-type projection appearing in `sig`. +fn model_where_predicates( + sig: &syn::Signature, + outer_context: Option<&FnOuterItem>, +) -> Vec { + struct GenericTypeParam { + ident: syn::Ident, + bounds: Vec, + } + + impl From for GenericTypeParam { + fn from(tp: syn::TypeParam) -> Self { + Self { + ident: tp.ident, + bounds: tp.bounds.into_iter().collect(), + } + } + } + + let mut generic_type_params: Vec = Vec::new(); + for param in &sig.generics.params { + let syn::GenericParam::Type(tp) = param else { + continue; + }; + generic_type_params.push(tp.clone().into()); + } + if let Some(outer_item) = outer_context { + for param in &outer_item.generics().params { + let syn::GenericParam::Type(tp) = param else { + continue; + }; + generic_type_params.push(tp.clone().into()); + } + if let FnOuterItem::ItemTrait(outer_item) = &outer_item { + generic_type_params.push(GenericTypeParam { + ident: quote::format_ident!("Self"), + bounds: outer_item.supertraits.iter().cloned().collect(), + }); + } + } + generic_type_params.retain(|p| !has_fn_bound(&p.bounds)); + + let mut predicates: Vec = Vec::new(); + for param in &generic_type_params { + predicates.extend(model_predicates(¶m.ident)); + } + + struct Visitor { + generic_type_params: Vec, + generic_paths: Vec, + } + + impl syn::visit::Visit<'_> for Visitor { + fn visit_type_path(&mut self, tp: &syn::TypePath) { + for param in &self.generic_type_params { + if let Some(qself) = &tp.qself { + let param = ¶m.ident; + let param_ty: syn::Type = syn::parse_quote!(#param); + if *qself.ty == param_ty { + self.generic_paths.push(tp.clone()); + } + } + if tp.path.segments.len() > 1 + && tp.path.segments.first().unwrap().ident == param.ident + && tp.qself.is_none() + { + self.generic_paths.push(tp.clone()); + } + } + syn::visit::visit_type_path(self, tp); + } + } + + let mut visitor = Visitor { + generic_type_params, + generic_paths: Vec::new(), + }; + use syn::visit::Visit as _; + for arg in &sig.inputs { + visitor.visit_fn_arg(arg); + } + visitor.visit_return_type(&sig.output); + for tp in visitor.generic_paths { + predicates.extend(model_predicates(&tp)); + } + + predicates +} + /// Maps each function parameter `x: T` to `x: ::Ty`. fn fn_params_with_model_ty<'ast, I>(args: I) -> TokenStream2 where @@ -71,3 +225,15 @@ where } quote::quote!(#(#model_inputs),*) } + +fn tokens_contain_ident(tokens: &TokenStream2, target: T) -> bool +where + T: AsRef, +{ + let target = target.as_ref(); + tokens.clone().into_iter().any(|tt| match tt { + TokenTree2::Ident(ident) => ident == target, + TokenTree2::Group(group) => tokens_contain_ident(&group.stream(), target), + _ => false, + }) +} diff --git a/thrust-macros/src/spec.rs b/thrust-macros/src/spec.rs index e049b33..1faf6a3 100644 --- a/thrust-macros/src/spec.rs +++ b/thrust-macros/src/spec.rs @@ -10,8 +10,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::{format_ident, quote, ToTokens}; use syn::{ - parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, TypeParamBound, - WherePredicate, + parse_macro_input, punctuated::Punctuated, FnArg, GenericParam, Generics, WherePredicate, }; use crate::{fn_outer_item::FnOuterItem, fn_params_with_model_ty}; @@ -31,7 +30,7 @@ pub fn expand_predicate(item: TokenStream) -> TokenStream { let model_ty_params = fn_params_with_model_ty(&func.sig().inputs); let model_ret = fn_return_ty_with_model_ty(&func.sig().output); - let model_preds = model_where_predicates(&func, outer_context.as_ref()); + let model_preds = crate::model_where_predicates(func.sig(), outer_context.as_ref()); let extended_where = extended_where_clause(&func, &model_preds); let sig = quote! { @@ -235,22 +234,7 @@ fn extract_requires_ensures(func: &mut FnItemWithSignature) -> syn::Result<(syn: } fn extract_outer_context(func: &FnItemWithSignature) -> syn::Result> { - let outer_context_path: syn::Path = syn::parse_quote!(thrust::_outer_context); - let mut outer_context = None; - for attr in func.attrs() { - if attr.path() != &outer_context_path { - continue; - } - - let item = attr.parse_args()?; - if outer_context.is_some() { - return Err(syn::Error::new_spanned( - attr, - "multiple _outer_context attributes found; expected at most one", - )); - } - outer_context = Some(item); - } + let outer_context = crate::extract_outer_context(func.attrs())?; if mentions_self(func.sig()) && outer_context.is_none() { return Err(syn::Error::new_spanned( func.sig().ident.clone(), @@ -325,7 +309,8 @@ impl ExpandedTokens { } fn extended_where_clause(&self) -> TokenStream2 { - let model_preds = model_where_predicates(&self.func, self.outer_context.as_ref()); + let model_preds = + crate::model_where_predicates(self.func.sig(), self.outer_context.as_ref()); extended_where_clause(&self.func, &model_preds) } @@ -557,112 +542,6 @@ fn rewrite_inputs_for_call( (quote!(#(#rewritten),*), quote!(#(#call_args),*)) } -/// Returns `T: thrust_models::Model` predicates for every type param that does not -/// already carry an `Fn`, `FnOnce`, or `FnMut` bound. -fn model_where_predicates( - func: &FnItemWithSignature, - outer_context: Option<&FnOuterItem>, -) -> Vec { - struct GenericTypeParam { - ident: syn::Ident, - bounds: Vec, - } - - impl From for GenericTypeParam { - fn from(tp: syn::TypeParam) -> Self { - Self { - ident: tp.ident, - bounds: tp.bounds.into_iter().collect(), - } - } - } - - impl GenericTypeParam { - fn has_fn_bound(&self) -> bool { - self.bounds.iter().any(|b| { - let TypeParamBound::Trait(tb) = b else { - return false; - }; - tb.path.segments.last().is_some_and(|s| { - matches!(s.ident.to_string().as_str(), "Fn" | "FnOnce" | "FnMut") - }) - }) - } - } - - let mut generic_type_params: Vec = Vec::new(); - for param in &func.sig().generics.params { - let GenericParam::Type(tp) = param else { - continue; - }; - generic_type_params.push(tp.clone().into()); - } - if let Some(outer_item) = outer_context { - for param in &outer_item.generics().params { - let GenericParam::Type(tp) = param else { - continue; - }; - generic_type_params.push(tp.clone().into()); - } - if let FnOuterItem::ItemTrait(outer_item) = &outer_item { - generic_type_params.push(GenericTypeParam { - ident: format_ident!("Self"), - bounds: outer_item.supertraits.iter().cloned().collect(), - }); - } - } - generic_type_params.retain(|p| !p.has_fn_bound()); - - let mut predicates: Vec = Vec::new(); - for param in &generic_type_params { - let ident = ¶m.ident; - predicates.push(syn::parse_quote!(#ident: thrust_models::Model)); - predicates.push(syn::parse_quote!(<#ident as thrust_models::Model>::Ty: PartialEq)); - } - - struct Visitor { - generic_type_params: Vec, - generic_paths: Vec, - } - - impl syn::visit::Visit<'_> for Visitor { - fn visit_type_path(&mut self, tp: &syn::TypePath) { - for param in &self.generic_type_params { - if let Some(qself) = &tp.qself { - let param = ¶m.ident; - let param_ty: syn::Type = syn::parse_quote!(#param); - if *qself.ty == param_ty { - self.generic_paths.push(tp.clone()); - } - } - if tp.path.segments.len() > 1 - && tp.path.segments.first().unwrap().ident == param.ident - && tp.qself.is_none() - { - self.generic_paths.push(tp.clone()); - } - } - syn::visit::visit_type_path(self, tp); - } - } - - let mut visitor = Visitor { - generic_type_params, - generic_paths: Vec::new(), - }; - use syn::visit::Visit as _; - for arg in &func.sig().inputs { - visitor.visit_fn_arg(arg); - } - visitor.visit_return_type(&func.sig().output); - for tp in visitor.generic_paths { - predicates.push(syn::parse_quote!(#tp: thrust_models::Model)); - predicates.push(syn::parse_quote!(<#tp as thrust_models::Model>::Ty: PartialEq)); - } - - predicates -} - /// Builds `where , `. /// Returns an empty token stream when both sets are empty. fn extended_where_clause(