diff --git a/codegen/src/function.rs b/codegen/src/function.rs index 89358274..78b26e3d 100644 --- a/codegen/src/function.rs +++ b/codegen/src/function.rs @@ -2,6 +2,8 @@ #[cfg(no_std)] use core::mem; +#[cfg(not(no_std))] +use std::mem; #[cfg(no_std)] use alloc::format; @@ -291,12 +293,21 @@ impl ExportedFn { } } - pub fn generate_with_params( - mut self, - mut params: ExportedFnParams, - ) -> proc_macro2::TokenStream { + pub fn set_params( + &mut self, mut params: ExportedFnParams, + ) -> syn::Result<()> { + + // Do not allow non-returning raw functions. + // + // This is caught now to avoid issues with diagnostics later. + if params.return_raw && mem::discriminant(&self.signature.output) == + mem::discriminant(&syn::ReturnType::Default) { + return Err(syn::Error::new(self.signature.span(), + "return_raw functions must return Result")); + } + self.params = params; - self.generate() + Ok(()) } pub fn generate(self) -> proc_macro2::TokenStream { @@ -353,7 +364,7 @@ impl ExportedFn { } } } else { - quote! { + quote_spanned! { self.return_type().unwrap().span()=> type EvalBox = Box; pub #dynamic_signature { super::#name(#(#arguments),*) @@ -520,7 +531,7 @@ impl ExportedFn { Ok(Dynamic::from(#sig_name(#(#unpack_exprs),*))) } } else { - quote! { + quote_spanned! { self.return_type().unwrap().span()=> #sig_name(#(#unpack_exprs),*) } }; diff --git a/codegen/src/lib.rs b/codegen/src/lib.rs index 9675ac66..a9e43d93 100644 --- a/codegen/src/lib.rs +++ b/codegen/src/lib.rs @@ -109,9 +109,12 @@ pub fn export_fn( let mut output = proc_macro2::TokenStream::from(input.clone()); let parsed_params = parse_macro_input!(args as function::ExportedFnParams); - let function_def = parse_macro_input!(input as function::ExportedFn); + let mut function_def = parse_macro_input!(input as function::ExportedFn); + if let Err(e) = function_def.set_params(parsed_params) { + return e.to_compile_error().into(); + } - output.extend(function_def.generate_with_params(parsed_params)); + output.extend(function_def.generate()); proc_macro::TokenStream::from(output) } diff --git a/codegen/ui_tests/export_fn_raw_noreturn.rs b/codegen/ui_tests/export_fn_raw_noreturn.rs new file mode 100644 index 00000000..7c8b42e0 --- /dev/null +++ b/codegen/ui_tests/export_fn_raw_noreturn.rs @@ -0,0 +1,25 @@ +use rhai::plugin::*; + +#[derive(Clone)] +struct Point { + x: f32, + y: f32, +} + +#[export_fn(return_raw)] +pub fn test_fn(input: &mut Point) { + input.x += 1.0; +} + +fn main() { + let n = Point { + x: 0.0, + y: 10.0, + }; + test_fn(&mut n); + if n.x >= 10.0 { + println!("yes"); + } else { + println!("no"); + } +} diff --git a/codegen/ui_tests/export_fn_raw_noreturn.stderr b/codegen/ui_tests/export_fn_raw_noreturn.stderr new file mode 100644 index 00000000..0687c8c6 --- /dev/null +++ b/codegen/ui_tests/export_fn_raw_noreturn.stderr @@ -0,0 +1,11 @@ +error: return_raw functions must return Result + --> $DIR/export_fn_raw_noreturn.rs:10:5 + | +10 | pub fn test_fn(input: &mut Point) { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error[E0425]: cannot find function `test_fn` in this scope + --> $DIR/export_fn_raw_noreturn.rs:19:5 + | +19 | test_fn(&mut n); + | ^^^^^^^ not found in this scope diff --git a/codegen/ui_tests/export_fn_raw_return.rs b/codegen/ui_tests/export_fn_raw_return.rs new file mode 100644 index 00000000..9df99549 --- /dev/null +++ b/codegen/ui_tests/export_fn_raw_return.rs @@ -0,0 +1,24 @@ +use rhai::plugin::*; + +#[derive(Clone)] +struct Point { + x: f32, + y: f32, +} + +#[export_fn(return_raw)] +pub fn test_fn(input: Point) -> bool { + input.x > input.y +} + +fn main() { + let n = Point { + x: 0.0, + y: 10.0, + }; + if test_fn(n) { + println!("yes"); + } else { + println!("no"); + } +} diff --git a/codegen/ui_tests/export_fn_raw_return.stderr b/codegen/ui_tests/export_fn_raw_return.stderr new file mode 100644 index 00000000..f570fda9 --- /dev/null +++ b/codegen/ui_tests/export_fn_raw_return.stderr @@ -0,0 +1,21 @@ +error[E0308]: mismatched types + --> $DIR/export_fn_raw_return.rs:10:8 + | +9 | #[export_fn(return_raw)] + | ------------------------ expected `std::result::Result>` because of return type +10 | pub fn test_fn(input: Point) -> bool { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected enum `std::result::Result`, found `bool` + | + = note: expected enum `std::result::Result>` + found type `bool` + +error[E0308]: mismatched types + --> $DIR/export_fn_raw_return.rs:10:33 + | +9 | #[export_fn(return_raw)] + | ------------------------ expected `std::result::Result>` because of return type +10 | pub fn test_fn(input: Point) -> bool { + | ^^^^ expected enum `std::result::Result`, found `bool` + | + = note: expected enum `std::result::Result>` + found type `bool`