diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 2dbb9ea..1d7658e 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -37,6 +37,10 @@ pub fn extern_spec_fn_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")] } +pub fn raw_command_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("raw_command")] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 9dd85f9..d5d6dde 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -2,12 +2,14 @@ use std::collections::HashSet; +use rustc_hir::def_id::CRATE_DEF_ID; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; use crate::analyze; use crate::chc; use crate::rty::{self, ClauseBuilderExt as _}; +use crate::annot; /// An implementation of local crate analysis. /// @@ -26,6 +28,21 @@ pub struct Analyzer<'tcx, 'ctx> { } impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { + fn analyze_raw_command_annot(&mut self) { + for attrs in self.tcx.get_attrs_by_path( + CRATE_DEF_ID.to_def_id(), + &analyze::annot::raw_command_path(), + ) { + let ts = analyze::annot::extract_annot_tokens(attrs.clone()); + let parser = annot::AnnotParser::new( + // TODO: this resolver is not actually used. + analyze::annot::ParamResolver::default() + ); + let raw_command = parser.parse_raw_command(ts).unwrap(); + self.ctx.system.borrow_mut().push_raw_command(raw_command); + } + } + fn refine_local_defs(&mut self) { for local_def_id in self.tcx.mir_keys(()) { if self.tcx.def_kind(*local_def_id).is_fn_like() { @@ -187,6 +204,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let span = tracing::debug_span!("crate", krate = %self.tcx.crate_name(rustc_span::def_id::LOCAL_CRATE)); let _guard = span.enter(); + self.analyze_raw_command_annot(); self.refine_local_defs(); self.analyze_local_defs(); self.assert_callable_entry(); diff --git a/src/annot.rs b/src/annot.rs index 128c289..5a9e63d 100644 --- a/src/annot.rs +++ b/src/annot.rs @@ -8,7 +8,7 @@ //! The main entry point is [`AnnotParser`], which parses a [`TokenStream`] into a //! [`rty::RefinedType`] or a [`chc::Formula`]. -use rustc_ast::token::{BinOpToken, Delimiter, LitKind, Token, TokenKind}; +use rustc_ast::token::{BinOpToken, Delimiter, LitKind, Lit, Token, TokenKind}; use rustc_ast::tokenstream::{RefTokenTreeCursor, Spacing, TokenStream, TokenTree}; use rustc_index::IndexVec; use rustc_span::symbol::Ident; @@ -1076,6 +1076,32 @@ where .ok_or_else(|| ParseAttrError::unexpected_term("in annotation"))?; Ok(AnnotFormula::Formula(formula)) } + + pub fn parse_annot_raw_command(&mut self) -> Result { + let t = self.next_token("raw CHC command")?; + + match t { + Token { + kind: TokenKind::Literal( + Lit { kind, symbol, .. } + ), + .. + } => { + match kind { + LitKind::Str => { + let command = symbol.to_string(); + Ok(chc::RawCommand{ command }) + }, + _ => Err(ParseAttrError::unexpected_token( + "string literal", t.clone() + )) + } + }, + _ => Err(ParseAttrError::unexpected_token( + "string literal", t.clone() + )) + } + } } /// A [`Resolver`] implementation for resolving specific variable as [`rty::RefinedTypeVar::Value`]. @@ -1208,4 +1234,15 @@ where parser.end_of_input()?; Ok(formula) } + + pub fn parse_raw_command(&self, ts: TokenStream) -> Result { + let mut parser = Parser { + resolver: &self.resolver, + cursor: ts.trees(), + formula_existentials: Default::default(), + }; + let raw_command = parser.parse_annot_raw_command()?; + parser.end_of_input()?; + Ok(raw_command) + } } diff --git a/src/chc.rs b/src/chc.rs index 5543de4..74550a3 100644 --- a/src/chc.rs +++ b/src/chc.rs @@ -1606,6 +1606,14 @@ impl Clause { } } +/// A command specified using #![thrust::define_raw()] +/// +/// Those will be directly inserted into the generated SMT-LIB2 file. +#[derive(Debug, Clone)] +pub struct RawCommand { + pub command: String, +} + /// A selector for a datatype constructor. /// /// A selector is a function that extracts a field from a datatype value. @@ -1655,6 +1663,7 @@ pub struct PredVarDef { /// A CHC system. #[derive(Debug, Clone, Default)] pub struct System { + pub raw_commands: Vec, pub datatypes: Vec, pub clauses: IndexVec, pub pred_vars: IndexVec, @@ -1665,6 +1674,10 @@ impl System { self.pred_vars.push(PredVarDef { sig, debug_info }) } + pub fn push_raw_command(&mut self, raw_command: RawCommand) { + self.raw_commands.push(raw_command) + } + pub fn push_clause(&mut self, clause: Clause) -> Option { if clause.is_nop() { return None; diff --git a/src/chc/smtlib2.rs b/src/chc/smtlib2.rs index 167d108..0bd7311 100644 --- a/src/chc/smtlib2.rs +++ b/src/chc/smtlib2.rs @@ -370,6 +370,30 @@ impl<'ctx, 'a> Clause<'ctx, 'a> { } } +/// A wrapper around a [`chc::RawCommand`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. +#[derive(Debug, Clone)] +pub struct RawCommand<'a> { + inner: &'a chc::RawCommand, +} + +impl<'a> std::fmt::Display for RawCommand<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + self.inner.command, + ) + } +} + +impl<'a> RawCommand<'a> { + pub fn new(inner: &'a chc::RawCommand) -> Self { + Self { + inner + } + } +} + /// A wrapper around a [`chc::DatatypeSelector`] that provides a [`std::fmt::Display`] implementation in SMT-LIB2 format. #[derive(Debug, Clone)] pub struct DatatypeSelector<'ctx, 'a> { @@ -555,21 +579,26 @@ impl<'a> std::fmt::Display for System<'a> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "(set-logic HORN)\n")?; + // insert command from #![thrust::raw_command()] here + for raw_command in &self.inner.raw_commands { + writeln!(f, "{}\n", RawCommand::new(raw_command))?; + } + writeln!(f, "{}\n", Datatypes::new(&self.ctx, self.ctx.datatypes()))?; for datatype in self.ctx.datatypes() { writeln!(f, "{}", DatatypeDiscrFun::new(&self.ctx, datatype))?; writeln!(f, "{}", MatcherPredFun::new(&self.ctx, datatype))?; } writeln!(f)?; - for (p, def) in self.inner.pred_vars.iter_enumerated() { - if !def.debug_info.is_empty() { - writeln!(f, "{}", def.debug_info.display("; "))?; + for (p, cmd) in self.inner.pred_vars.iter_enumerated() { + if !cmd.debug_info.is_empty() { + writeln!(f, "{}", cmd.debug_info.display("; "))?; } writeln!( f, "(declare-fun {} {} Bool)\n", p, - List::closed(def.sig.iter().map(|s| self.ctx.fmt_sort(s))) + List::closed(cmd.sig.iter().map(|s| self.ctx.fmt_sort(s))) )?; } for (id, clause) in self.inner.clauses.iter_enumerated() { diff --git a/src/chc/unbox.rs b/src/chc/unbox.rs index 5be1240..8ed320f 100644 --- a/src/chc/unbox.rs +++ b/src/chc/unbox.rs @@ -161,6 +161,7 @@ pub fn unbox(system: System) -> System { clauses, pred_vars, datatypes, + raw_commands, } = system; let datatypes = datatypes.into_iter().map(unbox_datatype).collect(); let clauses = clauses.into_iter().map(unbox_clause).collect(); @@ -169,5 +170,6 @@ pub fn unbox(system: System) -> System { clauses, pred_vars, datatypes, + raw_commands, } } diff --git a/tests/ui/fail/annot_raw_command.rs b/tests/ui/fail/annot_raw_command.rs new file mode 100644 index 0000000..7ca52a7 --- /dev/null +++ b/tests/ui/fail/annot_raw_command.rs @@ -0,0 +1,19 @@ +//@compile-flags: -Adead_code -C debug-assertions=off +// This test panics with "UnexpectedToken" for now. +// TODO: reporting rustc diagnostics for parse errors + +// Insert commands written in SMT-LIB2 format into .smt2 file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_command(true)] // argument must be single string literal + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); +} diff --git a/tests/ui/fail/annot_raw_command_without_params.rs b/tests/ui/fail/annot_raw_command_without_params.rs new file mode 100644 index 0000000..810e1b0 --- /dev/null +++ b/tests/ui/fail/annot_raw_command_without_params.rs @@ -0,0 +1,19 @@ +//@compile-flags: -Adead_code -C debug-assertions=off +// This test panics with "invalid attribute" for now. +// TODO: reporting rustc diagnostics for parse errors + +// Insert commands written in SMT-LIB2 format into .smt file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_command] // argument must be single string literal + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); +} diff --git a/tests/ui/pass/annot_raw_command.rs b/tests/ui/pass/annot_raw_command.rs new file mode 100644 index 0000000..7999376 --- /dev/null +++ b/tests/ui/pass/annot_raw_command.rs @@ -0,0 +1,23 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert commands written in SMT-LIB2 format into .smt2 file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_command("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); +} diff --git a/tests/ui/pass/annot_raw_command_multi.rs b/tests/ui/pass/annot_raw_command_multi.rs new file mode 100644 index 0000000..e6c2164 --- /dev/null +++ b/tests/ui/pass/annot_raw_command_multi.rs @@ -0,0 +1,31 @@ +//@check-pass +//@compile-flags: -Adead_code -C debug-assertions=off + +// Insert commands written in SMT-LIB2 format into .smt2 file directly. +// This feature is intended for debug or experiment purpose. +#![feature(custom_inner_attributes)] +#![thrust::raw_command("(define-fun is_double ((x Int) (doubled_x Int)) Bool + (= + (* x 2) + doubled_x + ) +)")] + +// multiple raw commands can be inserted. +#![thrust::raw_command("(define-fun is_triple ((x Int) (tripled_x Int)) Bool + (= + (* x 3) + tripled_x + ) +)")] + +#[thrust::requires(true)] +#[thrust::ensures(result == 2 * x)] +fn double(x: i64) -> i64 { + x + x +} + +fn main() { + let a = 3; + assert!(double(a) == 6); +}