Add fallible functions support and replace most arithmetic operations with checked versions.

This commit is contained in:
Stephen Chung
2020-03-08 22:47:13 +08:00
parent 3e7adc2e51
commit b1b25d3043
9 changed files with 387 additions and 40 deletions

View File

@@ -3,9 +3,15 @@
use crate::any::Any;
use crate::engine::{Array, Engine};
use crate::fn_register::RegisterFn;
use crate::fn_register::{RegisterFn, RegisterResultFn};
use crate::parser::Position;
use crate::result::EvalAltResult;
use num_traits::{
CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedShl, CheckedShr, CheckedSub,
};
use std::convert::TryFrom;
use std::fmt::{Debug, Display};
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Range, Rem, Shl, Shr, Sub};
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Range, Rem, Sub};
macro_rules! reg_op {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
@@ -15,6 +21,22 @@ macro_rules! reg_op {
)
}
macro_rules! reg_op_result {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
$(
$self.register_result_fn($x, $op as fn(x: $y, y: $y)->Result<$y,EvalAltResult>);
)*
)
}
macro_rules! reg_op_result1 {
($self:expr, $x:expr, $op:expr, $v:ty, $( $y:ty ),*) => (
$(
$self.register_result_fn($x, $op as fn(x: $y, y: $v)->Result<$y,EvalAltResult>);
)*
)
}
macro_rules! reg_un {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
$(
@@ -23,6 +45,13 @@ macro_rules! reg_un {
)
}
macro_rules! reg_un_result {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
$(
$self.register_result_fn($x, $op as fn(x: $y)->Result<$y,EvalAltResult>);
)*
)
}
macro_rules! reg_cmp {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
$(
@@ -69,21 +98,96 @@ macro_rules! reg_func3 {
impl Engine<'_> {
/// Register the core built-in library.
pub(crate) fn register_core_lib(&mut self) {
fn add<T: Add>(x: T, y: T) -> <T as Add>::Output {
fn add<T: Display + CheckedAdd>(x: T, y: T) -> Result<T, EvalAltResult> {
x.checked_add(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Addition overflow: {} + {}", x, y),
Position::none(),
)
})
}
fn sub<T: Display + CheckedSub>(x: T, y: T) -> Result<T, EvalAltResult> {
x.checked_sub(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Subtraction underflow: {} - {}", x, y),
Position::none(),
)
})
}
fn mul<T: Display + CheckedMul>(x: T, y: T) -> Result<T, EvalAltResult> {
x.checked_mul(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Multiplication overflow: {} * {}", x, y),
Position::none(),
)
})
}
fn div<T>(x: T, y: T) -> Result<T, EvalAltResult>
where
T: Display + CheckedDiv + PartialEq + TryFrom<i8>,
{
if y == <T as TryFrom<i8>>::try_from(0)
.map_err(|_| ())
.expect("zero should always succeed")
{
return Err(EvalAltResult::ErrorArithmetic(
format!("Division by zero: {} / {}", x, y),
Position::none(),
));
}
x.checked_div(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Division overflow: {} / {}", x, y),
Position::none(),
)
})
}
fn neg<T: Display + CheckedNeg>(x: T) -> Result<T, EvalAltResult> {
x.checked_neg().ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Negation overflow: -{}", x),
Position::none(),
)
})
}
fn abs<T: Display + CheckedNeg + PartialOrd + From<i8>>(x: T) -> Result<T, EvalAltResult> {
if x >= 0.into() {
Ok(x)
} else {
x.checked_neg().ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Negation overflow: -{}", x),
Position::none(),
)
})
}
}
fn add_unchecked<T: Add>(x: T, y: T) -> <T as Add>::Output {
x + y
}
fn sub<T: Sub>(x: T, y: T) -> <T as Sub>::Output {
fn sub_unchecked<T: Sub>(x: T, y: T) -> <T as Sub>::Output {
x - y
}
fn mul<T: Mul>(x: T, y: T) -> <T as Mul>::Output {
fn mul_unchecked<T: Mul>(x: T, y: T) -> <T as Mul>::Output {
x * y
}
fn div<T: Div>(x: T, y: T) -> <T as Div>::Output {
fn div_unchecked<T: Div>(x: T, y: T) -> <T as Div>::Output {
x / y
}
fn neg<T: Neg>(x: T) -> <T as Neg>::Output {
fn neg_unchecked<T: Neg>(x: T) -> <T as Neg>::Output {
-x
}
fn abs_unchecked<T: Neg + PartialOrd + From<i8>>(x: T) -> T
where
<T as Neg>::Output: Into<T>,
{
if x < 0.into() {
(-x).into()
} else {
x
}
}
fn lt<T: PartialOrd>(x: T, y: T) -> bool {
x < y
}
@@ -120,13 +224,45 @@ impl Engine<'_> {
fn binary_xor<T: BitXor>(x: T, y: T) -> <T as BitXor>::Output {
x ^ y
}
fn left_shift<T: Shl<T>>(x: T, y: T) -> <T as Shl<T>>::Output {
x.shl(y)
fn left_shift<T: Display + CheckedShl>(x: T, y: i64) -> Result<T, EvalAltResult> {
if y < 0 {
return Err(EvalAltResult::ErrorArithmetic(
format!("Left-shift by a negative number: {} << {}", x, y),
Position::none(),
));
}
CheckedShl::checked_shl(&x, y as u32).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Left-shift overflow: {} << {}", x, y),
Position::none(),
)
})
}
fn right_shift<T: Shr<T>>(x: T, y: T) -> <T as Shr<T>>::Output {
x.shr(y)
fn right_shift<T: Display + CheckedShr>(x: T, y: i64) -> Result<T, EvalAltResult> {
if y < 0 {
return Err(EvalAltResult::ErrorArithmetic(
format!("Right-shift by a negative number: {} >> {}", x, y),
Position::none(),
));
}
CheckedShr::checked_shr(&x, y as u32).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Right-shift overflow: {} % {}", x, y),
Position::none(),
)
})
}
fn modulo<T: Rem<T>>(x: T, y: T) -> <T as Rem<T>>::Output {
fn modulo<T: Display + CheckedRem>(x: T, y: T) -> Result<T, EvalAltResult> {
x.checked_rem(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Modulo division overflow: {} % {}", x, y),
Position::none(),
)
})
}
fn modulo_unchecked<T: Rem>(x: T, y: T) -> <T as Rem>::Output {
x % y
}
fn pow_i64_i64(x: i64, y: i64) -> i64 {
@@ -139,10 +275,15 @@ impl Engine<'_> {
x.powi(y as i32)
}
reg_op!(self, "+", add, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64);
reg_op!(self, "-", sub, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64);
reg_op!(self, "*", mul, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64);
reg_op!(self, "/", div, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64);
reg_op_result!(self, "+", add, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result!(self, "-", sub, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result!(self, "*", mul, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result!(self, "/", div, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, "+", add_unchecked, f32, f64);
reg_op!(self, "-", sub_unchecked, f32, f64);
reg_op!(self, "*", mul_unchecked, f32, f64);
reg_op!(self, "/", div_unchecked, f32, f64);
reg_cmp!(self, "<", lt, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64, String, char);
reg_cmp!(self, "<=", lte, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64, String, char);
@@ -162,15 +303,19 @@ impl Engine<'_> {
reg_op!(self, "&", binary_and, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, "&", and, bool);
reg_op!(self, "^", binary_xor, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, "<<", left_shift, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, ">>", right_shift, i8, u8, i16, u16);
reg_op!(self, ">>", right_shift, i32, i64, u32, u64);
reg_op!(self, "%", modulo, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result1!(self, "<<", left_shift, i64, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result1!(self, ">>", right_shift, i64, i8, u8, i16, u16);
reg_op_result1!(self, ">>", right_shift, i64, i32, i64, u32, u64);
reg_op_result!(self, "%", modulo, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, "%", modulo_unchecked, f32, f64);
self.register_fn("~", pow_i64_i64);
self.register_fn("~", pow_f64_f64);
self.register_fn("~", pow_f64_i64);
reg_un!(self, "-", neg, i8, i16, i32, i64, f32, f64);
reg_un_result!(self, "-", neg, i8, i16, i32, i64);
reg_un!(self, "-", neg_unchecked, f32, f64);
reg_un_result!(self, "abs", abs, i8, i16, i32, i64);
reg_un!(self, "abs", abs_unchecked, f32, f64);
reg_un!(self, "!", not, bool);
self.register_fn("+", |x: String, y: String| x + &y); // String + String

View File

@@ -143,6 +143,9 @@ impl fmt::Display for ParseError {
if !self.1.is_eof() {
write!(f, " ({})", self.1)
} else if !self.1.is_none() {
// Do not write any position if None
Ok(())
} else {
write!(f, " at the end of the script but there is no more input")
}

View File

@@ -58,6 +58,32 @@ pub trait RegisterDynamicFn<FN, ARGS> {
fn register_dynamic_fn(&mut self, name: &str, f: FN);
}
/// A trait to register fallible custom functions returning Result<_, EvalAltResult> with the `Engine`.
///
/// # Example
///
/// ```rust
/// use rhai::{Engine, RegisterFn};
///
/// // Normal function
/// fn add(x: i64, y: i64) -> i64 {
/// x + y
/// }
///
/// let mut engine = Engine::new();
///
/// // You must use the trait rhai::RegisterFn to get this method.
/// engine.register_fn("add", add);
///
/// if let Ok(result) = engine.eval::<i64>("add(40, 2)") {
/// println!("Answer: {}", result); // prints 42
/// }
/// ```
pub trait RegisterResultFn<FN, ARGS, RET> {
/// Register a custom function with the `Engine`.
fn register_result_fn(&mut self, name: &str, f: FN);
}
pub struct Ref<A>(A);
pub struct Mut<A>(A);
@@ -91,7 +117,7 @@ macro_rules! def_register {
let mut drain = args.drain(..);
$(
// Downcast every element, return in case of a type mismatch
let $par = (drain.next().unwrap().downcast_mut() as Option<&mut $par>).unwrap();
let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap();
)*
// Call the user-supplied function using ($clone) to
@@ -123,7 +149,7 @@ macro_rules! def_register {
let mut drain = args.drain(..);
$(
// Downcast every element, return in case of a type mismatch
let $par = (drain.next().unwrap().downcast_mut() as Option<&mut $par>).unwrap();
let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap();
)*
// Call the user-supplied function using ($clone) to
@@ -135,6 +161,44 @@ macro_rules! def_register {
}
}
impl<
$($par: Any + Clone,)*
FN: Fn($($param),*) -> Result<RET, EvalAltResult> + 'static,
RET: Any
> RegisterResultFn<FN, ($($mark,)*), RET> for Engine<'_>
{
fn register_result_fn(&mut self, name: &str, f: FN) {
let fn_name = name.to_string();
let fun = move |mut args: FnCallArgs, pos: Position| {
// Check for length at the beginning to avoid per-element bound checks.
const NUM_ARGS: usize = count_args!($($par)*);
if args.len() != NUM_ARGS {
Err(EvalAltResult::ErrorFunctionArgsMismatch(fn_name.clone(), NUM_ARGS, args.len(), pos))
} else {
#[allow(unused_variables, unused_mut)]
let mut drain = args.drain(..);
$(
// Downcast every element, return in case of a type mismatch
let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap();
)*
// Call the user-supplied function using ($clone) to
// potentially clone the value, otherwise pass the reference.
match f($(($clone)($par)),*) {
Ok(r) => Ok(Box::new(r) as Dynamic),
Err(mut err) => {
err.set_position(pos);
Err(err)
}
}
}
};
self.register_fn_raw(name, Some(vec![$(TypeId::of::<$par>()),*]), Box::new(fun));
}
}
//def_register!(imp_pop $($par => $mark => $param),*);
};
($p0:ident $(, $p:ident)*) => {

View File

@@ -69,7 +69,7 @@ pub use any::{Any, AnyExt, Dynamic, Variant};
pub use call::FuncArgs;
pub use engine::{Array, Engine};
pub use error::{ParseError, ParseErrorType};
pub use fn_register::{RegisterDynamicFn, RegisterFn};
pub use fn_register::{RegisterDynamicFn, RegisterFn, RegisterResultFn};
pub use parser::{Position, AST};
pub use result::EvalAltResult;
pub use scope::Scope;

View File

@@ -2,7 +2,7 @@
use crate::any::Dynamic;
use crate::error::{LexError, ParseError, ParseErrorType};
use std::{borrow::Cow, char, fmt, iter::Peekable, str::Chars};
use std::{borrow::Cow, char, fmt, iter::Peekable, str::Chars, usize};
type LERR = LexError;
type PERR = ParseErrorType;
@@ -17,25 +17,33 @@ pub struct Position {
impl Position {
/// Create a new `Position`.
pub fn new(line: usize, position: usize) -> Self {
if line == 0 || (line == usize::MAX && position == usize::MAX) {
panic!("invalid position: ({}, {})", line, position);
}
Self {
line,
pos: position,
}
}
/// Get the line number (1-based), or `None` if EOF.
/// Get the line number (1-based), or `None` if no position or EOF.
pub fn line(&self) -> Option<usize> {
match self.line {
0 => None,
x => Some(x),
if self.is_none() || self.is_eof() {
None
} else {
Some(self.line)
}
}
/// Get the character position (1-based), or `None` if at beginning of a line.
pub fn position(&self) -> Option<usize> {
match self.pos {
0 => None,
x => Some(x),
if self.is_none() || self.is_eof() {
None
} else if self.pos == 0 {
None
} else {
Some(self.pos)
}
}
@@ -61,14 +69,27 @@ impl Position {
self.pos = 0;
}
/// Create a `Position` representing no position.
pub(crate) fn none() -> Self {
Self { line: 0, pos: 0 }
}
/// Create a `Position` at EOF.
pub(crate) fn eof() -> Self {
Self { line: 0, pos: 0 }
Self {
line: usize::MAX,
pos: usize::MAX,
}
}
/// Is there no `Position`?
pub fn is_none(&self) -> bool {
self.line == 0 && self.pos == 0
}
/// Is the `Position` at EOF?
pub fn is_eof(&self) -> bool {
self.line == 0
self.line == usize::MAX && self.pos == usize::MAX
}
}
@@ -82,6 +103,8 @@ impl fmt::Display for Position {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_eof() {
write!(f, "EOF")
} else if self.is_none() {
write!(f, "none")
} else {
write!(f, "line {}, position {}", self.line, self.pos)
}

View File

@@ -118,9 +118,10 @@ impl fmt::Display for EvalAltResult {
Self::ErrorMismatchOutputType(s, pos) => write!(f, "{}: {} ({})", desc, s, pos),
Self::ErrorDotExpr(s, pos) if !s.is_empty() => write!(f, "{} {} ({})", desc, s, pos),
Self::ErrorDotExpr(_, pos) => write!(f, "{} ({})", desc, pos),
Self::ErrorArithmetic(s, pos) => write!(f, "{}: {} ({})", desc, s, pos),
Self::ErrorRuntime(s, pos) if s.is_empty() => write!(f, "{} ({})", desc, pos),
Self::ErrorRuntime(s, pos) => write!(f, "{}: {} ({})", desc, s, pos),
Self::ErrorArithmetic(s, pos) => write!(f, "{} ({})", s, pos),
Self::ErrorRuntime(s, pos) => {
write!(f, "{} ({})", if s.is_empty() { desc } else { s }, pos)
}
Self::LoopBreak => write!(f, "{}", desc),
Self::Return(_, pos) => write!(f, "{} ({})", desc, pos),
Self::ErrorReadingScriptFile(filename, err) => {
@@ -171,3 +172,37 @@ impl From<ParseError> for EvalAltResult {
Self::ErrorParsing(err)
}
}
impl EvalAltResult {
pub(crate) fn set_position(&mut self, new_position: Position) {
match self {
EvalAltResult::ErrorReadingScriptFile(_, _)
| EvalAltResult::LoopBreak
| EvalAltResult::ErrorParsing(_) => (),
EvalAltResult::ErrorFunctionNotFound(_, ref mut pos)
| EvalAltResult::ErrorFunctionArgsMismatch(_, _, _, ref mut pos)
| EvalAltResult::ErrorBooleanArgMismatch(_, ref mut pos)
| EvalAltResult::ErrorCharMismatch(ref mut pos)
| EvalAltResult::ErrorArrayBounds(_, _, ref mut pos)
| EvalAltResult::ErrorStringBounds(_, _, ref mut pos)
| EvalAltResult::ErrorIndexingType(_, ref mut pos)
| EvalAltResult::ErrorIndexExpr(ref mut pos)
| EvalAltResult::ErrorIfGuard(ref mut pos)
| EvalAltResult::ErrorFor(ref mut pos)
| EvalAltResult::ErrorVariableNotFound(_, ref mut pos)
| EvalAltResult::ErrorAssignmentToUnknownLHS(ref mut pos)
| EvalAltResult::ErrorMismatchOutputType(_, ref mut pos)
| EvalAltResult::ErrorDotExpr(_, ref mut pos)
| EvalAltResult::ErrorArithmetic(_, ref mut pos)
| EvalAltResult::ErrorRuntime(_, ref mut pos)
| EvalAltResult::Return(_, ref mut pos) => *pos = new_position,
}
}
}
impl<T: AsRef<str>> From<T> for EvalAltResult {
fn from(err: T) -> Self {
Self::ErrorRuntime(err.as_ref().to_string(), Position::none())
}
}