diff --git a/kite_sql_serde_macros/src/orm.rs b/kite_sql_serde_macros/src/orm.rs index a4fe59fc..0b2b5752 100644 --- a/kite_sql_serde_macros/src/orm.rs +++ b/kite_sql_serde_macros/src/orm.rs @@ -66,7 +66,9 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { )); }; - let mut assignments = Vec::new(); + let mut field_initializers = Vec::new(); + let mut field_index_declarations = Vec::new(); + let mut field_index_resolvers = Vec::new(); let mut params = Vec::new(); let mut orm_fields = Vec::new(); let mut orm_columns = Vec::new(); @@ -132,6 +134,13 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { "decimal field cannot be skipped", )); } + generics + .make_where_clause() + .predicates + .push(parse_quote!(#field_ty : ::core::default::Default)); + field_initializers.push(quote! { + #field_name: ::core::default::Default::default() + }); continue; } @@ -163,6 +172,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { let is_unique = field.unique; let is_index = field.index; let column_index = orm_columns.len(); + let field_index_ident = format_ident!("__kite_orm_{field_name}_index"); persisted_columns.push((field_name_string, column_name.clone())); column_names.push(column_name.clone()); @@ -223,11 +233,22 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { quote! { None::<::kite_sql::types::value::DataValue> } }; - assignments.push(quote! { - if let Some(value) = ::kite_sql::orm::try_get::<#field_ty>(&mut tuple, schema, #column_name_lit) { - struct_instance.#field_name = value; + field_index_declarations.push(quote! { + let mut #field_index_ident = None; + }); + field_index_resolvers.push(quote! { + if #field_index_ident.is_none() && __kite_orm_column_name == #column_name_lit { + #field_index_ident = Some(__kite_orm_index); + __kite_orm_found_fields += 1; } }); + field_initializers.push(quote! { + #field_name: ::kite_sql::orm::take_value_at::<#field_ty>( + &mut tuple, + #field_index_ident, + #column_name_lit, + )? + }); params.push(quote! { (#placeholder_lit, ::kite_sql::orm::ToDataValue::to_data_value(&self.#field_name)) }); @@ -379,23 +400,31 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { let primary_key_value = primary_key_value.expect("primary key checked above"); let _primary_key_column = primary_key_column.expect("primary key checked above"); let _primary_key_placeholder = primary_key_placeholder.expect("primary key checked above"); - let mut from_generics = generics.clone(); - from_generics.params.insert(0, parse_quote!('__kite_arena)); - from_generics.params.insert(0, parse_quote!('__kite_schema)); - let (from_impl_generics, _, from_where_clause) = from_generics.split_for_impl(); + let field_count = field_index_declarations.len(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); Ok(quote! { - impl #from_impl_generics ::core::convert::From<(&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)> + impl #impl_generics ::kite_sql::orm::FromQueryRow for #struct_name #ty_generics - #from_where_clause + #where_clause { - fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)) -> Self { - let mut struct_instance = ::default(); - - #(#assignments)* + fn from_query_row( + schema: &::kite_sql::types::tuple::SchemaView<'_, '_>, + mut tuple: ::kite_sql::types::tuple::Tuple, + ) -> ::std::result::Result { + let mut __kite_orm_found_fields = 0usize; + #(#field_index_declarations)* + for (__kite_orm_index, __kite_orm_column) in schema.iter().enumerate() { + let __kite_orm_column_name = __kite_orm_column.name(); + #(#field_index_resolvers)* + if __kite_orm_found_fields == #field_count { + break; + } + } - struct_instance + Ok(Self { + #(#field_initializers),* + }) } } diff --git a/kite_sql_serde_macros/src/projection.rs b/kite_sql_serde_macros/src/projection.rs index 3d617019..e9b5b5ce 100644 --- a/kite_sql_serde_macros/src/projection.rs +++ b/kite_sql_serde_macros/src/projection.rs @@ -1,7 +1,7 @@ use darling::ast::Data; use darling::{FromDeriveInput, FromField}; use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{format_ident, quote}; use syn::{parse_quote, DeriveInput, Error, Generics, Ident, LitStr, Type}; #[derive(Debug, FromDeriveInput)] @@ -34,7 +34,9 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { }; let mut projection_exprs = Vec::new(); - let mut assignments = Vec::new(); + let mut field_initializers = Vec::new(); + let mut field_index_declarations = Vec::new(); + let mut field_index_resolvers = Vec::new(); for field in data_struct.fields { let ProjectionFieldOpts { @@ -50,6 +52,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { let source_name = rename.clone().unwrap_or_else(|| field_name_string.clone()); let source_name_lit = LitStr::new(&source_name, Span::call_site()); let field_name_lit = LitStr::new(&field_name_string, Span::call_site()); + let field_index_ident = format_ident!("__kite_projection_{field_name}_index"); let relation_expr = if let Some(source_relation) = from { let relation_lit = LitStr::new(&source_relation, Span::call_site()); quote!(#relation_lit) @@ -74,17 +77,25 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { scope.column_ref(#relation_expr, #source_name_lit)? } }); - assignments.push(quote! { - if let Some(value) = ::kite_sql::orm::try_get::<#field_ty>(&mut tuple, schema, #field_name_lit) { - struct_instance.#field_name = value; + field_index_declarations.push(quote! { + let mut #field_index_ident = None; + }); + field_index_resolvers.push(quote! { + if #field_index_ident.is_none() && __kite_orm_column_name == #field_name_lit { + #field_index_ident = Some(__kite_orm_index); + __kite_orm_found_fields += 1; } }); + field_initializers.push(quote! { + #field_name: ::kite_sql::orm::take_value_at::<#field_ty>( + &mut tuple, + #field_index_ident, + #field_name_lit, + )? + }); } - let mut from_generics = generics.clone(); - from_generics.params.insert(0, parse_quote!('__kite_arena)); - from_generics.params.insert(0, parse_quote!('__kite_schema)); - let (from_impl_generics, _, from_where_clause) = from_generics.split_for_impl(); + let field_count = field_index_declarations.len(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); Ok(quote! { @@ -105,13 +116,26 @@ pub(crate) fn handle(ast: DeriveInput) -> Result { } } - impl #from_impl_generics From<(&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)> for #struct_name #ty_generics - #from_where_clause + impl #impl_generics ::kite_sql::orm::FromQueryRow for #struct_name #ty_generics + #where_clause { - fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)) -> Self { - let mut struct_instance = ::default(); - #(#assignments)* - struct_instance + fn from_query_row( + schema: &::kite_sql::types::tuple::SchemaView<'_, '_>, + mut tuple: ::kite_sql::types::tuple::Tuple, + ) -> ::std::result::Result { + let mut __kite_orm_found_fields = 0usize; + #(#field_index_declarations)* + for (__kite_orm_index, __kite_orm_column) in schema.iter().enumerate() { + let __kite_orm_column_name = __kite_orm_column.name(); + #(#field_index_resolvers)* + if __kite_orm_found_fields == #field_count { + break; + } + } + + Ok(Self { + #(#field_initializers),* + }) } } }) diff --git a/src/db.rs b/src/db.rs index 66dd5073..f1303294 100644 --- a/src/db.rs +++ b/src/db.rs @@ -37,6 +37,8 @@ use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline; use crate::optimizer::rule::implementation::ImplementationRuleImpl; use crate::optimizer::rule::normalization::NormalizationRuleImpl; +#[cfg(feature = "orm")] +use crate::orm::FromQueryRow; use crate::planner::operator::Operator; use crate::planner::{LogicalPlan, PlanArena, TableArenaCell}; #[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))] @@ -945,12 +947,12 @@ pub trait ResultIter: BorrowResultIter + Iterator`, which is typically generated by - /// `#[derive(Model)]`. + /// implements `FromQueryRow`, which is typically generated by + /// `#[derive(Model)]` or `#[derive(Projection)]`. fn orm(self) -> OrmIter where Self: Sized, - T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, + T: FromQueryRow, { OrmIter::new(self) } @@ -969,7 +971,7 @@ pub struct OrmIter { impl OrmIter where I: ResultIter, - T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, + T: FromQueryRow, { fn new(inner: I) -> Self { Self { @@ -993,7 +995,7 @@ where impl Iterator for OrmIter where I: ResultIter, - T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, + T: FromQueryRow, { type Item = Result; @@ -1002,7 +1004,7 @@ where Ok(tuple) => tuple, Err(err) => return Some(Err(err)), }; - Some(Ok(self.inner.schema(|schema| T::from((schema, tuple))))) + Some(self.inner.schema(|schema| T::from_query_row(schema, tuple))) } } diff --git a/src/orm/mod.rs b/src/orm/mod.rs index 5c4c17d5..bef198e4 100644 --- a/src/orm/mod.rs +++ b/src/orm/mod.rs @@ -2278,9 +2278,7 @@ where } #[doc(hidden)] -pub trait Projection: - for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)> -{ +pub trait Projection: FromQueryRow { fn bind_projection<'ctx, 'bind, 'parent, 'arena, T, A>( scope: &mut ExprBindScope<'ctx, 'bind, 'parent, 'arena, T, A>, relation: &str, @@ -2449,9 +2447,7 @@ fn describe_text_value(value: Option) -> String { /// In normal usage you should derive this trait with `#[derive(Model)]` rather /// than implementing it by hand. The derive macro generates tuple mapping and /// model metadata. -pub trait Model: - Sized + for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)> -{ +pub trait Model: Sized + FromQueryRow { /// Rust type used as the model primary key. /// /// This associated type lets APIs such as @@ -2512,8 +2508,8 @@ pub trait FromDataValue: Sized { /// Returns the logical SQL type used for conversion, when one is required. fn logical_type() -> Option; - /// Attempts to convert a raw [`DataValue`] into `Self`. - fn from_data_value(value: DataValue) -> Option; + /// Converts a raw [`DataValue`] into `Self`. + fn from_data_value(value: DataValue) -> Result; } /// Conversion trait from a projected result tuple into a Rust value. @@ -2531,6 +2527,25 @@ pub trait FromQueryTuple: Sized { fn from_query_tuple(tuple: Tuple) -> Result; } +/// Conversion trait from a query result row into a Rust value. +/// +/// `#[derive(Model)]` and `#[derive(Projection)]` generate this automatically. +/// Types that still implement the older `From<(&SchemaView, Tuple)>` mapping +/// are also accepted through a compatibility implementation. +pub trait FromQueryRow: Sized { + /// Decodes one result row into `Self`. + fn from_query_row(schema: &SchemaView<'_, '_>, tuple: Tuple) -> Result; +} + +impl FromQueryRow for T +where + T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, +{ + fn from_query_row(schema: &SchemaView<'_, '_>, tuple: Tuple) -> Result { + Ok(T::from((schema, tuple))) + } +} + /// Typed adapter over a [`ResultIter`] that yields projected values instead of raw tuples. /// /// This adapts a raw ORM result iterator into scalar projected values. @@ -2698,17 +2713,24 @@ pub trait StringType {} pub trait DecimalType {} #[doc(hidden)] -pub fn try_get( +pub fn take_value_at( tuple: &mut Tuple, - schema: &SchemaView<'_, '_>, + index: Option, field_name: &str, -) -> Option { - let ty = T::logical_type()?; - let idx = schema.position(field_name)?; - - let value = std::mem::replace(&mut tuple.values[idx], DataValue::Null) - .cast(&ty) - .ok()?; +) -> Result { + let idx = index.ok_or_else(|| DatabaseError::ColumnNotFound { + name: field_name.to_string(), + span: None, + })?; + let value = tuple.values.get_mut(idx).ok_or(DatabaseError::MisMatch( + "the query result schema", + "the query result tuple", + ))?; + let value = std::mem::replace(value, DataValue::Null); + let value = match T::logical_type() { + Some(ty) => value.cast(&ty)?, + None => value, + }; T::from_data_value(value) } @@ -2720,8 +2742,10 @@ macro_rules! impl_from_data_value_by_method { LogicalType::type_trans::() } - fn from_data_value(value: DataValue) -> Option { - value.$method() + fn from_data_value(value: DataValue) -> Result { + value + .$method() + .ok_or_else(|| crate::orm::invalid_from_data_value::(&value)) } } }; @@ -2828,11 +2852,11 @@ impl FromDataValue for String { LogicalType::type_trans::() } - fn from_data_value(value: DataValue) -> Option { + fn from_data_value(value: DataValue) -> Result { if let DataValue::Utf8 { value, .. } = value { - Some(value) + Ok(value) } else { - None + Err(invalid_from_data_value::(&value)) } } } @@ -2842,11 +2866,11 @@ impl FromDataValue for Arc { Some(LogicalType::Varchar(None, CharLengthUnits::Characters)) } - fn from_data_value(value: DataValue) -> Option { + fn from_data_value(value: DataValue) -> Result { if let DataValue::Utf8 { value, .. } = value { - Some(value.into()) + Ok(value.into()) } else { - None + Err(invalid_from_data_value::(&value)) } } } @@ -2874,9 +2898,9 @@ impl FromDataValue for Option { T::logical_type() } - fn from_data_value(value: DataValue) -> Option { + fn from_data_value(value: DataValue) -> Result { if matches!(value, DataValue::Null) { - Some(None) + Ok(None) } else { T::from_data_value(value).map(Some) } @@ -2993,12 +3017,12 @@ where fn extract_optional_row(mut iter: I) -> Result, DatabaseError> where I: ResultIter, - T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>, + T: FromQueryRow, { Ok(match iter.next() { Some(tuple) => { let tuple = tuple?; - Some(iter.schema(|schema| T::from((schema, tuple)))) + Some(iter.schema(|schema| T::from_query_row(schema, tuple))?) } None => None, }) @@ -3010,12 +3034,15 @@ fn convert_projected_value(value: DataValue) -> Result value, }; - T::from_data_value(value).ok_or_else(|| { - DatabaseError::InvalidValue(format!( - "failed to convert projected value into {}", - std::any::type_name::() - )) - }) + T::from_data_value(value) +} + +fn invalid_from_data_value(value: &DataValue) -> DatabaseError { + DatabaseError::InvalidValue(format!( + "failed to convert {} value `{value}` into {}", + value.logical_type(), + std::any::type_name::() + )) } fn extract_projected_data_value( diff --git a/tests/macros-test/src/main.rs b/tests/macros-test/src/main.rs index a424cdeb..c620b4b1 100644 --- a/tests/macros-test/src/main.rs +++ b/tests/macros-test/src/main.rs @@ -73,6 +73,22 @@ mod test { Ok((temp_dir, database)) } + #[test] + fn test_from_data_value_reports_conversion_error() { + let err = ::from_data_value(DataValue::Utf8 { + value: "not an integer".to_string(), + ty: Utf8Type::Variable(None), + unit: CharLengthUnits::Characters, + }) + .expect_err("converting utf8 directly into i32 should fail"); + + assert!(matches!( + err, + DatabaseError::InvalidValue(message) + if message.contains("Varchar") && message.contains("i32") + )); + } + fn create_model_table( database: &mut Database, ) -> Result<(), DatabaseError> { @@ -137,6 +153,14 @@ mod test { cache: String, } + #[derive(Debug, PartialEq, Model)] + #[model(table = "no_default_users")] + struct NoDefaultUser { + #[model(primary_key)] + id: i32, + name: String, + } + #[derive(Default, Debug, PartialEq, Model)] #[model(table = "wallets")] struct Wallet { @@ -203,6 +227,27 @@ mod test { age: Option, } + #[derive(Debug, PartialEq, Projection)] + struct InvalidUserProjection { + #[projection(rename = "user_name")] + name: RejectDataValue, + } + + #[derive(Debug, PartialEq)] + struct RejectDataValue; + + impl kite_sql::orm::FromDataValue for RejectDataValue { + fn logical_type() -> Option { + Some(LogicalType::Varchar(None, CharLengthUnits::Characters)) + } + + fn from_data_value(_: DataValue) -> Result { + Err(DatabaseError::InvalidValue( + "rejecting value for test".to_string(), + )) + } + } + #[derive(Default, Debug, PartialEq, Projection)] struct UserOrderSummary { #[projection(from = "users", rename = "user_name")] @@ -325,7 +370,8 @@ mod test { ); let schema = SchemaView::new(&schema, &plan_arena); - let derived = DerivedStruct::from((&schema, tuple)); + let derived = + ::from_query_row(&schema, tuple).unwrap(); assert_eq!(derived.c1, 9); assert_eq!(derived.name, "LOL"); @@ -333,6 +379,27 @@ mod test { assert_eq!(derived.skipped, ""); } + #[test] + fn test_model_decode_does_not_require_default() -> Result<(), DatabaseError> { + let (_temp_dir, mut database) = build_test_database()?; + + create_model_table::(&mut database)?; + database.insert(&NoDefaultUser { + id: 1, + name: "Alice".to_string(), + })?; + + assert_eq!( + database.get::(&1)?, + Some(NoDefaultUser { + id: 1, + name: "Alice".to_string(), + }) + ); + + Ok(()) + } + #[test] fn test_result_iter_to_orm_iter() -> Result<(), DatabaseError> { let (_temp_dir, mut database) = build_test_database()?; @@ -1056,6 +1123,21 @@ mod test { }) ); + let invalid_projection = database + .bind(|ctx| { + ctx.from::()? + .filter(|e| e.column(User::id())?.eq(1))? + .project::()? + .finish() + })? + .orm::() + .next() + .transpose(); + assert!(matches!( + invalid_projection, + Err(DatabaseError::InvalidValue(_)) + )); + let aliased_total_users = database.bind(|ctx| { ctx.from::()? .project_value(|e| {