Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/ui/fail/loop_invariant_generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

#[thrust_macros::invariant_context]
fn keep<T: Copy + PartialEq>(v: T) {
let mut x = v;
while rand() == 0 {
thrust_macros::invariant!(|v: T| v == v);
x = v;
}
assert!(x == v);
}

fn main() {
keep(0_i64);
keep(true);
}
30 changes: 30 additions & 0 deletions tests/ui/fail/loop_invariant_self.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

struct Counter(i64);
impl thrust_models::Model for Counter {
type Ty = (thrust_models::model::Int,);
}

#[thrust_macros::context]
impl Counter {
#[thrust_macros::invariant_context]
fn run(self) {
let mut c = self;
let mut x = 1_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64, c: Self| x >= 2 && c == c);
x = x + 1;
c = Counter(0);
}
let _last = c;
assert!(x >= 1);
}
}

fn main() { Counter(0).run(); }
22 changes: 22 additions & 0 deletions tests/ui/pass/loop_invariant_generic.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//@check-pass
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

#[thrust_macros::invariant_context]
fn keep<T: Copy + PartialEq>(v: T) {
let mut x = v;
while rand() == 0 {
thrust_macros::invariant!(|x: T, v: T| x == v);
x = v;
}
assert!(x == v);
}

fn main() {
keep(0_i64);
keep(true);
}
26 changes: 26 additions & 0 deletions tests/ui/pass/loop_invariant_generic_closure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//@check-pass
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

// A closure-typed generic param must not be given a `Model` bound: the
// invariant only constrains the `Model`-typed `T`, and `keep` must still be
// callable with a real closure.
#[thrust_macros::invariant_context]
fn keep<F: Fn(i64) -> i64, T: Copy + PartialEq>(f: F, v: T) {
let _ = f;
let mut x = v;
while rand() == 0 {
thrust_macros::invariant!(|x: T, v: T| x == v);
x = v;
}
assert!(x == v);
}

fn main() {
keep(|x| x + 1, 0_i64);
keep(|x| x, true);
}
30 changes: 30 additions & 0 deletions tests/ui/pass/loop_invariant_self.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//@check-pass
//@compile-flags: -C debug-assertions=off

#[thrust_macros::requires(true)]
#[thrust_macros::ensures(true)]
#[thrust::trusted]
fn rand() -> i64 { unimplemented!() }

struct Counter(i64);
impl thrust_models::Model for Counter {
type Ty = (thrust_models::model::Int,);
}

#[thrust_macros::context]
impl Counter {
#[thrust_macros::invariant_context]
fn run(self) {
let mut c = self;
let mut x = 1_i64;
while rand() == 0 {
thrust_macros::invariant!(|x: i64, c: Self| x >= 1 && c == c);
x = x + 1;
c = Counter(0);
}
let _last = c;
assert!(x >= 1);
}
}

fn main() { Counter(0).run(); }
207 changes: 185 additions & 22 deletions thrust-macros/src/invariant.rs
Original file line number Diff line number Diff line change
@@ -1,54 +1,217 @@
//! Expansion of `thrust_macros::invariant!`.
//! Expansion of `thrust_macros::invariant!` and its context-carrying sibling
//! `thrust_macros::_invariant_with_context!`.
//!
//! Expands a closure with concrete parameter types into a
//! `#[thrust::formula_fn]` over `Model::Ty` parameters and a marker call
//! referencing it.
//! Both expand a predicate closure with explicit parameter types into a
//! `#[thrust::formula_fn]` over `Model::Ty` parameters plus a marker call
//! referencing it; they share [`expand_invariant`] and differ only in input:
//!
//! - `invariant!(|x: i64| x >= 1)` takes a bare predicate closure and only sees
//! concrete types.
//! - `_invariant_with_context!(..)` additionally carries the enclosing generic
//! context. It is never written by hand: `#[thrust_macros::invariant_context]`
//! rewrites each `invariant!` it finds into this form, pasting the host
//! function's signature (and, in methods, a `#[thrust::_outer_context(..)]`
//! attribute carrying the enclosing `impl`/`trait` header) ahead of the
//! closure:
//!
//! ```ignore
//! _invariant_with_context!(
//! #[thrust::_outer_context(impl<T> Foo<T> where ..)] // methods only
//! fn f<U>(..) -> .. where ..; // host signature, as-is
//! |x: T, v: T| x == v
//! )
//! ```
//!
//! The in-scope generics (shadowing the enclosing ones) are re-declared on the
//! formula function and instantiated via turbofish; in methods, `Self` is
//! re-declared as a synthetic type parameter and instantiated with the real
//! `Self` (legal in expression position).

use std::sync::atomic::{AtomicUsize, Ordering};

use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, FnArg};
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
visit_mut::VisitMut,
FnArg, GenericParam, Signature, WherePredicate,
};

use crate::fn_params_with_model_ty;
use crate::fn_outer_item::FnOuterItem;

static COUNTER: AtomicUsize = AtomicUsize::new(0);

/// Expands `invariant!(CLOSURE)`: a bare predicate closure with no threaded
/// context.
pub fn expand(input: TokenStream) -> TokenStream {
let closure = parse_macro_input!(input as syn::ExprClosure);
match expand_invariant(&closure, None) {
Ok(expr) => expr.into_token_stream().into(),
Err(e) => e.to_compile_error().into(),
}
}

/// Expands `_invariant_with_context!(#outer_attr #sig; CLOSURE)`, the form
/// `#[thrust_macros::invariant_context]` rewrites each `invariant!` into.
pub fn expand_with_context(input: TokenStream) -> TokenStream {
struct WithContext {
context: Context,
closure: syn::ExprClosure,
}

impl Parse for WithContext {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attrs = input.call(syn::Attribute::parse_outer)?;
let outer = crate::extract_outer_context(&attrs)?;
let sig: Signature = input.parse()?;
input.parse::<syn::Token![;]>()?;
let closure: syn::ExprClosure = input.parse()?;
Ok(Self {
context: Context { sig, outer },
closure,
})
}
}

let WithContext { closure, context } = parse_macro_input!(input as WithContext);
match expand_invariant(&closure, Some(&context)) {
Ok(expr) => expr.into_token_stream().into(),
Err(e) => e.to_compile_error().into(),
}
}

/// The enclosing context threaded into an invariant by
/// `#[thrust_macros::invariant_context]`: the host function signature and, for a
/// method, its `impl`/`trait` header. A standalone `invariant!` has none.
struct Context {
sig: Signature,
outer: Option<FnOuterItem>,
}

impl Context {
/// The generic params in scope: the host signature's own, plus the outer
/// `impl`/`trait`'s for a method.
fn generic_params(&self) -> impl Iterator<Item = &GenericParam> {
self.sig
.generics
.params
.iter()
.chain(self.outer.iter().flat_map(|o| o.generics().params.iter()))
}

/// The where-predicates in scope, from the host signature and (for a method)
/// the outer `impl`/`trait`.
fn where_predicates(&self) -> impl Iterator<Item = &WherePredicate> {
fn preds(g: &syn::Generics) -> impl Iterator<Item = &WherePredicate> {
g.where_clause.iter().flat_map(|wc| wc.predicates.iter())
}
preds(&self.sig.generics).chain(self.outer.iter().flat_map(|o| preds(o.generics())))
}
}

/// Expands a predicate closure into a `#[thrust::formula_fn]` plus a marker
/// call. With `context`, the in-scope generics (and, in methods, `Self`) are
/// re-declared on the formula function and instantiated via turbofish.
fn expand_invariant(
closure: &syn::ExprClosure,
context: Option<&Context>,
) -> syn::Result<syn::Expr> {
let mut fn_params: Vec<FnArg> = Vec::new();
for param in &closure.inputs {
let syn::Pat::Type(pt) = param else {
return syn::Error::new_spanned(
return Err(syn::Error::new_spanned(
param,
"invariant closure parameters must have explicit types, e.g. `|x: i64| ...`",
)
.to_compile_error()
.into();
));
};
let pat = &pt.pat;
let ty = &pt.ty;
fn_params.push(syn::parse_quote!(#pat: #ty));
}

let model_ty_params = fn_params_with_model_ty(&fn_params);
let mut def_params: Vec<TokenStream2> = Vec::new();
let mut turbofish_args: Vec<TokenStream2> = Vec::new();
for param in context.into_iter().flat_map(Context::generic_params) {
def_params.push(param.to_token_stream());
match param {
GenericParam::Type(tp) => turbofish_args.push(tp.ident.to_token_stream()),
GenericParam::Const(cp) => turbofish_args.push(cp.ident.to_token_stream()),
GenericParam::Lifetime(_) => {}
}
}

let mut def_wheres: Vec<WherePredicate> = context
.into_iter()
.flat_map(Context::where_predicates)
.cloned()
.collect();
if let Some(context) = context {
def_wheres.extend(crate::model_where_predicates(
&context.sig,
context.outer.as_ref(),
));
}

// `Self` in a method context: rewrite it to a synthetic generic, then pass
// the real `Self` via turbofish (legal in expression position).
if crate::tokens_contain_ident(&closure.to_token_stream(), "Self") {
let synth: syn::Ident = format_ident!("__ThrustSelf");
for param in &mut fn_params {
SelfRewriter { synth: &synth }.visit_fn_arg_mut(param);
}
def_params.push(quote!(#synth));
def_wheres.extend(crate::model_predicates(&synth));
turbofish_args.push(quote!(Self));
}
Comment thread
coord-e marked this conversation as resolved.

let model_ty_params = crate::fn_params_with_model_ty(&fn_params);
let body = &closure.body;

let id = COUNTER.fetch_add(1, Ordering::Relaxed);
let name = format_ident!("_thrust_invariant_{}", id);

quote! {
let def_generics = if def_params.is_empty() {
quote!()
} else {
quote!(<#(#def_params),*>)
};
let where_clause = if def_wheres.is_empty() {
quote!()
} else {
quote!(where #(#def_wheres),*)
};
let turbofish = if turbofish_args.is_empty() {
quote!()
} else {
quote!(::<#(#turbofish_args),*>)
};

Ok(syn::parse_quote!({
#[allow(unused_variables)]
#[allow(non_snake_case)]
#[thrust::formula_fn]
fn #name #def_generics(#model_ty_params) -> bool #where_clause {
#body
}

thrust_models::__invariant_marker(#name #turbofish)
}))
}

struct SelfRewriter<'a> {
synth: &'a syn::Ident,
}

impl VisitMut for SelfRewriter<'_> {
fn visit_path_mut(&mut self, path: &mut syn::Path) {
syn::visit_mut::visit_path_mut(self, path);
if path.leading_colon.is_none()
&& path.segments.len() == 1
&& path.segments[0].ident == "Self"
{
#[allow(unused_variables)]
#[allow(non_snake_case)]
#[thrust::formula_fn]
fn #name(#model_ty_params) -> bool {
#body
}

thrust_models::__invariant_marker(#name)
path.segments[0].ident = self.synth.clone();
}
}
.into()
}
Loading