diff --git a/RELEASES.md b/RELEASES.md index 3975648a..5f251322 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -16,6 +16,7 @@ New features ------------ * `Engine::compile_to_self_contained` compiles a script into an `AST` and _eagerly_ resolves all `import` statements with string literal paths. The resolved modules are directly embedded into the `AST`. When the `AST` is later evaluated, `import` statements directly yield the pre-resolved modules without going through the resolution process once again. +* `AST::walk`, `Stmt::walk` and `Expr::walk` internal API's to recursively walk an `AST`. Enhancements ------------ diff --git a/src/ast.rs b/src/ast.rs index 6a5f478e..de1d613e 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -682,6 +682,47 @@ impl AST { pub fn clear_statements(&mut self) { self.statements = vec![]; } + /// Recursively walk the [`AST`], including function bodies (if any). + #[cfg(not(feature = "internals"))] + #[inline(always)] + pub(crate) fn walk(&self, on_node: &mut impl FnMut(&[ASTNode])) { + let mut path = Default::default(); + + self.statements() + .iter() + .chain({ + #[cfg(not(feature = "no_function"))] + { + self.iter_fn_def().map(|f| &f.body) + } + #[cfg(feature = "no_function")] + { + crate::stdlib::iter::empty() + } + }) + .for_each(|stmt| stmt.walk(&mut path, on_node)); + } + /// _(INTERNALS)_ Recursively walk the [`AST`], including function bodies (if any). + /// Exported under the `internals` feature only. + #[cfg(feature = "internals")] + #[inline(always)] + pub fn walk(&self, on_node: &mut impl FnMut(&[ASTNode])) { + let mut path = Default::default(); + + self.statements() + .iter() + .chain({ + #[cfg(not(feature = "no_function"))] + { + self.iter_fn_def().map(|f| &f.body) + } + #[cfg(feature = "no_function")] + { + crate::stdlib::iter::empty() + } + }) + .for_each(|stmt| stmt.walk(&mut path, on_node)); + } } impl> Add for &AST { @@ -749,6 +790,30 @@ pub enum ReturnType { Exception, } +/// _(INTERNALS)_ An [`AST`] node, consisting of either an [`Expr`] or a [`Stmt`]. +/// Exported under the `internals` feature only. +/// +/// # WARNING +/// +/// This type is volatile and may change. +#[derive(Debug, Clone, Hash)] +pub enum ASTNode<'a> { + Stmt(&'a Stmt), + Expr(&'a Expr), +} + +impl<'a> From<&'a Stmt> for ASTNode<'a> { + fn from(stmt: &'a Stmt) -> Self { + Self::Stmt(stmt) + } +} + +impl<'a> From<&'a Expr> for ASTNode<'a> { + fn from(expr: &'a Expr) -> Self { + Self::Expr(expr) + } +} + /// _(INTERNALS)_ A statement. /// Exported under the `internals` feature only. /// @@ -949,50 +1014,50 @@ impl Stmt { } /// Recursively walk this statement. #[inline(always)] - pub fn walk(&self, process_stmt: &mut impl FnMut(&Stmt), process_expr: &mut impl FnMut(&Expr)) { - process_stmt(self); + pub fn walk<'a>(&'a self, path: &mut Vec>, on_node: &mut impl FnMut(&[ASTNode])) { + path.push(self.into()); + on_node(path); match self { - Self::Let(_, Some(e), _, _) | Self::Const(_, Some(e), _, _) => { - e.walk(process_stmt, process_expr) - } + Self::Let(_, Some(e), _, _) | Self::Const(_, Some(e), _, _) => e.walk(path, on_node), Self::If(e, x, _) => { - e.walk(process_stmt, process_expr); - x.0.walk(process_stmt, process_expr); + e.walk(path, on_node); + x.0.walk(path, on_node); if let Some(ref s) = x.1 { - s.walk(process_stmt, process_expr); + s.walk(path, on_node); } } Self::Switch(e, x, _) => { - e.walk(process_stmt, process_expr); - x.0.values() - .for_each(|s| s.walk(process_stmt, process_expr)); + e.walk(path, on_node); + x.0.values().for_each(|s| s.walk(path, on_node)); if let Some(ref s) = x.1 { - s.walk(process_stmt, process_expr); + s.walk(path, on_node); } } Self::While(e, s, _) | Self::Do(s, e, _, _) => { - e.walk(process_stmt, process_expr); - s.walk(process_stmt, process_expr); + e.walk(path, on_node); + s.walk(path, on_node); } Self::For(e, x, _) => { - e.walk(process_stmt, process_expr); - x.1.walk(process_stmt, process_expr); + e.walk(path, on_node); + x.1.walk(path, on_node); } Self::Assignment(x, _) => { - x.0.walk(process_stmt, process_expr); - x.2.walk(process_stmt, process_expr); + x.0.walk(path, on_node); + x.2.walk(path, on_node); } - Self::Block(x, _) => x.iter().for_each(|s| s.walk(process_stmt, process_expr)), + Self::Block(x, _) => x.iter().for_each(|s| s.walk(path, on_node)), Self::TryCatch(x, _, _) => { - x.0.walk(process_stmt, process_expr); - x.2.walk(process_stmt, process_expr); + x.0.walk(path, on_node); + x.2.walk(path, on_node); } - Self::Expr(e) | Self::Return(_, Some(e), _) => e.walk(process_stmt, process_expr), + Self::Expr(e) | Self::Return(_, Some(e), _) => e.walk(path, on_node), #[cfg(not(feature = "no_module"))] - Self::Import(e, _, _) => e.walk(process_stmt, process_expr), + Self::Import(e, _, _) => e.walk(path, on_node), _ => (), } + + path.pop().unwrap(); } } @@ -1416,25 +1481,23 @@ impl Expr { } /// Recursively walk this expression. #[inline(always)] - pub fn walk(&self, process_stmt: &mut impl FnMut(&Stmt), process_expr: &mut impl FnMut(&Expr)) { - process_expr(self); + pub fn walk<'a>(&'a self, path: &mut Vec>, on_node: &mut impl FnMut(&[ASTNode])) { + path.push(self.into()); + on_node(path); match self { - Self::Stmt(x, _) => x.iter().for_each(|s| s.walk(process_stmt, process_expr)), - Self::Array(x, _) => x.iter().for_each(|e| e.walk(process_stmt, process_expr)), - Self::Map(x, _) => x - .iter() - .for_each(|(_, e)| e.walk(process_stmt, process_expr)), + Self::Stmt(x, _) => x.iter().for_each(|s| s.walk(path, on_node)), + Self::Array(x, _) => x.iter().for_each(|e| e.walk(path, on_node)), + Self::Map(x, _) => x.iter().for_each(|(_, e)| e.walk(path, on_node)), Self::Index(x, _) | Expr::In(x, _) | Expr::And(x, _) | Expr::Or(x, _) => { - x.lhs.walk(process_stmt, process_expr); - x.rhs.walk(process_stmt, process_expr); + x.lhs.walk(path, on_node); + x.rhs.walk(path, on_node); } - Self::Custom(x, _) => x - .keywords - .iter() - .for_each(|e| e.walk(process_stmt, process_expr)), + Self::Custom(x, _) => x.keywords.iter().for_each(|e| e.walk(path, on_node)), _ => (), } + + path.pop().unwrap(); } } diff --git a/src/engine_api.rs b/src/engine_api.rs index d2af0960..9efa0c7e 100644 --- a/src/engine_api.rs +++ b/src/engine_api.rs @@ -903,47 +903,53 @@ impl Engine { script: &str, ) -> Result> { use crate::{ - ast::{Expr, Stmt}, + ast::{ASTNode, Expr, Stmt}, fn_native::shared_take_or_clone, module::resolvers::StaticModuleResolver, - stdlib::collections::HashMap, + stdlib::collections::HashSet, ImmutableString, }; - let mut ast = self.compile_scripts_with_scope(scope, &[script])?; - let mut imports = HashMap::::new(); - - ast.statements() - .iter() - .chain({ - #[cfg(not(feature = "no_function"))] + fn collect_imports( + ast: &AST, + resolver: &StaticModuleResolver, + imports: &mut HashSet, + ) { + ast.walk(&mut |path| match path.last().unwrap() { + // Collect all `import` statements with a string constant path + ASTNode::Stmt(Stmt::Import(Expr::StringConstant(s, _), _, _)) + if !resolver.contains_path(s) && !imports.contains(s) => { - ast.iter_fn_def().map(|f| &f.body) + imports.insert(s.clone()); } - #[cfg(feature = "no_function")] - { - crate::stdlib::iter::empty() - } - }) - .for_each(|stmt| { - stmt.walk( - &mut |stmt| match stmt { - Stmt::Import(Expr::StringConstant(s, pos), _, _) - if !imports.contains_key(s) => - { - imports.insert(s.clone(), *pos); - } - _ => (), - }, - &mut |_| {}, - ) + _ => (), }); + } + + let mut resolver = StaticModuleResolver::new(); + let mut ast = self.compile_scripts_with_scope(scope, &[script])?; + let mut imports = HashSet::::new(); + + collect_imports(&ast, &mut resolver, &mut imports); if !imports.is_empty() { - let mut resolver = StaticModuleResolver::new(); - for (path, pos) in imports { - let module = self.module_resolver.resolve(self, &path, pos)?; - let module = shared_take_or_clone(module); + while let Some(path) = imports.iter().next() { + let path = path.clone(); + + if let Some(module_ast) = + self.module_resolver + .resolve_ast(self, &path, Position::NONE)? + { + collect_imports(&module_ast, &mut resolver, &mut imports); + } + + let module = shared_take_or_clone(self.module_resolver.resolve( + self, + &path, + Position::NONE, + )?); + + imports.remove(&path); resolver.insert(path, module); } ast.set_resolver(resolver); diff --git a/src/lib.rs b/src/lib.rs index fb28dcf1..8cb99223 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -184,7 +184,8 @@ pub use token::{get_next_token, parse_string_literal, InputStream, Token, Tokeni #[cfg(feature = "internals")] #[deprecated = "this type is volatile and may change"] pub use ast::{ - BinaryExpr, CustomExpr, Expr, FloatWrapper, FnCallExpr, Ident, ReturnType, ScriptFnDef, Stmt, + ASTNode, BinaryExpr, CustomExpr, Expr, FloatWrapper, FnCallExpr, Ident, ReturnType, + ScriptFnDef, Stmt, }; #[cfg(feature = "internals")] diff --git a/src/module/resolvers/file.rs b/src/module/resolvers/file.rs index c3fc4a5a..841a2d4e 100644 --- a/src/module/resolvers/file.rs +++ b/src/module/resolvers/file.rs @@ -227,4 +227,30 @@ impl ModuleResolver for FileModuleResolver { Ok(m) } + + fn resolve_ast( + &self, + engine: &Engine, + path: &str, + pos: Position, + ) -> Result, Box> { + // Construct the script file path + let mut file_path = self.base_path.clone(); + file_path.push(path); + file_path.set_extension(&self.extension); // Force extension + + // Load the script file and compile it + let mut ast = engine + .compile_file(file_path.clone()) + .map_err(|err| match *err { + EvalAltResult::ErrorSystem(_, err) if err.is::() => { + Box::new(EvalAltResult::ErrorModuleNotFound(path.to_string(), pos)) + } + _ => Box::new(EvalAltResult::ErrorInModule(path.to_string(), err, pos)), + })?; + + ast.set_source(path); + + Ok(Some(ast)) + } } diff --git a/src/module/resolvers/mod.rs b/src/module/resolvers/mod.rs index f257f0e8..17e0c0ac 100644 --- a/src/module/resolvers/mod.rs +++ b/src/module/resolvers/mod.rs @@ -1,6 +1,6 @@ use crate::fn_native::SendSync; use crate::stdlib::boxed::Box; -use crate::{Engine, EvalAltResult, Module, Position, Shared}; +use crate::{Engine, EvalAltResult, Module, Position, Shared, AST}; mod dummy; pub use dummy::DummyModuleResolver; @@ -28,4 +28,23 @@ pub trait ModuleResolver: SendSync { path: &str, pos: Position, ) -> Result, Box>; + + /// Resolve a module into an `AST` based on a path string. + /// + /// Returns [`None`] (default) if such resolution is not supported + /// (e.g. if the module is Rust-based). + /// + /// ## Low-Level API + /// + /// Override the default implementation of this method if the module resolver + /// serves modules based on compiled Rhai scripts. + #[allow(unused_variables)] + fn resolve_ast( + &self, + engine: &Engine, + path: &str, + pos: Position, + ) -> Result, Box> { + Ok(None) + } } diff --git a/tests/modules.rs b/tests/modules.rs index 4f6baa55..aadf2a5d 100644 --- a/tests/modules.rs +++ b/tests/modules.rs @@ -176,11 +176,11 @@ fn test_module_resolver() -> Result<(), Box> { assert_eq!( engine.eval::( r#" - import "hello" as h; - let x = 21; - h::sum_of_three_args(x, 14, 26, 2.0); - x - "# + import "hello" as h; + let x = 21; + h::sum_of_three_args(x, 14, 26, 2.0); + x + "# )?, 42 );