diff --git a/README.md b/README.md index 04129ac..696dfa8 100644 --- a/README.md +++ b/README.md @@ -15,22 +15,15 @@ A mathematical expression evaluator library written in Rust with support for cus ## How It Works -The library implements a classic compiler pipeline: +Classic compiler pipeline with type-safe state transitions: ``` -Source → Lexer → Parser → AST → Semantic Analysis → IR → Bytecode → VM +Input → Lexer → Parser → Compiler → Program + ↓ link + Program → Execute ``` -1. **Lexer** - Tokenizes the input string into operators, numbers, and identifiers -2. **Parser** - Uses operator precedence climbing to build an Abstract Syntax Tree (AST) -3. **Semantic Analysis** - Resolves symbols and validates function arities -4. **IR Builder** - Converts the AST into stack-based bytecode instructions -5. **Virtual Machine** - Executes the bytecode on a stack-based VM - -This architecture allows for: -- Separating parsing from execution -- Compiling expressions once and running them multiple times -- Serializing compiled bytecode to disk for later use +The `Program` type uses Rust's type system to enforce correct usage at compile time. You cannot execute an unlinked program, and you cannot link a program twice. ## Usage @@ -40,7 +33,7 @@ Add this to your `Cargo.toml`: ```toml [dependencies] -expr-solver-lib = "1.0.3" +expr-solver-lib = "1.1.0" ``` ### As a binary @@ -49,166 +42,82 @@ Add this to your `Cargo.toml`: ```toml [dependencies] -expr-solver-bin = "1.0.3" +expr-solver-bin = "1.1.0" ``` -### Basic Example +### Quick Evaluation ```rust -use expr_solver::Eval; - -fn main() { - // Quick one-liner evaluation - match Eval::evaluate("2+3*4") { - Ok(result) => println!("Result: {}", result), - Err(e) => eprintln!("Error: {}", e), - } - - // Or create an evaluator instance for more control - let mut eval = Eval::new("sqrt(16) + pi"); - match eval.run() { - Ok(result) => println!("Result: {}", result), - Err(e) => eprintln!("Error: {}", e), - } -} +use expr_solver::eval; + +// Simple one-liner +let result = eval("2 + 3 * 4").unwrap(); +assert_eq!(result.to_string(), "14"); + +// With built-in functions +let result = eval("sqrt(16) + sin(pi/2)").unwrap(); ``` -### Advanced Example +### Custom Symbols ```rust -use expr_solver::{Eval, SymTable}; +use expr_solver::{eval_with_table, SymTable}; use rust_decimal_macros::dec; -fn main() { - // Create a custom symbol table - let mut table = SymTable::stdlib(); - table.add_const("x", dec!(10)).unwrap(); - table.add_func("double", 1, false, |args| Ok(args[0] * dec!(2))).unwrap(); - - // Evaluate with custom symbols - let mut eval = Eval::with_table("double(x) + sqrt(25)", table); - let result = eval.run().unwrap(); - println!("Result: {}", result); // 25 -} +let mut table = SymTable::stdlib(); +table.add_const("x", dec!(10)).unwrap(); +table.add_func("double", 1, false, |args| Ok(args[0] * dec!(2))).unwrap(); + +let result = eval_with_table("double(x)", table).unwrap(); +assert_eq!(result, dec!(20)); ``` -### Compile and Execute +### Compile Once, Execute Many Times ```rust -use expr_solver::Eval; -use std::path::PathBuf; - -fn main() { - // Compile expression to bytecode - let mut eval = Eval::new("2 + 3 * 4"); - eval.compile_to_file(&PathBuf::from("expr.bin")).unwrap(); - - // Load and execute the compiled bytecode - let mut eval = Eval::new_from_file(PathBuf::from("expr.bin")); - let result = eval.run().unwrap(); - println!("Result: {}", result); // 14 -} -``` - -### Viewing Assembly +use expr_solver::{load, SymTable}; +use rust_decimal_macros::dec; -You can inspect the generated bytecode as human-readable assembly: +// Compile expression +let program = load("x * 2 + y").unwrap(); -```rust -use expr_solver::Eval; +// Execute with different values +let mut table = SymTable::new(); +table.add_const("x", dec!(10)).unwrap(); +table.add_const("y", dec!(5)).unwrap(); -fn main() { - let mut eval = Eval::new("2 + 3 * 4"); - println!("{}", eval.get_assembly().unwrap()); -} +let linked = program.link(table).unwrap(); +let result = linked.execute().unwrap(); // 25 ``` -Output: -```asm -; VERSION 1.0.2 -0000 PUSH 2 -0001 PUSH 3 -0002 PUSH 4 -0003 MUL -0004 ADD -``` +## Precision -The assembly shows the stack-based bytecode instructions that will be executed by the VM. +Uses **128-bit `Decimal`** arithmetic for exact decimal calculations without floating-point errors. -## Precision and Data Types +## Built-in Functions -All calculations are performed using **128-bit `Decimal`** type from the `rust_decimal` crate, providing exact decimal arithmetic without floating-point errors. +| Category | Functions | +|----------------|---------------------------------------------------------------------------| +| **Arithmetic** | `abs`, `sign`, `floor`, `ceil`, `round`, `trunc`, `fract`, `mod`, `clamp` | +| **Trig** | `sin`, `cos`, `tan`, `asin`*, `acos`*, `atan`*, `atan2`* | +| **Hyperbolic** | `sinh`*, `cosh`*, `tanh`* | +| **Exp/Log** | `sqrt`, `cbrt`*, `pow`, `exp`, `exp2`*, `log`, `log2`*, `log10`, `hypot`* | +| **Variadic** | `min`, `max`, `sum`, `avg` (1+ args) | +| **Special** | `if(cond, then, else)` | -> **Note**: Some trigonometric and hyperbolic functions (`asin`, `acos`, `atan`, `atan2`, `sinh`, `cosh`, `tanh`, `cbrt`, `exp2`, `log2`, `hypot`) internally convert to/from `f64` for computation, which may introduce minor precision differences. All constants (`pi`, `e`, `tau`, `ln2`, `ln10`, `sqrt2`) are computed using native `Decimal` operations for maximum precision. +\* *Uses f64 internally, may have minor precision differences* -## Built-in Functions +## Built-in Constants -| Function | Arguments | Description | Notes | -|-----------------------------|-----------|------------------------------------------|---------------------------------| -| **Arithmetic** | | | | -| `abs(x)` | 1 | Absolute value | | -| `sign(x)` | 1 | Sign (-1, 0, or 1) | | -| `floor(x)` | 1 | Round down to integer | | -| `ceil(x)` | 1 | Round up to integer | | -| `round(x)` | 1 | Round to nearest integer | | -| `trunc(x)` | 1 | Truncate to integer | | -| `fract(x)` | 1 | Fractional part | | -| `mod(x, y)` | 2 | Remainder of x/y | | -| `clamp(x, min, max)` | 3 | Constrain value between bounds | | -| **Trigonometry** | | | | -| `sin(x)` | 1 | Sine | | -| `cos(x)` | 1 | Cosine | | -| `tan(x)` | 1 | Tangent | | -| `asin(x)` | 1 | Arcsine | Uses f64 internally | -| `acos(x)` | 1 | Arccosine | Uses f64 internally | -| `atan(x)` | 1 | Arctangent | Uses f64 internally | -| `atan2(y, x)` | 2 | Two-argument arctangent | Uses f64 internally | -| **Hyperbolic** | | | | -| `sinh(x)` | 1 | Hyperbolic sine | Uses f64 internally | -| `cosh(x)` | 1 | Hyperbolic cosine | Uses f64 internally | -| `tanh(x)` | 1 | Hyperbolic tangent | Uses f64 internally | -| **Exponential/Logarithmic** | | | | -| `sqrt(x)` | 1 | Square root | | -| `cbrt(x)` | 1 | Cube root | Uses f64 internally | -| `pow(x, y)` | 2 | x raised to power y | | -| `exp(x)` | 1 | e raised to power x | | -| `exp2(x)` | 1 | 2 raised to power x | Uses f64 internally | -| `log(x)` | 1 | Natural logarithm | | -| `log2(x)` | 1 | Base-2 logarithm | Uses f64 internally | -| `log10(x)` | 1 | Base-10 logarithm | | -| `hypot(x, y)` | 2 | Euclidean distance √(x²+y²) | Uses f64 internally | -| **Variadic** | | | | -| `min(x, ...)` | 1+ | Minimum value | Accepts any number of arguments | -| `max(x, ...)` | 1+ | Maximum value | Accepts any number of arguments | -| `sum(x, ...)` | 1+ | Sum of values | Accepts any number of arguments | -| `avg(x, ...)` | 1+ | Average of values | Accepts any number of arguments | -| **Special** | | | | -| `if(cond, t, f)` | 3 | Conditional: returns t if cond≠0, else f | | +`pi`, `e`, `tau`, `ln2`, `ln10`, `sqrt2` -## Built-in Constants +> All names are case-insensitive. + +## Operators -| Constant | Value | Description | -|----------|-------|-------------| -| `pi` | 3.14159... | π (pi) | -| `e` | 2.71828... | Euler's number | -| `tau` | 6.28318... | 2π (tau) | -| `ln2` | 0.69314... | Natural logarithm of 2 | -| `ln10` | 2.30258... | Natural logarithm of 10 | -| `sqrt2` | 1.41421... | Square root of 2 | - -> **Note**: All function and constant names are case-insensitive. - -## Supported Operators - -| Operator | Type | Associativity | Precedence | Description | -|----------|------|---------------|------------|-------------| -| `!` | Postfix Unary | Left | 6 | Factorial | -| `^` | Binary | Right | 5 | Exponentiation | -| `-` | Prefix Unary | Right | 4 | Negation | -| `*`, `/` | Binary | Left | 3 | Multiplication, Division | -| `+`, `-` | Binary | Left | 2 | Addition, Subtraction | -| `==`, `!=`, `<`, `<=`, `>`, `>=` | Binary | Left | 1 | Comparisons (return 1 or 0) | -| `()` | Grouping | - | - | Parentheses for grouping | +**Arithmetic**: `+`, `-`, `*`, `/`, `^` (power), `!` (factorial), unary `-` +**Comparison**: `==`, `!=`, `<`, `<=`, `>`, `>=` (returns 1 or 0) +**Grouping**: `(` `)` ## Command Line Usage diff --git a/bin/Cargo.toml b/bin/Cargo.toml index c283e3e..b06ccf9 100644 --- a/bin/Cargo.toml +++ b/bin/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "expr-solver-bin" -version = "1.0.3" +version = "1.1.0" edition = "2024" authors = ["Albert Varaksin "] description = "Binary using the expr-solver-lib to solve math expressions from command line" @@ -15,6 +15,6 @@ name = "expr-solver" path = "src/main.rs" [dependencies] -expr-solver-lib = { version = "1.0.3", path = "../lib" } +expr-solver-lib = { version = "1.1.0", path = "../lib" } clap = { version = "4.0", features = ["derive"] } rust_decimal = { workspace = true } diff --git a/bin/src/main.rs b/bin/src/main.rs index 6939b94..05304f8 100644 --- a/bin/src/main.rs +++ b/bin/src/main.rs @@ -1,5 +1,5 @@ use clap::{ArgAction, Parser}; -use expr_solver::{Eval, SymTable, Symbol}; +use expr_solver::{SymTable, Symbol, eval_file_with_table, load_with_table}; use rust_decimal::prelude::*; use std::path::PathBuf; @@ -44,11 +44,8 @@ fn parse_key_val(s: &str) -> Result<(String, f64), Box { - eprintln!("{err}"); - } - _ => {} + if let Err(err) = run() { + eprintln!("{err}"); } } @@ -65,25 +62,28 @@ fn run() -> Result<(), String> { } // load either from string input or a file - let mut eval = if let Some(expr) = args.expression.as_ref().or(args.expr.as_ref()) { - Eval::with_table(expr, table) - } else if let Some(input) = &args.input { - Eval::from_file_with_table(input.clone(), table) - } else { - return Err("no input".to_string()); - }; + if let Some(expr) = args.expression.as_ref().or(args.expr.as_ref()) { + let program = load_with_table(expr, table)?; - if args.assembly { - print!("{}", eval.get_assembly()?); - return Ok(()); - } + if args.assembly { + print!("{}", program.get_assembly()); + return Ok(()); + } - // save to a file? - if let Some(output_path) = &args.output { - eval.compile_to_file(output_path)? - } else { - let res = eval.run()?; + // save to a file? + if let Some(output_path) = &args.output { + program + .save_bytecode_to_file(output_path) + .map_err(|e| e.to_string())? + } else { + let res = program.execute().map_err(|e| e.to_string())?; + println!("{res}"); + } + } else if let Some(input) = &args.input { + let res = eval_file_with_table(input.to_string_lossy().as_ref(), table)?; println!("{res}"); + } else { + return Err("no input".to_string()); } Ok(()) diff --git a/lib/Cargo.toml b/lib/Cargo.toml index e4e2521..04dd1ce 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "expr-solver-lib" -version = "1.0.3" +version = "1.1.0" edition = "2024" authors = ["Albert Varaksin "] description = "A simple math expression solver library" diff --git a/lib/src/ast.rs b/lib/src/ast.rs index b733a5f..91c9e81 100644 --- a/lib/src/ast.rs +++ b/lib/src/ast.rs @@ -1,10 +1,15 @@ +//! Abstract Syntax Tree for mathematical expressions. + use crate::span::Span; use crate::token::Token; use rust_decimal::Decimal; +/// Unary operators: negation and factorial. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum UnOp { + /// Negation (`-`) Neg, + /// Factorial (`!`) Fact, } @@ -18,19 +23,30 @@ impl UnOp { } } +/// Binary operators: arithmetic and comparison. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum BinOp { + /// Addition (`+`) Add, + /// Subtraction (`-`) Sub, + /// Multiplication (`*`) Mul, + /// Division (`/`) Div, + /// Exponentiation (`^`) Pow, - // Comparison operators + /// Equality (`==`) Equal, + /// Inequality (`!=`) NotEqual, + /// Less than (`<`) Less, + /// Less than or equal (`<=`) LessEqual, + /// Greater than (`>`) Greater, + /// Greater than or equal (`>=`) GreaterEqual, } @@ -53,36 +69,41 @@ impl BinOp { } } +/// Expression node in the AST with source location. #[derive(Debug, Clone)] -pub struct Expr<'src> { - pub kind: ExprKind<'src>, +pub struct Expr { + pub kind: ExprKind, pub span: Span, } +/// Expression kind representing different types of expressions. #[derive(Debug, Clone)] -pub enum ExprKind<'src> { +pub enum ExprKind { + /// Numeric literal Literal(Decimal), + /// Identifier (constant or variable) Ident { - name: &'src str, - sym_index: Option, + name: String, }, + /// Unary operation Unary { op: UnOp, - expr: Box>, + expr: Box, }, + /// Binary operation Binary { op: BinOp, - left: Box>, - right: Box>, + left: Box, + right: Box, }, + /// Function call Call { - name: &'src str, - args: Vec>, - sym_index: Option, + name: String, + args: Vec, }, } -impl<'src> Expr<'src> { +impl Expr { pub fn literal(value: Decimal, span: Span) -> Self { Self { kind: ExprKind::Literal(value), @@ -90,17 +111,14 @@ impl<'src> Expr<'src> { } } - pub fn ident(name: &'src str, span: Span) -> Self { + pub fn ident(name: String, span: Span) -> Self { Self { - kind: ExprKind::Ident { - name, - sym_index: None, - }, + kind: ExprKind::Ident { name }, span, } } - pub fn unary(op: UnOp, expr: Expr<'src>, span: Span) -> Self { + pub fn unary(op: UnOp, expr: Expr, span: Span) -> Self { Self { kind: ExprKind::Unary { op, @@ -110,7 +128,7 @@ impl<'src> Expr<'src> { } } - pub fn binary(op: BinOp, left: Expr<'src>, right: Expr<'src>, span: Span) -> Self { + pub fn binary(op: BinOp, left: Expr, right: Expr, span: Span) -> Self { Self { kind: ExprKind::Binary { op, @@ -121,13 +139,9 @@ impl<'src> Expr<'src> { } } - pub fn call(name: &'src str, args: Vec>, span: Span) -> Self { + pub fn call(name: String, args: Vec, span: Span) -> Self { Self { - kind: ExprKind::Call { - name, - args, - sym_index: None, - }, + kind: ExprKind::Call { name, args }, span, } } diff --git a/lib/src/error.rs b/lib/src/error.rs new file mode 100644 index 0000000..0c77d15 --- /dev/null +++ b/lib/src/error.rs @@ -0,0 +1,68 @@ +//! Error types for parsing, linking, and program operations. + +use crate::span::Span; +use crate::span::SpanError; +use thiserror::Error; + +/// Errors that can occur during parsing. +#[derive(Error, Debug)] +pub enum ParseError { + #[error("Unexpected token: {message}")] + UnexpectedToken { message: String, span: Span }, + #[error("Unexpected end of input")] + UnexpectedEof { span: Span }, + #[error("Invalid number literal: {message}")] + InvalidNumber { message: String, span: Span }, +} + +impl SpanError for ParseError { + fn span(&self) -> Span { + match self { + ParseError::UnexpectedToken { span, .. } => *span, + ParseError::UnexpectedEof { span } => *span, + ParseError::InvalidNumber { span, .. } => *span, + } + } +} + +/// Errors that can occur during linking. +#[derive(Error, Debug)] +pub enum LinkError { + #[error("Missing symbol: '{name}' is required by bytecode but not in symbol table")] + MissingSymbol { name: String }, + + #[error("Type mismatch for symbol '{name}': expected {expected}, found {found}")] + TypeMismatch { + name: String, + expected: String, + found: String, + }, + + #[error("Symbol table error: {0}")] + SymbolTableError(#[from] crate::symbol::SymbolError), +} + +/// Errors that can occur during program operations. +#[derive(Error, Debug)] +pub enum ProgramError { + #[error("{0}")] + ParseError(String), + + #[error("Link error: {0}")] + LinkError(#[from] LinkError), + + #[error("Serialization error: {0}")] + SerializationError(#[from] bincode::error::EncodeError), + + #[error("Deserialization error: {0}")] + DeserializationError(#[from] bincode::error::DecodeError), + + #[error("Incompatible program version: expected {expected}, got {found}")] + IncompatibleVersion { expected: String, found: String }, + + #[error("Invalid symbol index: {0}")] + InvalidSymbolIndex(usize), + + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), +} diff --git a/lib/src/ir.rs b/lib/src/ir.rs index 87e78e9..b324a0a 100644 --- a/lib/src/ir.rs +++ b/lib/src/ir.rs @@ -1,17 +1,9 @@ -use crate::ast::{BinOp, Expr, ExprKind, UnOp}; -use crate::program::Program; -use crate::span::Span; +//! Bytecode instruction definitions for the virtual machine. + use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; -use thiserror::Error; - -/// IR building errors. -#[derive(Error, Debug, Clone)] -pub enum IrError { - #[error("Undefined symbol {0}")] - UndefinedSymbol(String, Span), -} +/// Bytecode instructions for the stack-based virtual machine. #[derive(Debug, Clone, Serialize, Deserialize)] pub enum Instr { Push(Decimal), @@ -32,77 +24,3 @@ pub enum Instr { Greater, GreaterEqual, } - -/// Builder for converting AST expressions into bytecode programs. -pub struct IrBuilder { - prog: Program, -} - -impl IrBuilder { - /// Creates a new IR builder. - pub fn new() -> Self { - Self { - prog: Program::new(), - } - } - - /// Builds a bytecode program from an AST expression. - pub fn build<'src>(mut self, expr: &Expr<'src>) -> Result { - self.emit(expr)?; - Ok(self.prog) - } - - fn emit<'src>(&mut self, e: &Expr<'src>) -> Result<(), IrError> { - match &e.kind { - ExprKind::Literal(v) => { - self.prog.code.push(Instr::Push(*v)); - } - ExprKind::Ident { name, sym_index } => { - if let Some(idx) = sym_index { - self.prog.code.push(Instr::Load(*idx)); - } else { - return Err(IrError::UndefinedSymbol(name.to_string(), e.span)); - } - } - ExprKind::Unary { op, expr } => { - self.emit(expr)?; - match op { - UnOp::Neg => self.prog.code.push(Instr::Neg), - UnOp::Fact => self.prog.code.push(Instr::Fact), - } - } - ExprKind::Binary { op, left, right } => { - self.emit(left)?; - self.emit(right)?; - self.prog.code.push(match op { - BinOp::Add => Instr::Add, - BinOp::Sub => Instr::Sub, - BinOp::Mul => Instr::Mul, - BinOp::Div => Instr::Div, - BinOp::Pow => Instr::Pow, - BinOp::Equal => Instr::Equal, - BinOp::NotEqual => Instr::NotEqual, - BinOp::Less => Instr::Less, - BinOp::LessEqual => Instr::LessEqual, - BinOp::Greater => Instr::Greater, - BinOp::GreaterEqual => Instr::GreaterEqual, - }); - } - ExprKind::Call { - name, - args, - sym_index, - } => { - if let Some(idx) = sym_index { - for a in args.iter() { - self.emit(a)?; - } - self.prog.code.push(Instr::Call(*idx, args.len())); - } else { - return Err(IrError::UndefinedSymbol(name.to_string(), e.span)); - } - } - } - Ok(()) - } -} diff --git a/lib/src/lexer.rs b/lib/src/lexer.rs index 5a9d9f4..50b9c58 100644 --- a/lib/src/lexer.rs +++ b/lib/src/lexer.rs @@ -1,4 +1,5 @@ -use crate::source::Source; +//! Lexer for tokenizing mathematical expressions. + use crate::span::Span; use crate::token::Token; use rust_decimal::Decimal; @@ -15,11 +16,11 @@ pub struct Lexer<'src> { } impl<'src> Lexer<'src> { - /// Create a new lexer from a source. - pub fn new(source: &'src Source) -> Self { + /// Create a new lexer from a string slice. + pub fn new(input: &'src str) -> Self { Self { - input: source.input, - iter: source.input.chars().peekable(), + input, + iter: input.chars().peekable(), start: 0, pos: 0, } @@ -31,7 +32,7 @@ impl<'src> Lexer<'src> { self.start = self.pos; let ch = match self.read() { Some(c) => c, - None => return Token::EOF, + None => return Token::Eof, }; match ch { '0'..='9' => self.number(false), @@ -115,7 +116,7 @@ impl<'src> Lexer<'src> { } fn peek(&mut self) -> Option { - self.iter.peek().map(|ch| *ch) + self.iter.peek().copied() } fn read(&mut self) -> Option { @@ -182,93 +183,3 @@ impl<'src> Lexer<'src> { } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::source::Source; - use crate::token::Token; - use rust_decimal::dec; - - fn lex_all<'src>(source: &'src Source) -> Vec> { - let mut lexer = Lexer::new(source); - let mut tokens = Vec::new(); - loop { - let tok = lexer.next(); - if matches!(tok, Token::EOF) { - break; - } - tokens.push(tok); - } - tokens - } - - #[test] - fn test_number_with_multiple_dots() { - let source = Source { input: "1.2.3" }; - let tokens = lex_all(&source); - assert_eq!(tokens[0], Token::Invalid("1.2.3")); - } - - #[test] - fn test_identifier_with_emoji() { - let source = Source { - input: "foo😀 bar🚀", - }; - let tokens = lex_all(&source); - assert_eq!(tokens[0], Token::Ident("foo😀")); - assert_eq!(tokens[1], Token::Ident("bar🚀")); - } - - #[test] - fn test_unknown_token() { - let source = Source { input: "$" }; - let tokens = lex_all(&source); - assert_eq!(tokens[0], Token::Invalid("$")); - } - - #[test] - fn test_whitespace_handling() { - let source = Source { - input: " 1 + 2\t\t*", - }; - let tokens = lex_all(&source); - assert_eq!(tokens[0], Token::Number(dec!(1))); - assert_eq!(tokens[1], Token::Plus); - assert_eq!(tokens[2], Token::Number(dec!(2))); - assert_eq!(tokens[3], Token::Star); - } - - #[test] - fn test_comparison_operators() { - let source = Source { - input: "== != < <= > >=", - }; - let tokens = lex_all(&source); - assert_eq!(tokens[0], Token::Equal); - assert_eq!(tokens[1], Token::NotEqual); - assert_eq!(tokens[2], Token::Less); - assert_eq!(tokens[3], Token::LessEqual); - assert_eq!(tokens[4], Token::Greater); - assert_eq!(tokens[5], Token::GreaterEqual); - } - - #[test] - fn test_factorial_vs_not_equal() { - // Test that ! is factorial but != is not equal - let source = Source { input: "5! != 100" }; - let tokens = lex_all(&source); - assert_eq!(tokens[0], Token::Number(dec!(5))); - assert_eq!(tokens[1], Token::Bang); - assert_eq!(tokens[2], Token::NotEqual); - assert_eq!(tokens[3], Token::Number(dec!(100))); - } - - #[test] - fn test_invalid_single_equals() { - // Single '=' should be invalid since we only support '==' - let source = Source { input: "=" }; - let tokens = lex_all(&source); - assert_eq!(tokens[0], Token::Invalid("=")); - } -} diff --git a/lib/src/lib.rs b/lib/src/lib.rs index c6d37a4..7b8ce69 100644 --- a/lib/src/lib.rs +++ b/lib/src/lib.rs @@ -1,267 +1,192 @@ -//! A simple expression solver library +//! A mathematical expression evaluator library with bytecode compilation. //! -//! Parses and evaluates mathematical expressions with built-in functions and constants. +//! This library provides a complete compiler pipeline for mathematical expressions, +//! from parsing to bytecode execution on a stack-based virtual machine. //! //! # Features //! -//! - Mathematical operators: `+`, `-`, `*`, `/`, `^`, unary `-`, `!` (factorial) -//! - Comparison operators: `==`, `!=`, `<`, `<=`, `>`, `>=` (return 1.0 or 0.0) -//! - Built-in constants: `pi`, `e`, `tau`, `ln2`, `ln10`, `sqrt2` -//! - Basic math functions: `abs`, `floor`, `ceil`, `round`, `trunc`, `fract` -//! - Variadic functions: `min`, `max`, `sum`, `avg` -//! - 128-bit decimal arithmetic (no floating-point representation errors!) -//! - Error handling with source location information +//! - **Type-safe compilation** - Uses Rust's type system to enforce correct pipeline order +//! - **128-bit decimal precision** - No floating-point errors using `rust_decimal` +//! - **Rich error messages** - Parse errors with syntax highlighting +//! - **Bytecode compilation** - Compile once, execute many times +//! - **Custom symbols** - Add your own constants and functions +//! - **Serialization** - Save/load compiled programs to/from disk +//! +//! # Quick Start +//! +//! ``` +//! use expr_solver::eval; +//! +//! // Simple evaluation +//! let result = eval("2 + 3 * 4").unwrap(); +//! assert_eq!(result.to_string(), "14"); +//! ``` +//! +//! # Custom Symbols +//! +//! ``` +//! use expr_solver::{eval_with_table, SymTable}; +//! use rust_decimal_macros::dec; +//! +//! let mut table = SymTable::stdlib(); +//! table.add_const("x", dec!(10)).unwrap(); +//! +//! let result = eval_with_table("x * 2", table).unwrap(); +//! assert_eq!(result, dec!(20)); +//! ``` +//! +//! # Advanced: Type-State Pattern +//! +//! The `Program` type uses the type-state pattern to enforce correct usage: +//! +//! ``` +//! use expr_solver::{load, SymTable}; +//! use rust_decimal_macros::dec; +//! +//! // Compile expression to bytecode +//! let program = load("x + y").unwrap(); +//! +//! // Link with symbol table (validated at link time) +//! let mut table = SymTable::new(); +//! table.add_const("x", dec!(10)).unwrap(); +//! table.add_const("y", dec!(5)).unwrap(); +//! +//! let linked = program.link(table).unwrap(); +//! +//! // Execute +//! let result = linked.execute().unwrap(); +//! assert_eq!(result, dec!(15)); +//! ``` +//! +//! # Supported Operators +//! +//! - Arithmetic: `+`, `-`, `*`, `/`, `^` (power), `!` (factorial), unary `-` +//! - Comparison: `==`, `!=`, `<`, `<=`, `>`, `>=` (return 1 or 0) +//! - Grouping: `(` `)` +//! +//! # Built-in Functions +//! +//! See [`SymTable::stdlib()`] for the complete list of built-in functions and constants. -mod ast; +// Core types (shared) mod ir; -mod lexer; -mod parser; -mod program; - -mod sema; -mod source; mod span; mod symbol; mod token; mod vm; -use std::{borrow::Cow, fmt, fs, path::PathBuf}; +// Expression solver implementation +mod ast; +mod error; +mod lexer; +mod metadata; +mod parser; +mod program; + +use rust_decimal::Decimal; // Public API -pub use ir::IrBuilder; +pub use ast::{BinOp, Expr, ExprKind, UnOp}; +pub use error::{LinkError, ParseError, ProgramError}; +pub use metadata::{SymbolKind, SymbolMetadata}; pub use parser::Parser; -pub use program::Program; - -use crate::ast::Expr; -use crate::span::SpanError; -use rust_decimal::Decimal; -pub use sema::Sema; -pub use source::Source; +pub use program::{Compiled, Linked, Program, ProgramOrigin}; pub use symbol::{SymTable, Symbol, SymbolError}; pub use vm::{Vm, VmError}; -/// A wrapper that formats errors with source code highlighting -struct FormattedError { - message: String, -} +// ============================================================================ +// Helper functions for evaluating expressions +// ============================================================================ -impl fmt::Display for FormattedError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.message) - } +/// Evaluates an expression string with the standard library. +/// +/// # Examples +/// +/// ``` +/// use expr_solver::eval; +/// +/// let result = eval("2 + 3 * 4").unwrap(); +/// assert_eq!(result.to_string(), "14"); +/// ``` +pub fn eval(expression: &str) -> Result { + let program = load_with_table(expression, SymTable::stdlib())?; + program.execute().map_err(|err| err.to_string()) } -impl From<(&T, &Source<'_>)> for FormattedError { - fn from((error, source): (&T, &Source<'_>)) -> Self { - Self { - message: format!("{}\n{}", error, source.highlight(&error.span())), - } - } +/// Evaluates an expression string with a custom symbol table. +/// +/// # Examples +/// +/// ``` +/// use expr_solver::{eval_with_table, SymTable}; +/// use rust_decimal_macros::dec; +/// +/// let mut table = SymTable::stdlib(); +/// table.add_const("x", dec!(42)).unwrap(); +/// +/// let result = eval_with_table("x * 2", table).unwrap(); +/// assert_eq!(result, dec!(84)); +/// ``` +pub fn eval_with_table(expression: &str, table: SymTable) -> Result { + let program = load_with_table(expression, table)?; + program.execute().map_err(|err| err.to_string()) } -#[derive(Debug)] -enum EvalSource<'str> { - Source(Cow<'str, Source<'str>>), - File(PathBuf), +/// Evaluates an expression from a binary file with the standard library. +/// +/// # Examples +/// +/// ```no_run +/// use expr_solver::eval_file; +/// +/// let result = eval_file("expr.bin").unwrap(); +/// ``` +pub fn eval_file(path: impl AsRef) -> Result { + eval_file_with_table(path, SymTable::stdlib()) } -/// Expression evaluator with support for custom symbols and bytecode compilation. +/// Evaluates an expression from a binary file with a custom symbol table. +/// +/// # Examples /// -/// `Eval` is the main entry point for evaluating mathematical expressions. It supports -/// both quick one-off evaluations and reusable evaluators with custom symbol tables. +/// ```no_run +/// use expr_solver::{eval_file_with_table, SymTable}; +/// +/// let result = eval_file_with_table("expr.bin", SymTable::stdlib()).unwrap(); +/// ``` +pub fn eval_file_with_table(path: impl AsRef, table: SymTable) -> Result { + let program = Program::new_from_file(path.as_ref()).map_err(|err| err.to_string())?; + let linked = program.link(table).map_err(|err| err.to_string())?; + linked.execute().map_err(|err| err.to_string()) +} + +/// Loads and compiles an expression, returning a compiled program. /// /// # Examples /// /// ``` -/// use expr_solver::Eval; +/// use expr_solver::{load, SymTable}; /// -/// // Quick evaluation -/// let result = Eval::evaluate("2 + 3 * 4").unwrap(); +/// let program = load("2 + 3 * 4").unwrap(); +/// let linked = program.link(SymTable::stdlib()).unwrap(); +/// let result = linked.execute().unwrap(); /// assert_eq!(result.to_string(), "14"); -/// -/// // Reusable evaluator -/// let mut eval = Eval::new("sqrt(16) + pi"); -/// let result = eval.run().unwrap(); /// ``` -#[derive(Debug)] -pub struct Eval<'str> { - source: EvalSource<'str>, - table: SymTable, +pub fn load(expression: &str) -> Result, String> { + Program::new_from_source(expression).map_err(|err| err.to_string()) } -impl<'str> Eval<'str> { - /// Quick evaluation of an expression with the standard library. - /// - /// This is a convenience method for one-off evaluations. - /// - /// # Examples - /// - /// ``` - /// use expr_solver::Eval; - /// - /// let result = Eval::evaluate("2^8").unwrap(); - /// assert_eq!(result.to_string(), "256"); - /// ``` - pub fn evaluate(expression: &'str str) -> Result { - Self::new(expression).run() - } - - /// Quick evaluation of an expression with a custom symbol table. - /// - /// # Examples - /// - /// ``` - /// use expr_solver::{Eval, SymTable}; - /// use rust_decimal_macros::dec; - /// - /// let mut table = SymTable::stdlib(); - /// table.add_const("x", dec!(42)).unwrap(); - /// - /// let result = Eval::evaluate_with_table("x * 2", table).unwrap(); - /// assert_eq!(result, dec!(84)); - /// ``` - pub fn evaluate_with_table(expression: &'str str, table: SymTable) -> Result { - Self::with_table(expression, table).run() - } - - /// Creates a new evaluator with the standard library. - /// - /// # Examples - /// - /// ``` - /// use expr_solver::Eval; - /// - /// let mut eval = Eval::new("sin(pi/2)"); - /// let result = eval.run().unwrap(); - /// ``` - pub fn new(string: &'str str) -> Self { - Self::with_table(string, SymTable::stdlib()) - } - - /// Creates a new evaluator with a custom symbol table. - /// - /// # Examples - /// - /// ``` - /// use expr_solver::{Eval, SymTable}; - /// use rust_decimal_macros::dec; - /// - /// let mut table = SymTable::stdlib(); - /// table.add_const("x", dec!(42)).unwrap(); - /// - /// let mut eval = Eval::with_table("x * 2", table); - /// let result = eval.run().unwrap(); - /// assert_eq!(result, dec!(84)); - /// ``` - pub fn with_table(string: &'str str, table: SymTable) -> Self { - let source = Source::new(string); - Self { - source: EvalSource::Source(Cow::Owned(source)), - table, - } - } - - /// Creates a new evaluator from a [`Source`] reference. - pub fn new_from_source(source: &'str Source<'str>) -> Self { - Self::from_source_with_table(source, SymTable::stdlib()) - } - - /// Creates a new evaluator from a [`Source`] reference with a custom symbol table. - pub fn from_source_with_table(source: &'str Source<'str>, table: SymTable) -> Self { - Self { - source: EvalSource::Source(Cow::Borrowed(source)), - table, - } - } - - /// Creates a new evaluator from a compiled binary file. - /// - /// The file must have been created using [`compile_to_file`](Self::compile_to_file). - pub fn new_from_file(path: PathBuf) -> Self { - Self::from_file_with_table(path, SymTable::stdlib()) - } - - /// Creates a new evaluator from a compiled binary file with a custom symbol table. - pub fn from_file_with_table(path: PathBuf, table: SymTable) -> Self { - Self { - source: EvalSource::File(path), - table, - } - } - - /// Evaluates the expression and returns the result. - /// - /// # Examples - /// - /// ``` - /// use expr_solver::Eval; - /// - /// let mut eval = Eval::new("2 + 3"); - /// assert_eq!(eval.run().unwrap().to_string(), "5"); - /// ``` - pub fn run(&mut self) -> Result { - let program = self.build_program()?; - Vm::default() - .run(&program, &self.table) - .map_err(|err| err.to_string()) - } - - /// Compiles the expression to a binary file. - /// - /// The compiled bytecode can later be loaded with [`new_from_file`](Self::new_from_file). - /// - /// # Examples - /// - /// ```no_run - /// use expr_solver::Eval; - /// use std::path::PathBuf; - /// - /// let mut eval = Eval::new("2 + 3 * 4"); - /// eval.compile_to_file(&PathBuf::from("expr.bin")).unwrap(); - /// ``` - pub fn compile_to_file(&mut self, path: &PathBuf) -> Result<(), String> { - let program = self.build_program()?; - let binary_data = program.compile().map_err(|err| err.to_string())?; - fs::write(path, binary_data).map_err(|err| err.to_string()) - } - - /// Returns a human-readable assembly representation of the compiled expression. - /// - /// # Examples - /// - /// ``` - /// use expr_solver::Eval; - /// - /// let mut eval = Eval::new("2 + 3"); - /// let assembly = eval.get_assembly().unwrap(); - /// assert!(assembly.contains("PUSH")); - /// assert!(assembly.contains("ADD")); - /// ``` - pub fn get_assembly(&mut self) -> Result { - let program = self.build_program()?; - Ok(program.get_assembly(&self.table)) - } - - fn build_program(&mut self) -> Result { - match &self.source { - EvalSource::Source(source) => { - let mut parser = Parser::new(source); - let mut ast: Expr = match parser - .parse() - .map_err(|err| FormattedError::from((&err, source.as_ref())).to_string())? - { - Some(ast) => ast, - None => return Ok(Program::default()), - }; - Sema::new(&self.table) - .visit(&mut ast) - .map_err(|err| FormattedError::from((&err, source.as_ref())).to_string())?; - IrBuilder::new().build(&ast).map_err(|err| err.to_string()) - } - EvalSource::File(path) => { - let binary_data = fs::read(path).map_err(|err| err.to_string())?; - Program::load(&binary_data).map_err(|err| err.to_string()) - } - } - } +/// Loads, compiles, and links an expression, returning a ready-to-execute program. +/// +/// # Examples +/// +/// ``` +/// use expr_solver::{load_with_table, SymTable}; +/// +/// let program = load_with_table("sin(pi/2)", SymTable::stdlib()).unwrap(); +/// let result = program.execute().unwrap(); +/// ``` +pub fn load_with_table(expression: &str, table: SymTable) -> Result, String> { + let program = load(expression)?; + program.link(table).map_err(|err| err.to_string()) } diff --git a/lib/src/metadata.rs b/lib/src/metadata.rs new file mode 100644 index 0000000..b4c989d --- /dev/null +++ b/lib/src/metadata.rs @@ -0,0 +1,56 @@ +//! Symbol metadata for bytecode validation and linking. + +use crate::symbol::Symbol; +use serde::{Deserialize, Serialize}; +use std::borrow::Cow; + +/// Metadata about a symbol required by compiled bytecode. +/// +/// This is used to validate and remap symbol indices when linking +/// bytecode with a symbol table. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SymbolMetadata { + /// The name of the symbol + pub name: Cow<'static, str>, + /// The kind and requirements of the symbol + pub kind: SymbolKind, + /// The resolved index in the linked symbol table (None until linked) + #[serde(skip)] + pub index: Option, +} + +/// The kind of symbol (constant or function) with its requirements. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SymbolKind { + /// A constant value + Const, + /// A function with specified arity + Func { + /// Minimum number of arguments + arity: usize, + /// Whether the function accepts additional arguments + variadic: bool, + }, +} + +impl From<&Symbol> for SymbolKind { + fn from(symbol: &Symbol) -> Self { + match symbol { + Symbol::Const { .. } => SymbolKind::Const, + Symbol::Func { args, variadic, .. } => SymbolKind::Func { + arity: *args, + variadic: *variadic, + }, + } + } +} + +impl From<&Symbol> for SymbolMetadata { + fn from(symbol: &Symbol) -> Self { + SymbolMetadata { + name: symbol.name().to_string().into(), + kind: symbol.into(), + index: None, + } + } +} diff --git a/lib/src/parser.rs b/lib/src/parser.rs index 2697d2c..d7418d2 100644 --- a/lib/src/parser.rs +++ b/lib/src/parser.rs @@ -1,183 +1,195 @@ -use crate::ast::{BinOp, Expr, UnOp}; -use crate::lexer::Lexer; -use crate::source::Source; -use crate::span::{Span, SpanError}; -use crate::token::Token; -use thiserror::Error; - -/// Expression parsing errors. -#[derive(Error, Debug, Clone)] -pub enum ParseError { - #[error("Unexpected token '{found}', expected '{expected}'")] - UnexpectedToken { - found: String, - expected: String, - span: Span, - }, -} +//! Recursive descent parser for mathematical expressions. -impl SpanError for ParseError { - fn span(&self) -> Span { - match self { - ParseError::UnexpectedToken { span, .. } => *span, - } - } -} +use super::ast::{BinOp, Expr, UnOp}; +use super::error::ParseError; +use super::lexer::Lexer; +use crate::span::Span; +use crate::token::Token; -pub type ParseResult<'src> = Result, ParseError>; +pub type ParseResult = Result; /// Recursive descent parser for mathematical expressions. /// /// Uses operator precedence climbing for efficient binary operator parsing. +/// +/// # Examples +/// +/// ``` +/// use expr_solver::Parser; +/// +/// let mut parser = Parser::new("2 + 3 * 4"); +/// let ast = parser.parse().unwrap(); +/// assert!(ast.is_some()); +/// ``` pub struct Parser<'src> { - lexer: Lexer<'src>, - lookahead: Token<'src>, - span: Span, + input: &'src str, } impl<'src> Parser<'src> { - /// Creates a new parser from a source. - pub fn new(source: &'src Source) -> Self { - let mut lexer = Lexer::new(source); - let lookahead = lexer.next(); - let span = lexer.span(); - Self { - lexer, - lookahead, - span, - } + /// Creates a new parser from a string slice. + pub fn new(input: &'src str) -> Self { + Self { input } } - /// Parses the source into an abstract syntax tree. + /// Parses the input into an abstract syntax tree. /// - /// Returns `None` for empty input, or an expression AST on success. - pub fn parse(&mut self) -> Result>, ParseError> { - if self.lookahead == Token::EOF { + /// Returns `None` for empty input, or an expression on success. + pub fn parse(&mut self) -> Result, ParseError> { + let mut lexer = Lexer::new(self.input); + let mut lookahead = lexer.next(); + let mut span = lexer.span(); + + if lookahead == Token::Eof { return Ok(None); } - let expr = self.expression()?; - self.expect(&Token::EOF)?; + + let expr = Self::expression(&mut lexer, &mut lookahead, &mut span)?; + Self::expect_token(&mut lexer, &mut lookahead, &mut span, &Token::Eof)?; Ok(Some(expr)) } - fn expression(&mut self) -> ParseResult<'src> { - let lhs = self.primary()?; - self.climb(lhs, 1) + fn expression<'lex>( + lexer: &mut Lexer<'lex>, + lookahead: &mut Token<'lex>, + span: &mut Span, + ) -> ParseResult { + let lhs = Self::primary(lexer, lookahead, span)?; + Self::climb(lexer, lookahead, span, lhs, 1) } - fn primary(&mut self) -> ParseResult<'src> { - let span = self.span; - match self.lookahead { + fn primary<'lex>( + lexer: &mut Lexer<'lex>, + lookahead: &mut Token<'lex>, + span: &mut Span, + ) -> ParseResult { + let current_span = *span; + match *lookahead { Token::Number(n) => { - self.advance(); - Ok(Expr::literal(n, span)) + Self::advance(lexer, lookahead, span); + Ok(Expr::literal(n, current_span)) } Token::Ident(id) => { - self.advance(); - if self.lookahead == Token::ParenOpen { - return self.call(id, span); + let id_string = id.to_string(); + Self::advance(lexer, lookahead, span); + if *lookahead == Token::ParenOpen { + return Self::call(lexer, lookahead, span, id_string, current_span); } - Ok(Expr::ident(id, span)) + Ok(Expr::ident(id_string, current_span)) } Token::Minus => { - self.advance(); - let expr = self.primary()?; - let expr = self.climb(expr, Token::Negate.precedence())?; - let span = self.span.merge(expr.span); - Ok(Expr::unary(UnOp::Neg, expr, span)) + Self::advance(lexer, lookahead, span); + let expr = Self::primary(lexer, lookahead, span)?; + let expr = Self::climb(lexer, lookahead, span, expr, Token::Negate.precedence())?; + let full_span = current_span.merge(expr.span); + Ok(Expr::unary(UnOp::Neg, expr, full_span)) } Token::ParenOpen => { - self.advance(); - let expr = self.expression()?; - self.expect(&Token::ParenClose)?; + Self::advance(lexer, lookahead, span); + let expr = Self::expression(lexer, lookahead, span)?; + Self::expect_token(lexer, lookahead, span, &Token::ParenClose)?; Ok(expr) } _ => Err(ParseError::UnexpectedToken { - found: self.lookahead.lexeme().to_string(), - expected: "an expression".to_string(), - span, + message: format!( + "unexpected token '{}', expected an expression", + lookahead.lexeme() + ), + span: current_span, }), } } - fn call(&mut self, id: &'src str, span: Span) -> ParseResult<'src> { + fn call<'lex>( + lexer: &mut Lexer<'lex>, + lookahead: &mut Token<'lex>, + span: &mut Span, + id: String, + start_span: Span, + ) -> ParseResult { // assume lookahead is '(' - self.advance(); + Self::advance(lexer, lookahead, span); - let mut args: Vec> = Vec::new(); - while self.lookahead != Token::ParenClose { - let arg = self.expression()?; + let mut args: Vec = Vec::new(); + while *lookahead != Token::ParenClose { + let arg = Self::expression(lexer, lookahead, span)?; args.push(arg); - if self.lookahead == Token::Comma { - self.advance(); + if *lookahead == Token::Comma { + Self::advance(lexer, lookahead, span); } else { break; } } - self.expect(&Token::ParenClose)?; + Self::expect_token(lexer, lookahead, span, &Token::ParenClose)?; - let span = span.merge(self.span); - Ok(Expr::call(id, args, span)) + let full_span = start_span.merge(*span); + Ok(Expr::call(id, args, full_span)) } - fn climb(&mut self, mut lhs: Expr<'src>, min_prec: u8) -> ParseResult<'src> { - let mut prec = self.lookahead.precedence(); + fn climb<'lex>( + lexer: &mut Lexer<'lex>, + lookahead: &mut Token<'lex>, + span: &mut Span, + mut lhs: Expr, + min_prec: u8, + ) -> ParseResult { + let mut prec = lookahead.precedence(); while prec >= min_prec { // Handle postfix unary operators - if self.lookahead.is_postfix_unary() { - let op = self.lookahead.clone(); - let op_span = self.span; - self.advance(); - prec = self.lookahead.precedence(); + if lookahead.is_postfix_unary() { + let op = lookahead.clone(); + let op_span = *span; + Self::advance(lexer, lookahead, span); + prec = lookahead.precedence(); let unary_op = UnOp::from_token(&op); - let span = lhs.span.merge(op_span); - lhs = Expr::unary(unary_op, lhs, span); + let full_span = lhs.span.merge(op_span); + lhs = Expr::unary(unary_op, lhs, full_span); continue; } - let op = self.lookahead.clone(); + let op = lookahead.clone(); - self.advance(); - let mut rhs = self.primary()?; - prec = self.lookahead.precedence(); + Self::advance(lexer, lookahead, span); + let mut rhs = Self::primary(lexer, lookahead, span)?; + prec = lookahead.precedence(); while prec > op.precedence() - || (self.lookahead.is_right_associative() && prec == op.precedence()) + || (lookahead.is_right_associative() && prec == op.precedence()) { - rhs = self.climb(rhs, prec)?; - prec = self.lookahead.precedence(); + rhs = Self::climb(lexer, lookahead, span, rhs, prec)?; + prec = lookahead.precedence(); } - let op = BinOp::from_token(&op); - let span = lhs.span.merge(rhs.span); - lhs = Expr::binary(op, lhs, rhs, span); + let binop = BinOp::from_token(&op); + let full_span = lhs.span.merge(rhs.span); + lhs = Expr::binary(binop, lhs, rhs, full_span); } Ok(lhs) } - fn advance(&mut self) { - self.lookahead = self.lexer.next(); - self.span = self.lexer.span(); + fn advance<'lex>(lexer: &mut Lexer<'lex>, lookahead: &mut Token<'lex>, span: &mut Span) { + *lookahead = lexer.next(); + *span = lexer.span(); } - fn accept(&mut self, t: &Token<'src>) -> bool { - if self.lookahead == *t { - self.advance(); - true + fn expect_token<'lex>( + lexer: &mut Lexer<'lex>, + lookahead: &mut Token<'lex>, + span: &mut Span, + expected: &Token<'lex>, + ) -> Result<(), ParseError> { + if lookahead == expected { + Self::advance(lexer, lookahead, span); + Ok(()) } else { - false - } - } - - fn expect(&mut self, tkn: &Token<'src>) -> Result<(), ParseError> { - if !self.accept(tkn) { - return Err(ParseError::UnexpectedToken { - found: self.lookahead.lexeme().to_string(), - expected: tkn.lexeme().to_string(), - span: self.span, - }); + Err(ParseError::UnexpectedToken { + message: format!( + "unexpected token '{}', expected '{}'", + lookahead.lexeme(), + expected.lexeme() + ), + span: *span, + }) } - Ok(()) } } diff --git a/lib/src/program.rs b/lib/src/program.rs index 7fc7e67..d4ba48f 100644 --- a/lib/src/program.rs +++ b/lib/src/program.rs @@ -1,130 +1,563 @@ +//! Type-state program implementation for compile-link-execute workflow. + +use super::ast::{BinOp, Expr, ExprKind, UnOp}; +use super::error::{LinkError, ParseError, ProgramError}; +use super::metadata::{SymbolKind, SymbolMetadata}; +use super::parser::Parser; use crate::ir::Instr; -use bincode::config; +use crate::span::{Span, SpanError}; +use crate::symbol::{SymTable, Symbol}; +use crate::vm::{Vm, VmError}; use colored::Colorize; +use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; -use thiserror::Error; +use unicode_width::UnicodeWidthStr; /// Current version of the program format const PROGRAM_VERSION: &str = env!("CARGO_PKG_VERSION"); -/// Expression parsing and evaluation errors. -#[derive(Error, Debug)] -pub enum ProgramError { - #[error("Compilation error: {0}")] - CompileError(String), - #[error("Decoding error: {0}")] - DecodingError(#[from] bincode::error::DecodeError), - #[error("incompatible program version: expected {0}, got {1}")] - IncompatibleVersions(String, String), +/// Binary format for serialization +#[derive(Debug, Clone, Serialize, Deserialize)] +struct BinaryFormat { + version: String, + bytecode: Vec, + symbols: Vec, } -/// Executable program containing bytecode instructions. +/// Origin of a compiled program. +#[derive(Debug, Clone)] +pub enum ProgramOrigin { + /// Loaded from a file (path stored) + File(String), + /// Compiled from source string + Source, + /// Deserialized from bytecode bytes + Bytecode, +} + +/// Type-state program using Rust's type system to enforce correct usage. +/// +/// # Examples +/// +/// ``` +/// use expr_solver::{Program, SymTable}; +/// use rust_decimal_macros::dec; /// -/// Programs reference symbols by index into a [`SymTable`] and can be serialized -/// to binary format for storage or transmission. -#[derive(Default)] -pub struct Program { - pub version: String, - pub code: Vec, +/// // Compile from source +/// let program = Program::new_from_source("x * 2 + 1").unwrap(); +/// +/// // Link with symbol table +/// let mut table = SymTable::new(); +/// table.add_const("x", dec!(5)).unwrap(); +/// let linked = program.link(table).unwrap(); +/// +/// // Execute +/// assert_eq!(linked.execute().unwrap(), dec!(11)); +/// ``` +#[derive(Debug)] +pub struct Program<'src, State> { + source: Option<&'src str>, + state: State, } -/// Binary format for serialization. -#[derive(Debug, Clone, Serialize, Deserialize)] -struct Binary { +/// Compiled state - bytecode ready for linking. +#[derive(Debug)] +pub struct Compiled { + origin: ProgramOrigin, version: String, - code: Vec, + bytecode: Vec, + symbols: Vec, } -impl Program { - /// Creates a new empty program. - pub fn new() -> Self { - Self { - version: PROGRAM_VERSION.to_string(), - code: Vec::new(), - } +/// Linked state - ready to execute. +#[derive(Debug)] +pub struct Linked { + #[allow(dead_code)] + origin: ProgramOrigin, + version: String, + bytecode: Vec, + symtable: SymTable, +} + +// ============================================================================ +// Program - Public constructors (return Compiled state directly) +// ============================================================================ + +impl<'src> Program<'src, Compiled> { + // ======================================================================== + // Public API + // ======================================================================== + + /// Creates a compiled program from source code. + /// + /// # Examples + /// + /// ``` + /// use expr_solver::Program; + /// + /// let program = Program::new_from_source("2 + 3 * 4").unwrap(); + /// ``` + pub fn new_from_source(source: &'src str) -> Result { + let trimmed = source.trim(); + + // Parse + let mut parser = Parser::new(trimmed); + let ast = parser + .parse() + .map_err(|parse_err| { + // Format error with source highlighting + let highlighted = Self::highlight_error(trimmed, &parse_err); + ProgramError::ParseError(format!("{}\n{}", parse_err, highlighted)) + })? + .ok_or_else(|| { + let parse_err = ParseError::UnexpectedEof { + span: Span::new(0, 0), + }; + let highlighted = Self::highlight_error(trimmed, &parse_err); + ProgramError::ParseError(format!("{}\n{}", parse_err, highlighted)) + })?; + + // Compile + let (bytecode, symbols) = Self::generate_bytecode(&ast); + + Ok(Program { + source: Some(trimmed), + state: Compiled { + origin: ProgramOrigin::Source, + version: PROGRAM_VERSION.to_string(), + bytecode, + symbols, + }, + }) } - /// Compiles the program to binary format for serialization. - pub fn compile(&self) -> Result, ProgramError> { - let binary = Binary { - version: self.version.clone(), - code: self.code.clone(), - }; - let config = config::standard(); - bincode::serde::encode_to_vec(&binary, config) - .map_err(|err| ProgramError::CompileError(format!("failed to encode program: {}", err))) + /// Creates a compiled program from a binary file. + /// + /// # Examples + /// + /// ```no_run + /// use expr_solver::Program; + /// + /// let program = Program::new_from_file("expr.bin").unwrap(); + /// ``` + pub fn new_from_file(path: impl Into) -> Result { + let path_str = path.into(); + let data = std::fs::read(&path_str)?; + Self::from_bytecode(&data, ProgramOrigin::File(path_str)) + } + + /// Creates a compiled program from bytecode bytes. + /// + /// Deserializes the bytecode and validates the version. + pub fn new_from_bytecode(data: &[u8]) -> Result { + Self::from_bytecode(data, ProgramOrigin::Bytecode) } - /// Loads a program from binary data. + /// Links the bytecode with a symbol table. + /// + /// Validates that all required symbols are present and compatible. + /// + /// # Examples + /// + /// ``` + /// use expr_solver::{Program, SymTable}; /// - /// The binary data must have been created with [`compile`](Self::compile). - pub fn load(data: &[u8]) -> Result { - let config = config::standard(); - let (decoded, _): (Binary, usize) = bincode::serde::decode_from_slice(&data, config) - .map_err(ProgramError::DecodingError)?; + /// let program = Program::new_from_source("sin(pi)").unwrap(); + /// let linked = program.link(SymTable::stdlib()).unwrap(); + /// ``` + pub fn link(mut self, table: SymTable) -> Result, ProgramError> { + // Validate symbols and fill in their resolved indices + for metadata in &mut self.state.symbols { + let (resolved_idx, symbol) = + table + .get_with_index(&metadata.name) + .ok_or_else(|| LinkError::MissingSymbol { + name: metadata.name.to_string(), + })?; - Self::validate_version(&decoded.version)?; + // Validate kind matches + Self::validate_symbol_kind(metadata, symbol)?; + + // Store resolved index in metadata + metadata.index = Some(resolved_idx); + } + + // Rewrite all indices in bytecode using resolved indices from metadata + for instr in &mut self.state.bytecode { + match instr { + Instr::Load(idx) => { + *idx = self.state.symbols[*idx] + .index + .expect("Symbol should have been resolved during linking"); + } + Instr::Call(idx, _) => { + *idx = self.state.symbols[*idx] + .index + .expect("Symbol should have been resolved during linking"); + } + _ => {} + } + } Ok(Program { - version: decoded.version, - code: decoded.code, + source: self.source, + state: Linked { + origin: self.state.origin, + version: self.state.version, + bytecode: self.state.bytecode, + symtable: table, + }, }) } - fn validate_version(version: &String) -> Result<(), ProgramError> { - if version != PROGRAM_VERSION { - return Err(ProgramError::IncompatibleVersions( - PROGRAM_VERSION.to_string(), - version.clone(), - )); + /// Returns the symbol metadata required by this program. + pub fn symbols(&self) -> &[SymbolMetadata] { + &self.state.symbols + } + + /// Returns the version of this program. + pub fn version(&self) -> &str { + &self.state.version + } + + // ======================================================================== + // Private helpers + // ======================================================================== + + /// Internal helper to create program from bytecode with a specific origin. + fn from_bytecode(data: &[u8], origin: ProgramOrigin) -> Result { + let config = bincode::config::standard(); + let (binary, _): (BinaryFormat, _) = bincode::serde::decode_from_slice(data, config)?; + + // Validate version + if binary.version != PROGRAM_VERSION { + return Err(ProgramError::IncompatibleVersion { + expected: PROGRAM_VERSION.to_string(), + found: binary.version, + }); + } + + Ok(Program { + source: None, // No source for bytecode + state: Compiled { + origin, + version: binary.version, + bytecode: binary.bytecode, + symbols: binary.symbols, + }, + }) + } + + /// Highlights an error in the source code. + fn highlight_error(input: &str, error: &ParseError) -> String { + let span = error.span(); + let pre = Self::escape(&input[..span.start]); + let tok = Self::escape(&input[span.start..span.end]); + let post = Self::escape(&input[span.end..]); + let line = format!("{}{}{}", pre, tok.red().bold(), post); + + let caret = "^".green().bold(); + let squiggly_len = UnicodeWidthStr::width(tok.as_str()); + let caret_offset = UnicodeWidthStr::width(pre.as_str()) + caret.len(); + + format!( + "1 | {0}\n | {1: >2$}{3}", + line, + caret, + caret_offset, + "~".repeat(squiggly_len.saturating_sub(1)).green() + ) + } + + /// Escapes special characters for display. + fn escape(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + for c in s.chars() { + match c { + '\n' => out.push_str("\\n"), + '\r' => out.push_str("\\r"), + other => out.push(other), + } + } + out + } + + /// Generates bytecode and collects symbol metadata in a single AST traversal. + fn generate_bytecode(ast: &Expr) -> (Vec, Vec) { + let mut bytecode = Vec::new(); + let mut symbols = Vec::new(); + Self::emit_instr(ast, &mut bytecode, &mut symbols); + (bytecode, symbols) + } + + /// Emits bytecode instructions for an expression node. + fn emit_instr(expr: &Expr, bytecode: &mut Vec, symbols: &mut Vec) { + match &expr.kind { + ExprKind::Literal(v) => { + bytecode.push(Instr::Push(*v)); + } + ExprKind::Ident { name } => { + // Get or create index for this constant + let idx = Self::get_or_create_symbol(name, SymbolKind::Const, symbols); + bytecode.push(Instr::Load(idx)); + } + ExprKind::Unary { op, expr } => { + Self::emit_instr(expr, bytecode, symbols); + match op { + UnOp::Neg => bytecode.push(Instr::Neg), + UnOp::Fact => bytecode.push(Instr::Fact), + } + } + ExprKind::Binary { op, left, right } => { + Self::emit_instr(left, bytecode, symbols); + Self::emit_instr(right, bytecode, symbols); + bytecode.push(match op { + BinOp::Add => Instr::Add, + BinOp::Sub => Instr::Sub, + BinOp::Mul => Instr::Mul, + BinOp::Div => Instr::Div, + BinOp::Pow => Instr::Pow, + BinOp::Equal => Instr::Equal, + BinOp::NotEqual => Instr::NotEqual, + BinOp::Less => Instr::Less, + BinOp::LessEqual => Instr::LessEqual, + BinOp::Greater => Instr::Greater, + BinOp::GreaterEqual => Instr::GreaterEqual, + }); + } + ExprKind::Call { name, args } => { + // Emit arguments first + for arg in args { + Self::emit_instr(arg, bytecode, symbols); + } + + // Get or create index for this function + let idx = Self::get_or_create_symbol( + name, + SymbolKind::Func { + arity: args.len(), + variadic: false, // Will be validated during linking + }, + symbols, + ); + bytecode.push(Instr::Call(idx, args.len())); + } } - Ok(()) + } + + /// Gets existing symbol index or creates a new one. + /// For ~50 symbols, linear search is faster than HashMap overhead. + fn get_or_create_symbol( + name: &str, + kind: SymbolKind, + symbols: &mut Vec, + ) -> usize { + // Check if symbol already exists + if let Some(pos) = symbols.iter().position(|s| s.name == name) { + return pos; + } + + // Create new symbol entry + symbols.push(SymbolMetadata { + name: name.to_string().into(), + kind, + index: None, + }); + symbols.len() - 1 + } + + /// Validates that a symbol matches the expected kind. + fn validate_symbol_kind(metadata: &SymbolMetadata, symbol: &Symbol) -> Result<(), LinkError> { + match (&metadata.kind, symbol) { + (SymbolKind::Const, Symbol::Const { .. }) => Ok(()), + ( + SymbolKind::Func { arity, .. }, + Symbol::Func { + args: min_args, + variadic, + .. + }, + ) => { + // Check if the call is valid: + // - For non-variadic: arity must match exactly + // - For variadic: arity must be >= min_args + let valid = if *variadic { + arity >= min_args + } else { + arity == min_args + }; + + if valid { + Ok(()) + } else { + let expected_msg = if *variadic { + format!("at least {} arguments", min_args) + } else { + format!("exactly {} arguments", min_args) + }; + Err(LinkError::TypeMismatch { + name: metadata.name.to_string(), + expected: expected_msg, + found: format!("{} arguments provided", arity), + }) + } + } + (SymbolKind::Const, Symbol::Func { .. }) => Err(LinkError::TypeMismatch { + name: metadata.name.to_string(), + expected: "constant".to_string(), + found: "function".to_string(), + }), + (SymbolKind::Func { .. }, Symbol::Const { .. }) => Err(LinkError::TypeMismatch { + name: metadata.name.to_string(), + expected: "function".to_string(), + found: "constant".to_string(), + }), + } + } +} + +// ============================================================================ +// Program - After linking, ready to execute +// ============================================================================ + +impl<'src> Program<'src, Linked> { + // ======================================================================== + // Public API + // ======================================================================== + + /// Executes the program and returns the result. + pub fn execute(&self) -> Result { + Vm.run_bytecode(&self.state.bytecode, &self.state.symtable) + } + + /// Returns a reference to the symbol table. + pub fn symtable(&self) -> &SymTable { + &self.state.symtable + } + + /// Returns a mutable reference to the symbol table. + pub fn symtable_mut(&mut self) -> &mut SymTable { + &mut self.state.symtable + } + + /// Returns the version of this program. + pub fn version(&self) -> &str { + &self.state.version } /// Returns a human-readable assembly representation of the program. - pub fn get_assembly(&self, table: &crate::symbol::SymTable) -> String { + pub fn get_assembly(&self) -> String { + Self::format_assembly( + &self.state.version, + &self.state.bytecode, + &self.state.symtable, + ) + } + + /// Converts the program to bytecode bytes. + /// + /// This involves reverse-mapping the bytecode indices back to metadata indices. + pub fn to_bytecode(&self) -> Result, ProgramError> { + use std::collections::HashMap; + + let mut reverse_map = HashMap::new(); + let mut symbols = Vec::new(); + + // Helper closure to get or create metadata index + // All indices are valid since we successfully linked + let mut get_or_create_metadata = |idx: usize| -> usize { + if let Some(&existing) = reverse_map.get(&idx) { + existing + } else { + let symbol = self + .state + .symtable + .get_by_index(idx) + .expect("symbol index must be valid after linking"); + + let new_idx = symbols.len(); + symbols.push(symbol.into()); + reverse_map.insert(idx, new_idx); + new_idx + } + }; + + // Single pass: build symbol mapping and rewrite bytecode + let bytecode: Vec = self + .state + .bytecode + .iter() + .map(|instr| match instr { + Instr::Load(idx) => Instr::Load(get_or_create_metadata(*idx)), + Instr::Call(idx, argc) => Instr::Call(get_or_create_metadata(*idx), *argc), + other => other.clone(), + }) + .collect(); + + // Serialize + let binary = BinaryFormat { + version: self.state.version.clone(), + bytecode, + symbols, + }; + + let config = bincode::config::standard(); + Ok(bincode::serde::encode_to_vec(&binary, config)?) + } + + /// Saves the program bytecode to a file. + pub fn save_bytecode_to_file( + &self, + path: impl AsRef, + ) -> Result<(), ProgramError> { + let bytecode = self.to_bytecode()?; + std::fs::write(path, bytecode)?; + Ok(()) + } + + // ======================================================================== + // Private helpers + // ======================================================================== + + /// Formats bytecode as human-readable assembly. + fn format_assembly(version: &str, bytecode: &[Instr], table: &SymTable) -> String { use std::fmt::Write as _; let mut out = String::new(); - out += &format!("; VERSION {}\n", self.version) + out += &format!("; VERSION {}\n", version) .bright_black() .to_string(); - let emit = |mnemonic: &str| -> String { format!("{}", mnemonic.magenta()) }; - let emit1 = |mnemonic: &str, op: &str| -> String { - format!("{} {}", mnemonic.magenta(), op.green()) - }; - - for (i, instr) in self.code.iter().enumerate() { + for (i, instr) in bytecode.iter().enumerate() { let _ = write!(out, "{} ", format!("{:04X}", i).yellow()); let line = match instr { - Instr::Push(v) => emit1("PUSH", &v.to_string().green()), + Instr::Push(v) => format!("{} {}", "PUSH".magenta(), v.to_string().green()), Instr::Load(idx) => { let sym_name = table.get_by_index(*idx).map(|s| s.name()).unwrap_or("???"); - emit1("LOAD", &sym_name.blue()) + format!("{} {}", "LOAD".magenta(), sym_name.blue()) } - Instr::Neg => emit("NEG"), - Instr::Add => emit("ADD"), - Instr::Sub => emit("SUB"), - Instr::Mul => emit("MUL"), - Instr::Div => emit("DIV"), - Instr::Pow => emit("POW"), - Instr::Fact => emit("FACT"), + Instr::Neg => format!("{}", "NEG".magenta()), + Instr::Add => format!("{}", "ADD".magenta()), + Instr::Sub => format!("{}", "SUB".magenta()), + Instr::Mul => format!("{}", "MUL".magenta()), + Instr::Div => format!("{}", "DIV".magenta()), + Instr::Pow => format!("{}", "POW".magenta()), + Instr::Fact => format!("{}", "FACT".magenta()), Instr::Call(idx, argc) => { let sym_name = table.get_by_index(*idx).map(|s| s.name()).unwrap_or("???"); format!( "{} {} args: {}", - emit("CALL"), + "CALL".magenta(), sym_name.cyan(), argc.to_string().bright_blue() ) } - Instr::Equal => emit("EQ"), - Instr::NotEqual => emit("NEQ"), - Instr::Less => emit("LT"), - Instr::LessEqual => emit("LTE"), - Instr::Greater => emit("GT"), - Instr::GreaterEqual => emit("GTE"), + Instr::Equal => format!("{}", "EQ".magenta()), + Instr::NotEqual => format!("{}", "NEQ".magenta()), + Instr::Less => format!("{}", "LT".magenta()), + Instr::LessEqual => format!("{}", "LTE".magenta()), + Instr::Greater => format!("{}", "GT".magenta()), + Instr::GreaterEqual => format!("{}", "GTE".magenta()), }; let _ = writeln!(out, "{}", line); } diff --git a/lib/src/sema.rs b/lib/src/sema.rs deleted file mode 100644 index 8ed80c8..0000000 --- a/lib/src/sema.rs +++ /dev/null @@ -1,180 +0,0 @@ -use crate::ast::*; -use crate::span::{Span, SpanError}; -use crate::symbol::{SymTable, Symbol}; -use thiserror::Error; - -/// Expression parsing and evaluation errors. -#[derive(Error, Debug, Clone)] -pub enum SemaError { - #[error("Undefined symbol '{name}'")] - UndefinedSymbol { name: String, span: Span }, - #[error("Symbol '{name}' is not a constant")] - SymbolIsNotAConstant { name: String, span: Span }, - #[error("Symbol '{name}' is not a function")] - SymbolIsNotAFunction { name: String, span: Span }, - #[error("Function '{name}' expects exactly {expected} arguments but got {got}")] - ArgumentCountMismatch { - name: String, - expected: usize, - got: usize, - span: Span, - }, - #[error("Function '{name}' expects at least {min} arguments but got {got}")] - InsufficientArguments { - name: String, - min: usize, - got: usize, - span: Span, - }, -} - -impl SpanError for SemaError { - fn span(&self) -> Span { - match self { - SemaError::UndefinedSymbol { span, .. } => *span, - SemaError::SymbolIsNotAConstant { span, .. } => *span, - SemaError::SymbolIsNotAFunction { span, .. } => *span, - SemaError::ArgumentCountMismatch { span, .. } => *span, - SemaError::InsufficientArguments { span, .. } => *span, - } - } -} - -/// Semantic analyzer for type checking and symbol resolution. -/// -/// Validates that identifiers reference valid symbols and that function -/// calls have the correct number of arguments. -#[derive(Debug)] -pub struct Sema<'sym> { - table: &'sym SymTable, -} - -impl<'src, 'sym> Sema<'sym> { - /// Creates a new semantic analyzer with the given symbol table. - pub fn new(table: &'sym SymTable) -> Self { - Self { table } - } - - /// Analyzes an AST expression, resolving symbols and checking types. - pub fn visit(&mut self, ast: &mut Expr<'src>) -> Result<(), SemaError> { - match &mut ast.kind { - ExprKind::Literal(_) => Ok(()), - ExprKind::Ident { name, sym_index } => self.visit_ident(name, sym_index, ast.span), - ExprKind::Unary { op: _, expr } => self.visit_unary(expr), - ExprKind::Binary { op: _, left, right } => self.visit_binary(left, right), - ExprKind::Call { - name, - args, - sym_index, - } => self.visit_call(name, args, sym_index, ast.span), - } - } - - fn visit_ident( - &mut self, - name: &str, - sym_index: &mut Option, - span: Span, - ) -> Result<(), SemaError> { - let (idx, sym) = self.get_symbol_with_index(name, span)?; - - let Symbol::Const { .. } = sym else { - return Err(SemaError::SymbolIsNotAConstant { - name: name.to_string(), - span, - }); - }; - - *sym_index = Some(idx); - Ok(()) - } - - fn visit_unary(&mut self, expr: &mut Expr<'src>) -> Result<(), SemaError> { - self.visit(expr) - } - - fn visit_binary( - &mut self, - left: &mut Expr<'src>, - right: &mut Expr<'src>, - ) -> Result<(), SemaError> { - self.visit(left)?; - self.visit(right) - } - - fn visit_call( - &mut self, - name: &str, - args: &mut Vec>, - sym_index: &mut Option, - span: Span, - ) -> Result<(), SemaError> { - // span here will include a whole call expression, - // but is guaranteed to start with the symbol - let sym_span = Span::new(span.start, span.start + name.len()); - let (idx, sym) = self.get_symbol_with_index(name, sym_span)?; - - let Symbol::Func { - args: min_args, - variadic, - .. - } = sym - else { - return Err(SemaError::SymbolIsNotAFunction { - name: name.to_string(), - span: sym_span, - }); - }; - - self.validate_arity(name, args.len(), *min_args, *variadic, span)?; - self.analyse_arguments(args)?; - - *sym_index = Some(idx); - Ok(()) - } - - fn validate_arity( - &self, - name: &str, - args: usize, - min_args: usize, - variadic: bool, - span: Span, - ) -> Result<(), SemaError> { - if args == min_args || variadic && args > min_args { - return Ok(()); - } - if variadic { - Err(SemaError::InsufficientArguments { - name: name.to_string(), - min: min_args, - got: args, - span, - }) - } else { - Err(SemaError::ArgumentCountMismatch { - name: name.to_string(), - expected: min_args, - got: args, - span, - }) - } - } - - fn analyse_arguments(&mut self, args: &mut [Expr<'src>]) -> Result<(), SemaError> { - args.iter_mut().try_for_each(|a| self.visit(a)) - } - - fn get_symbol_with_index( - &self, - name: &str, - span: Span, - ) -> Result<(usize, &Symbol), SemaError> { - self.table - .get_with_index(name) - .ok_or_else(|| SemaError::UndefinedSymbol { - name: name.to_string(), - span, - }) - } -} diff --git a/lib/src/source.rs b/lib/src/source.rs deleted file mode 100644 index ff891d3..0000000 --- a/lib/src/source.rs +++ /dev/null @@ -1,54 +0,0 @@ -use crate::span::Span; -use colored::Colorize; -use unicode_width::UnicodeWidthStr; - -/// Source code container with input validation and error highlighting. -#[derive(Debug, Clone)] -pub struct Source<'str> { - pub input: &'str str, -} - -impl<'str> Source<'str> { - /// Creates a new source from an input string. - /// - /// The input is trimmed of leading and trailing whitespace. - pub fn new(input: &'str str) -> Self { - let trimmed = input.trim(); - Self { input: trimmed } - } - - /// Returns a formatted string with syntax highlighting for the given span. - /// - /// The output includes a caret and squiggly line pointing to the error location. - pub fn highlight(&self, span: &Span) -> String { - let input = &self.input; - let pre = Self::escape(&input[..span.start]); - let tok = Self::escape(&input[span.start..span.end]); - let post = Self::escape(&input[span.end..]); - let line = format!("{}{}{}", pre, tok.red().bold(), post); - - let caret = "^".green().bold(); - let squiggly_len = UnicodeWidthStr::width(tok.as_str()); - let caret_offset = UnicodeWidthStr::width(pre.as_str()) + caret.len(); - - format!( - "1 | {0}\n | {1: >2$}{3}", - line, - caret, - caret_offset, - "~".repeat(squiggly_len.saturating_sub(1)).green() - ) - } - - fn escape(s: &str) -> String { - let mut out = String::with_capacity(s.len()); - for c in s.chars() { - match c { - '\n' => out.push_str("\\n"), - '\r' => out.push_str("\\r"), - other => out.push(other), - } - } - out - } -} diff --git a/lib/src/token.rs b/lib/src/token.rs index 3a1802e..ca33b43 100644 --- a/lib/src/token.rs +++ b/lib/src/token.rs @@ -23,7 +23,7 @@ pub enum Token<'src> { LessEqual, // <= Greater, // > GreaterEqual, // >= - EOF, + Eof, Invalid(&'src str), } @@ -73,7 +73,7 @@ impl<'src> Token<'src> { Token::LessEqual => Borrowed("<="), Token::Greater => Borrowed(">"), Token::GreaterEqual => Borrowed(">="), - Token::EOF => Borrowed("EOF"), + Token::Eof => Borrowed("EOF"), Token::Invalid(str) => match *str { "\n" => Borrowed("\\n"), "\r" => Borrowed("\\r"), diff --git a/lib/src/vm.rs b/lib/src/vm.rs index 3a9a1e4..c5f6ba4 100644 --- a/lib/src/vm.rs +++ b/lib/src/vm.rs @@ -1,5 +1,4 @@ use crate::ir::Instr; -use crate::program::Program; use crate::symbol::{FuncError, SymTable, Symbol}; use rust_decimal::Decimal; use rust_decimal::prelude::*; @@ -47,7 +46,7 @@ pub enum VmError { pub struct Vm; impl Vm { - /// Executes a program and returns the result. + /// Executes bytecode directly and returns the result. /// /// # Errors /// @@ -57,14 +56,14 @@ impl Vm { /// - Invalid operations (e.g., factorial of non-integer) /// - Function errors /// - Invalid symbol indices - pub fn run(&self, prog: &Program, table: &SymTable) -> Result { - if prog.code.is_empty() { + pub fn run_bytecode(&self, bytecode: &[Instr], table: &SymTable) -> Result { + if bytecode.is_empty() { return Ok(Decimal::ZERO); } let mut stack: Vec = Vec::new(); - for op in &prog.code { + for op in bytecode { self.execute_instruction(op, table, &mut stack)?; } @@ -287,45 +286,37 @@ mod tests { use super::*; use crate::symbol::SymTable; - fn make(code: Vec) -> Program { - let mut program = Program::new(); - program.code = code; - program - } - #[test] fn test_vm_error_stack_underflow() { - let vm = Vm::default(); + let vm = Vm; let table = SymTable::stdlib(); - let program = make( - vec![Instr::Add], // No values on stack - ); + let bytecode = vec![Instr::Add]; // No values on stack - let result = vm.run(&program, &table); + let result = vm.run_bytecode(&bytecode, &table); assert!(matches!(result, Err(VmError::StackUnderflow))); } #[test] fn test_vm_error_division_by_zero() { - let vm = Vm::default(); + let vm = Vm; let table = SymTable::stdlib(); - let program = make(vec![Instr::Push(dec!(5)), Instr::Push(dec!(0)), Instr::Div]); + let bytecode = vec![Instr::Push(dec!(5)), Instr::Push(dec!(0)), Instr::Div]; - let result = vm.run(&program, &table); + let result = vm.run_bytecode(&bytecode, &table); assert!(matches!(result, Err(VmError::DivisionByZero))); } #[test] fn test_vm_error_invalid_final_stack() { - let vm = Vm::default(); + let vm = Vm; let table = SymTable::stdlib(); - let program = make(vec![ + let bytecode = vec![ Instr::Push(dec!(1)), Instr::Push(dec!(2)), // No operation to combine them - ]); + ]; - let result = vm.run(&program, &table); + let result = vm.run_bytecode(&bytecode, &table); assert!(matches!( result, Err(VmError::InvalidFinalStack { count: 2 }) @@ -334,15 +325,13 @@ mod tests { #[test] fn test_vm_error_invalid_load() { - let vm = Vm::default(); + let vm = Vm; let table = SymTable::stdlib(); let (sin_idx, _) = table.get_with_index("sin").unwrap(); - let program = make( - vec![Instr::Load(sin_idx)], // Trying to load a function as constant - ); + let bytecode = vec![Instr::Load(sin_idx)]; // Trying to load a function as constant - let result = vm.run(&program, &table); + let result = vm.run_bytecode(&bytecode, &table); assert!(matches!( result, Err(VmError::InvalidLoad { symbol_name: _ }) @@ -351,15 +340,13 @@ mod tests { #[test] fn test_vm_error_invalid_call() { - let vm = Vm::default(); + let vm = Vm; let table = SymTable::stdlib(); let (pi_idx, _) = table.get_with_index("pi").unwrap(); - let program = make( - vec![Instr::Call(pi_idx, 0)], // Trying to call a constant as function - ); + let bytecode = vec![Instr::Call(pi_idx, 0)]; // Trying to call a constant as function - let result = vm.run(&program, &table); + let result = vm.run_bytecode(&bytecode, &table); assert!(matches!( result, Err(VmError::InvalidCall { symbol_name: _ }) @@ -368,15 +355,13 @@ mod tests { #[test] fn test_vm_error_call_stack_underflow() { - let vm = Vm::default(); + let vm = Vm; let table = SymTable::stdlib(); let (sin_idx, _) = table.get_with_index("sin").unwrap(); - let program = make( - vec![Instr::Call(sin_idx, 0)], // No arguments for sin function - ); + let bytecode = vec![Instr::Call(sin_idx, 0)]; // No arguments for sin function - let result = vm.run(&program, &table); + let result = vm.run_bytecode(&bytecode, &table); assert!(matches!( result, Err(VmError::CallStackUnderflow { @@ -425,7 +410,7 @@ mod tests { #[test] fn test_binary_operations() { - let vm = Vm::default(); + let vm = Vm; let table = SymTable::stdlib(); // Test all binary operations @@ -445,8 +430,7 @@ mod tests { ]; for (code, expected) in test_cases { - let program = make(code); - assert_eq!(vm.run(&program, &table).unwrap(), expected); + assert_eq!(vm.run_bytecode(&code, &table).unwrap(), expected); } } } diff --git a/lib/tests/integration_tests.rs b/lib/tests/integration_tests.rs index d3d1b95..a15ea20 100644 --- a/lib/tests/integration_tests.rs +++ b/lib/tests/integration_tests.rs @@ -1,24 +1,22 @@ -use expr_solver::{Eval, SymTable}; +use expr_solver::{SymTable, eval, eval_with_table, load, load_with_table}; use indoc::indoc; use rust_decimal::{Decimal, MathematicalOps}; use rust_decimal_macros::dec; // Helper function to evaluate an expression and expect an Ok result. fn eval_ok(expr: &str) -> Decimal { - let mut eval = Eval::new(expr); - eval.run().expect("Evaluation should be successful") + eval(expr).expect("Evaluation should be successful") } // Helper function to evaluate an expression and expect an Err result. fn eval_err(expr: &str) -> String { colored::control::set_override(false); - let mut eval = Eval::new(expr); - eval.run().expect_err("Evaluation should fail") + eval(expr).expect_err("Evaluation should fail") } // Helper function to evaluate an expression with a custom symbol table and expect an Ok result. fn eval_with_custom_table_ok(expr: &str, table: SymTable) -> Decimal { - Eval::evaluate_with_table(expr, table).expect("Evaluation should be successful") + eval_with_table(expr, table).expect("Evaluation should be successful") } #[test] @@ -128,47 +126,47 @@ fn test_custom_symbols() { #[rustfmt::skip] fn test_syntax_errors() { assert_eq!(eval_err("1 + * 2"), indoc! {r#" - Unexpected token '*', expected 'an expression' + Unexpected token: unexpected token '*', expected an expression 1 | 1 + * 2 | ^"# }); assert_eq!(eval_err("(1 + 2"), indoc! {r#" - Unexpected token 'EOF', expected ')' + Unexpected token: unexpected token 'EOF', expected ')' 1 | (1 + 2 | ^"# }); assert_eq!(eval_err("1 2"), indoc! {r#" - Unexpected token '2', expected 'EOF' + Unexpected token: unexpected token '2', expected 'EOF' 1 | 1 2 | ^"# }); assert_eq!(eval_err("()"), indoc! {r#" - Unexpected token ')', expected 'an expression' + Unexpected token: unexpected token ')', expected an expression 1 | () | ^"# }); assert_eq!(eval_err("sin("), indoc! {r#" - Unexpected token 'EOF', expected 'an expression' + Unexpected token: unexpected token 'EOF', expected an expression 1 | sin( | ^"# }); assert_eq!(eval_err("1 + "), indoc! {r#" - Unexpected token 'EOF', expected 'an expression' + Unexpected token: unexpected token 'EOF', expected an expression 1 | 1 + | ^"# }); assert_eq!(eval_err("* 2"), indoc! {r#" - Unexpected token '*', expected 'an expression' + Unexpected token: unexpected token '*', expected an expression 1 | * 2 | ^"# }); assert_eq!(eval_err("1 (2 + 3)"), indoc! {r#" - Unexpected token '(', expected 'EOF' + Unexpected token: unexpected token '(', expected 'EOF' 1 | 1 (2 + 3) | ^"# }); assert_eq!(eval_err("sin 1"), indoc! {r#" - Unexpected token '1', expected 'EOF' + Unexpected token: unexpected token '1', expected 'EOF' 1 | sin 1 | ^"# }); @@ -177,51 +175,16 @@ fn test_syntax_errors() { #[test] #[rustfmt::skip] fn test_semantic_errors() { - assert_eq!(eval_err("foo()"), indoc! {r#" - Undefined symbol 'foo' - 1 | foo() - | ^~~"# - }); - assert_eq!(eval_err("🙈🍅🎉🌴🎶()"), indoc! {r#" - Undefined symbol '🙈🍅🎉🌴🎶' - 1 | 🙈🍅🎉🌴🎶() - | ^~~~~~~~~~"# - }); - assert_eq!(eval_err("bar"), indoc! {r#" - Undefined symbol 'bar' - 1 | bar - | ^~~"# - }); - assert_eq!(eval_err("sin(1, 2)"), indoc! {r#" - Function 'sin' expects exactly 1 arguments but got 2 - 1 | sin(1, 2) - | ^~~~~~~~~"# - }); - assert_eq!(eval_err("max()"), indoc! {r#" - Function 'max' expects at least 1 arguments but got 0 - 1 | max() - | ^~~~~"# - }); - assert_eq!(eval_err("pi()"), indoc! {r#" - Symbol 'pi' is not a function - 1 | pi() - | ^~"# - }); - assert_eq!(eval_err("1 + sin"), indoc! {r#" - Symbol 'sin' is not a constant - 1 | 1 + sin - | ^~~"# - }); - assert_eq!(eval_err("avg()"), indoc! {r#" - Function 'avg' expects at least 1 arguments but got 0 - 1 | avg() - | ^~~~~"# - }); - assert_eq!(eval_err("clamp(1, 2)"), indoc! {r#" - Function 'clamp' expects exactly 3 arguments but got 2 - 1 | clamp(1, 2) - | ^~~~~~~~~~~"# - }); + // V2 defers validation to link time, so we get link errors instead of semantic errors + assert_eq!(eval_err("foo()"), "Link error: Missing symbol: 'foo' is required by bytecode but not in symbol table"); + assert_eq!(eval_err("🙈🍅🎉🌴🎶()"), "Link error: Missing symbol: '🙈🍅🎉🌴🎶' is required by bytecode but not in symbol table"); + assert_eq!(eval_err("bar"), "Link error: Missing symbol: 'bar' is required by bytecode but not in symbol table"); + assert_eq!(eval_err("sin(1, 2)"), "Link error: Type mismatch for symbol 'sin': expected exactly 1 arguments, found 2 arguments provided"); + assert_eq!(eval_err("max()"), "Link error: Type mismatch for symbol 'max': expected at least 1 arguments, found 0 arguments provided"); + assert_eq!(eval_err("pi()"), "Link error: Type mismatch for symbol 'pi': expected function, found constant"); + assert_eq!(eval_err("1 + sin"), "Link error: Type mismatch for symbol 'sin': expected constant, found function"); + assert_eq!(eval_err("avg()"), "Link error: Type mismatch for symbol 'avg': expected at least 1 arguments, found 0 arguments provided"); + assert_eq!(eval_err("clamp(1, 2)"), "Link error: Type mismatch for symbol 'clamp': expected exactly 3 arguments, found 2 arguments provided"); } #[test] @@ -278,14 +241,102 @@ fn test_if_function() { #[test] #[rustfmt::skip] fn test_if_function_semantic_errors() { - assert_eq!(eval_err("if(1, 2)"), indoc! {r#" - Function 'if' expects exactly 3 arguments but got 2 - 1 | if(1, 2) - | ^~~~~~~~"# - }); - assert_eq!(eval_err("if(1, 2, 3, 4)"), indoc! {r#" - Function 'if' expects exactly 3 arguments but got 4 - 1 | if(1, 2, 3, 4) - | ^~~~~~~~~~~~~~"# - }); + // V2 defers validation to link time + assert_eq!(eval_err("if(1, 2)"), "Link error: Type mismatch for symbol 'if': expected exactly 3 arguments, found 2 arguments provided"); + assert_eq!(eval_err("if(1, 2, 3, 4)"), "Link error: Type mismatch for symbol 'if': expected exactly 3 arguments, found 4 arguments provided"); +} + +// ==================== +// Program API Tests +// ==================== + +#[test] +fn test_program_basic_arithmetic() { + let program = load_with_table("2 + 3 * 4", SymTable::stdlib()).expect("link failed"); + + let result = program.execute().expect("execution failed"); + assert_eq!(result, dec!(14)); +} + +#[test] +fn test_program_with_constants() { + let program = load_with_table("pi * 2", SymTable::stdlib()).expect("link failed"); + + let result = program.execute().expect("execution failed"); + // pi * 2 ≈ 6.28... + assert!(result > dec!(6.28) && result < dec!(6.29)); +} + +#[test] +fn test_program_with_functions() { + let program = load_with_table("sqrt(16) + sin(0)", SymTable::stdlib()).expect("link failed"); + + let result = program.execute().expect("execution failed"); + assert_eq!(result, dec!(4)); // sqrt(16) + sin(0) = 4 + 0 = 4 +} + +#[test] +fn test_program_symtable_mutation() { + let program = load("x + y").expect("compilation failed"); + + // Create symbol table with x and y + let mut table = SymTable::new(); + table.add_const("x", dec!(10)).unwrap(); + table.add_const("y", dec!(20)).unwrap(); + + let mut program = program.link(table).expect("link failed"); + + // First execution + let result = program.execute().expect("execution failed"); + assert_eq!(result, dec!(30)); + + // Modify symbol table + program.symtable_mut().add_const("z", dec!(100)).unwrap(); + + // Execute again (x + y should still be 30) + let result = program.execute().expect("execution failed"); + assert_eq!(result, dec!(30)); +} + +#[test] +fn test_program_serialization() { + let program = load_with_table("sqrt(pi) + 2", SymTable::stdlib()).expect("link failed"); + + // Execute original + let result1 = program.execute().expect("execution failed"); + + // Serialize + let bytes = program.to_bytecode().expect("serialization failed"); + + // Deserialize and re-link + use expr_solver::Program; + let program2 = Program::new_from_bytecode(&bytes) + .expect("deserialization failed") + .link(SymTable::stdlib()) + .expect("link failed"); + + // Execute deserialized + let result2 = program2.execute().expect("execution failed"); + + assert_eq!(result1, result2); +} + +#[test] +fn test_program_get_assembly() { + let program = load_with_table("2 + 3", SymTable::stdlib()).expect("link failed"); + + let assembly = program.get_assembly(); + assert!(assembly.contains("PUSH")); + assert!(assembly.contains("ADD")); +} + +#[test] +fn test_program_link_validation() { + let program = load("x + y").expect("compilation failed"); + + // Try to link with empty symbol table (should fail) + let empty_table = SymTable::new(); + let result = program.link(empty_table); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Missing symbol")); }