diff --git a/scripts/for1.rhai b/scripts/for1.rhai new file mode 100644 index 00000000..ea3675b7 --- /dev/null +++ b/scripts/for1.rhai @@ -0,0 +1,15 @@ +let arr = [1,2,3,4] +for a in arr { + for b in [10,20] { + print(a) + print(b) + } + if a == 3 { + break; + } +} +//print(a) + +for i in range(0,5) { + print(i) +} \ No newline at end of file diff --git a/src/any.rs b/src/any.rs index 97f5bf57..b6913517 100644 --- a/src/any.rs +++ b/src/any.rs @@ -76,7 +76,7 @@ impl Clone for Box { impl fmt::Debug for dyn Any { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.pad("Any") + f.pad("?") } } diff --git a/src/engine.rs b/src/engine.rs index 6a1543af..32061dc3 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -4,20 +4,25 @@ use std::collections::HashMap; use std::error::Error; use std::fmt; use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Rem, Shl, Shr, Sub}; -use std::sync::Arc; +use std::{convert::TryInto, sync::Arc}; use crate::any::{Any, AnyExt}; use crate::call::FunArgs; -use crate::fn_register::RegisterFn; +use crate::fn_register::{RegisterBoxFn, RegisterFn}; use crate::parser::{lex, parse, Expr, FnDef, Stmt}; +use fmt::{Debug, Display}; + +type Array = Vec>; #[derive(Debug)] pub enum EvalAltResult { ErrorFunctionNotFound(String), ErrorFunctionArgMismatch, ErrorArrayOutOfBounds(usize, i64), + ErrorArrayMismatch, ErrorIndexMismatch, ErrorIfGuardMismatch, + ErrorForMismatch, ErrorVariableNotFound(String), ErrorAssignmentToUnknownLHS, ErrorMismatchOutputType(String), @@ -46,10 +51,12 @@ impl PartialEq for EvalAltResult { (&ErrorFunctionNotFound(ref a), &ErrorFunctionNotFound(ref b)) => a == b, (&ErrorFunctionArgMismatch, &ErrorFunctionArgMismatch) => true, (&ErrorIndexMismatch, &ErrorIndexMismatch) => true, + (&ErrorArrayMismatch, &ErrorArrayMismatch) => true, (&ErrorArrayOutOfBounds(max1, index1), &ErrorArrayOutOfBounds(max2, index2)) => { max1 == max2 && index1 == index2 } (&ErrorIfGuardMismatch, &ErrorIfGuardMismatch) => true, + (&ErrorForMismatch, &ErrorForMismatch) => true, (&ErrorVariableNotFound(ref a), &ErrorVariableNotFound(ref b)) => a == b, (&ErrorAssignmentToUnknownLHS, &ErrorAssignmentToUnknownLHS) => true, (&ErrorMismatchOutputType(ref a), &ErrorMismatchOutputType(ref b)) => a == b, @@ -67,12 +74,14 @@ impl Error for EvalAltResult { EvalAltResult::ErrorFunctionNotFound(_) => "Function not found", EvalAltResult::ErrorFunctionArgMismatch => "Function argument types do not match", EvalAltResult::ErrorIndexMismatch => "Array access expects integer index", + EvalAltResult::ErrorArrayMismatch => "Indexing can only be performed on an array", EvalAltResult::ErrorArrayOutOfBounds(_, index) if index < 0 => { "Array access expects non-negative index" } EvalAltResult::ErrorArrayOutOfBounds(max, _) if max == 0 => "Access of empty array", EvalAltResult::ErrorArrayOutOfBounds(_, _) => "Array index out of bounds", EvalAltResult::ErrorIfGuardMismatch => "If guards expect boolean expression", + EvalAltResult::ErrorForMismatch => "For loops expect array", EvalAltResult::ErrorVariableNotFound(_) => "Variable not found", EvalAltResult::ErrorAssignmentToUnknownLHS => { "Assignment to an unsupported left-hand side" @@ -107,7 +116,7 @@ impl fmt::Display for EvalAltResult { EvalAltResult::ErrorArrayOutOfBounds(max, index) => { write!(f, "{} (max {}): {}", self.description(), max - 1, index) } - _ => write!(f, "{}", self.description()), + err => write!(f, "{}", err.description()), } } } @@ -119,6 +128,8 @@ pub struct FnSpec { args: Option>, } +type IteratorFn = dyn Fn(&Box) -> Box>>; + /// Rhai's engine type. This is what you use to run Rhai scripts /// /// ```rust @@ -137,6 +148,7 @@ pub struct FnSpec { pub struct Engine { /// A hashmap containing all functions known to the engine pub fns: HashMap>, + pub type_iterators: HashMap>, } pub enum FnIntExt { @@ -187,16 +199,14 @@ impl Engine { debug_println!( "Trying to call function {:?} with args {:?}", ident, - args.iter().map(|x| (&**x).type_id()).collect::>() + args.iter() + .map(|x| Any::type_name(&**x)) + .collect::>() ); let spec = FnSpec { ident: ident.clone(), - args: Some( - args.iter() - .map(|a| ::type_id(&**a)) - .collect(), - ), + args: Some(args.iter().map(|a| Any::type_id(&**a)).collect()), }; self.fns @@ -248,6 +258,14 @@ impl Engine { // currently a no-op, exists for future extensibility } + /// Register an iterator adapter for a type. + pub fn register_iterator(&mut self, f: F) + where + F: 'static + Fn(&Box) -> Box>>, + { + self.type_iterators.insert(TypeId::of::(), Arc::new(f)); + } + /// Register a get function for a member of a registered type pub fn register_get(&mut self, name: &str, get_fn: F) where @@ -290,7 +308,7 @@ impl Engine { match *dot_rhs { Expr::FnCall(ref fn_name, ref args) => { - let mut args: Vec> = args + let mut args: Array = args .iter() .map(|arg| self.eval_expr(scope, arg)) .collect::, _>>()?; @@ -311,9 +329,13 @@ impl Engine { let mut val = self.call_fn_raw(get_fn_name, vec![this_ptr])?; - ((*val).downcast_mut() as Option<&mut Vec>>) - .and_then(|arr| idx.downcast_ref::().map(|idx| (arr, *idx))) - .ok_or(EvalAltResult::ErrorIndexMismatch) + ((*val).downcast_mut() as Option<&mut Array>) + .ok_or(EvalAltResult::ErrorArrayMismatch) + .and_then(|arr| { + idx.downcast_ref::() + .map(|idx| (arr, *idx)) + .ok_or(EvalAltResult::ErrorIndexMismatch) + }) .and_then(|(arr, idx)| match idx { x if x < 0 => Err(EvalAltResult::ErrorArrayOutOfBounds(0, x)), x => arr @@ -367,8 +389,8 @@ impl Engine { x => x as usize, }; let (idx_sc, val) = Self::search_scope(scope, id, |val| { - ((*val).downcast_mut() as Option<&mut Vec>>) - .ok_or(EvalAltResult::ErrorIndexMismatch) + ((*val).downcast_mut() as Option<&mut Array>) + .ok_or(EvalAltResult::ErrorArrayMismatch) .and_then(|arr| { arr.get(idx) .cloned() @@ -402,7 +424,7 @@ impl Engine { // In case the expression mutated `target`, we need to reassign it because // of the above `clone`. - scope[sc_idx].1.downcast_mut::>>().unwrap()[idx] = target; + scope[sc_idx].1.downcast_mut::().unwrap()[idx] = target; value } @@ -465,7 +487,7 @@ impl Engine { // In case the expression mutated `target`, we need to reassign it because // of the above `clone`. - scope[sc_idx].1.downcast_mut::>>().unwrap()[idx] = target; + scope[sc_idx].1.downcast_mut::().unwrap()[idx] = target; value } @@ -509,24 +531,27 @@ impl Engine { for &mut (ref name, ref mut val) in &mut scope.iter_mut().rev() { if *id == *name { - if let Some(&i) = idx.downcast_ref::() { + return if let Some(&i) = idx.downcast_ref::() { if let Some(arr_typed) = - (*val).downcast_mut() as Option<&mut Vec>> + (*val).downcast_mut() as Option<&mut Array> { - return if i < 0 { + if i < 0 { Err(EvalAltResult::ErrorArrayOutOfBounds(0, i)) } else if i as usize >= arr_typed.len() { - Err(EvalAltResult::ErrorArrayOutOfBounds(arr_typed.len(), i)) + Err(EvalAltResult::ErrorArrayOutOfBounds( + arr_typed.len(), + i, + )) } else { arr_typed[i as usize] = rhs_val; Ok(Box::new(())) - }; + } } else { - return Err(EvalAltResult::ErrorIndexMismatch); + Err(EvalAltResult::ErrorIndexMismatch) } } else { - return Err(EvalAltResult::ErrorIndexMismatch); - } + Err(EvalAltResult::ErrorIndexMismatch) + }; } } @@ -553,7 +578,7 @@ impl Engine { fn_name.to_owned(), args.iter() .map(|ex| self.eval_expr(scope, ex)) - .collect::>, _>>()? + .collect::>()? .iter_mut() .map(|b| b.as_mut()) .collect(), @@ -635,6 +660,26 @@ impl Engine { _ => (), } }, + Stmt::For(ref name, ref expr, ref body) => { + let arr = self.eval_expr(scope, expr)?; + let tid = Any::type_id(&*arr); + if let Some(iter_fn) = self.type_iterators.get(&tid) { + scope.push((name.clone(), Box::new(()))); + let idx = scope.len() - 1; + for a in iter_fn(&arr) { + scope[idx].1 = a; + match self.eval_stmt(scope, body) { + Err(EvalAltResult::LoopBreak) => break, + Err(x) => return Err(x), + _ => (), + } + } + scope.remove(idx); + Ok(Box::new(())) + } else { + return Err(EvalAltResult::ErrorForMismatch); + } + } Stmt::Break => Err(EvalAltResult::LoopBreak), Stmt::Return => Err(EvalAltResult::Return(Box::new(()))), Stmt::ReturnWithVal(ref a) => { @@ -824,6 +869,30 @@ impl Engine { ) } + macro_rules! reg_func1 { + ($engine:expr, $x:expr, $op:expr, $r:ty, $( $y:ty ),*) => ( + $( + $engine.register_fn($x, $op as fn(x: $y)->$r); + )* + ) + } + + macro_rules! reg_func2 { + ($engine:expr, $x:expr, $op:expr, $v:ty, $r:ty, $( $y:ty ),*) => ( + $( + $engine.register_fn($x, $op as fn(x: $v, y: $y)->$r); + )* + ) + } + + macro_rules! reg_func2b { + ($engine:expr, $x:expr, $op:expr, $v:ty, $r:ty, $( $y:ty ),*) => ( + $( + $engine.register_fn($x, $op as fn(y: $y, x: $v)->$r); + )* + ) + } + fn add(x: T, y: T) -> ::Output { x + y } @@ -936,12 +1005,91 @@ impl Engine { // FIXME? Registering array lookups are a special case because we want to return boxes // directly let ent = engine.fns.entry("[]".to_string()).or_insert_with(Vec::new); // (*ent).push(FnType::ExternalFn2(Box::new(idx))); + + // Register print and debug + fn print_debug(x: T) { + println!("{:?}", x); + } + fn print(x: T) { + println!("{}", x); + } + + reg_func1!(engine, "print", print, (), i32, i64, u32, u64); + reg_func1!(engine, "print", print, (), f32, f64, bool, String); + reg_func1!(engine, "print", print_debug, (), Array); + + reg_func1!(engine, "debug", print_debug, (), i32, i64, u32, u64); + reg_func1!(engine, "debug", print_debug, (), f32, f64, bool, String); + reg_func1!(engine, "debug", print_debug, (), Array); + + // Register array functions + fn push(list: &mut Array, item: T) { + list.push(Box::new(item)); + } + fn pop(list: &mut Array) -> Box { + list.pop().unwrap() + } + fn shift(list: &mut Array) -> Box { + list.remove(0) + } + fn len(list: &mut Array) -> i64 { + list.len().try_into().unwrap() + } + + reg_func2!(engine, "push", push, &mut Array, (), i32, i64, u32, u64); + reg_func2!(engine, "push", push, &mut Array, (), f32, f64, bool); + reg_func2!(engine, "push", push, &mut Array, (), String, Array); + + engine.register_box_fn("pop", pop); + engine.register_box_fn("shift", shift); + engine.register_fn("len", len); + + // Register string concatenate functions + fn prepend(x: T, y: String) -> String { + format!("{}{}", x, y) + } + fn append(x: &mut String, y: T) -> String { + format!("{}{}", x, y) + } + fn prepend_array(x: Array, y: String) -> String { + format!("{:?}{}", x, y) + } + fn append_array(x: &mut String, y: Array) -> String { + format!("{}{:?}", x, y) + } + + reg_func2!(engine, "+", append, &mut String, String, i32, i64); + reg_func2!(engine, "+", append, &mut String, String, u32, u64); + reg_func2!(engine, "+", append, &mut String, String, f32, f64, bool); + engine.register_fn("+", append_array); + + reg_func2b!(engine, "+", prepend, String, String, i32, i64, u32, u64, f32, f64, bool); + engine.register_fn("+", prepend_array); + + // Register array iterator + engine.register_iterator::(|a| { + Box::new(a.downcast_ref::().unwrap().clone().into_iter()) + }); + + // Register range function + use std::ops::Range; + engine.register_iterator::, _>(|a| { + Box::new( + a.downcast_ref::>() + .unwrap() + .clone() + .map(|n| Box::new(n) as Box), + ) + }); + + engine.register_fn("range", |i1: i64, i2: i64| (i1..i2)); } /// Make a new engine pub fn new() -> Engine { let mut engine = Engine { fns: HashMap::new(), + type_iterators: HashMap::new(), }; Engine::register_default_lib(&mut engine); diff --git a/src/fn_register.rs b/src/fn_register.rs index 0dcf6fe4..d1bb6b6e 100644 --- a/src/fn_register.rs +++ b/src/fn_register.rs @@ -6,6 +6,9 @@ use crate::engine::{Engine, EvalAltResult}; pub trait RegisterFn { fn register_fn(&mut self, name: &str, f: FN); } +pub trait RegisterBoxFn { + fn register_box_fn(&mut self, name: &str, f: FN); +} pub struct Ref(A); pub struct Mut(A); @@ -44,7 +47,37 @@ macro_rules! def_register { // Call the user-supplied function using ($clone) to // potentially clone the value, otherwise pass the reference. - Ok(Box::new(f($(($clone)($par)),*)) as Box) + let r = f($(($clone)($par)),*); + Ok(Box::new(r) as Box) + }; + self.register_fn_raw(name.to_owned(), Some(vec![$(TypeId::of::<$par>()),*]), Box::new(fun)); + } + } + + impl<$($par,)* FN> RegisterBoxFn for Engine + where + $($par: Any + Clone,)* + FN: Fn($($param),*) -> Box + 'static + { + fn register_box_fn(&mut self, name: &str, f: FN) { + let fun = move |mut args: Vec<&mut dyn Any>| { + // Check for length at the beginning to avoid + // per-element bound checks. + if args.len() != count_args!($($par)*) { + return Err(EvalAltResult::ErrorFunctionArgMismatch); + } + + #[allow(unused_variables, unused_mut)] + let mut drain = args.drain(..); + $( + // Downcast every element, return in case of a type mismatch + let $par = ((*drain.next().unwrap()).downcast_mut() as Option<&mut $par>) + .ok_or(EvalAltResult::ErrorFunctionArgMismatch)?; + )* + + // Call the user-supplied function using ($clone) to + // potentially clone the value, otherwise pass the reference. + Ok(f($(($clone)($par)),*)) }; self.register_fn_raw(name.to_owned(), Some(vec![$(TypeId::of::<$par>()),*]), Box::new(fun)); } diff --git a/src/lib.rs b/src/lib.rs index 6aa339bd..4c621796 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,4 +47,4 @@ mod parser; pub use any::Any; pub use engine::{Engine, EvalAltResult, Scope}; -pub use fn_register::RegisterFn; +pub use fn_register::{RegisterBoxFn, RegisterFn}; diff --git a/src/parser.rs b/src/parser.rs index 0e82f0f4..f524e8f0 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -89,6 +89,7 @@ pub enum Stmt { IfElse(Box, Box, Box), While(Box, Box), Loop(Box), + For(String, Box, Box), Var(String, Option>), Block(Vec), Expr(Box), @@ -175,6 +176,8 @@ pub enum Token { ModuloAssign, PowerOf, PowerOfAssign, + For, + In, LexErr(LexError), } @@ -230,6 +233,7 @@ impl Token { ModuloAssign | Return | PowerOf | + In | PowerOfAssign => true, _ => false, } @@ -505,6 +509,8 @@ impl<'a> TokenIterator<'a> { "break" => return Some(Token::Break), "return" => return Some(Token::Return), "fn" => return Some(Token::Fn), + "for" => return Some(Token::For), + "in" => return Some(Token::In), x => return Some(Token::Identifier(x.to_string())), } } @@ -1102,6 +1108,26 @@ fn parse_loop<'a>(input: &mut Peekable>) -> Result(input: &mut Peekable>) -> Result { + input.next(); + + let name = match input.next() { + Some(Token::Identifier(ref s)) => s.clone(), + _ => return Err(ParseError::VarExpectsIdentifier), + }; + + match input.next() { + Some(Token::In) => {} + _ => return Err(ParseError::VarExpectsIdentifier), + } + + let expr = parse_expr(input)?; + + let body = parse_block(input)?; + + Ok(Stmt::For(name, Box::new(expr), Box::new(body))) +} + fn parse_var<'a>(input: &mut Peekable>) -> Result { input.next(); @@ -1168,6 +1194,7 @@ fn parse_stmt<'a>(input: &mut Peekable>) -> Result parse_if(input), Some(&Token::While) => parse_while(input), Some(&Token::Loop) => parse_loop(input), + Some(&Token::For) => parse_for(input), Some(&Token::Break) => { input.next(); Ok(Stmt::Break) diff --git a/tests/for.rs b/tests/for.rs new file mode 100644 index 00000000..687c1b4b --- /dev/null +++ b/tests/for.rs @@ -0,0 +1,24 @@ +use rhai::Engine; + +#[test] +fn test_for() { + let mut engine = Engine::new(); + + let script = r" + let sum1 = 0; + let sum2 = 0; + let inputs = [1, 2, 3, 4, 5]; + + for x in inputs { + sum1 += x; + } + + for x in range(1, 6) { + sum2 += x; + } + + sum1 + sum2 + "; + + assert_eq!(engine.eval::(script).unwrap(), 30); +}