Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions kite_sql_serde_macros/src/orm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
));
};

let mut assignments = Vec::new();
let mut field_initializers = Vec::new();
let mut field_index_declarations = Vec::new();
let mut field_index_resolvers = Vec::new();
let mut params = Vec::new();
let mut orm_fields = Vec::new();
let mut orm_columns = Vec::new();
Expand Down Expand Up @@ -132,6 +134,13 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
"decimal field cannot be skipped",
));
}
generics
.make_where_clause()
.predicates
.push(parse_quote!(#field_ty : ::core::default::Default));
field_initializers.push(quote! {
#field_name: ::core::default::Default::default()
});
continue;
}

Expand Down Expand Up @@ -163,6 +172,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
let is_unique = field.unique;
let is_index = field.index;
let column_index = orm_columns.len();
let field_index_ident = format_ident!("__kite_orm_{field_name}_index");

persisted_columns.push((field_name_string, column_name.clone()));
column_names.push(column_name.clone());
Expand Down Expand Up @@ -223,11 +233,22 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
quote! { None::<::kite_sql::types::value::DataValue> }
};

assignments.push(quote! {
if let Some(value) = ::kite_sql::orm::try_get::<#field_ty>(&mut tuple, schema, #column_name_lit) {
struct_instance.#field_name = value;
field_index_declarations.push(quote! {
let mut #field_index_ident = None;
});
field_index_resolvers.push(quote! {
if #field_index_ident.is_none() && __kite_orm_column_name == #column_name_lit {
#field_index_ident = Some(__kite_orm_index);
__kite_orm_found_fields += 1;
}
});
field_initializers.push(quote! {
#field_name: ::kite_sql::orm::take_value_at::<#field_ty>(
&mut tuple,
#field_index_ident,
#column_name_lit,
)?
});
params.push(quote! {
(#placeholder_lit, ::kite_sql::orm::ToDataValue::to_data_value(&self.#field_name))
});
Expand Down Expand Up @@ -379,23 +400,31 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
let primary_key_value = primary_key_value.expect("primary key checked above");
let _primary_key_column = primary_key_column.expect("primary key checked above");
let _primary_key_placeholder = primary_key_placeholder.expect("primary key checked above");
let mut from_generics = generics.clone();
from_generics.params.insert(0, parse_quote!('__kite_arena));
from_generics.params.insert(0, parse_quote!('__kite_schema));
let (from_impl_generics, _, from_where_clause) = from_generics.split_for_impl();
let field_count = field_index_declarations.len();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

Ok(quote! {
impl #from_impl_generics ::core::convert::From<(&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)>
impl #impl_generics ::kite_sql::orm::FromQueryRow
for #struct_name #ty_generics
#from_where_clause
#where_clause
{
fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)) -> Self {
let mut struct_instance = <Self as ::core::default::Default>::default();

#(#assignments)*
fn from_query_row(
schema: &::kite_sql::types::tuple::SchemaView<'_, '_>,
mut tuple: ::kite_sql::types::tuple::Tuple,
) -> ::std::result::Result<Self, ::kite_sql::errors::DatabaseError> {
let mut __kite_orm_found_fields = 0usize;
#(#field_index_declarations)*
for (__kite_orm_index, __kite_orm_column) in schema.iter().enumerate() {
let __kite_orm_column_name = __kite_orm_column.name();
#(#field_index_resolvers)*
if __kite_orm_found_fields == #field_count {
break;
}
}

struct_instance
Ok(Self {
#(#field_initializers),*
})
}
}

Expand Down
54 changes: 39 additions & 15 deletions kite_sql_serde_macros/src/projection.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use darling::ast::Data;
use darling::{FromDeriveInput, FromField};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use quote::{format_ident, quote};
use syn::{parse_quote, DeriveInput, Error, Generics, Ident, LitStr, Type};

#[derive(Debug, FromDeriveInput)]
Expand Down Expand Up @@ -34,7 +34,9 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
};

let mut projection_exprs = Vec::new();
let mut assignments = Vec::new();
let mut field_initializers = Vec::new();
let mut field_index_declarations = Vec::new();
let mut field_index_resolvers = Vec::new();

for field in data_struct.fields {
let ProjectionFieldOpts {
Expand All @@ -50,6 +52,7 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
let source_name = rename.clone().unwrap_or_else(|| field_name_string.clone());
let source_name_lit = LitStr::new(&source_name, Span::call_site());
let field_name_lit = LitStr::new(&field_name_string, Span::call_site());
let field_index_ident = format_ident!("__kite_projection_{field_name}_index");
let relation_expr = if let Some(source_relation) = from {
let relation_lit = LitStr::new(&source_relation, Span::call_site());
quote!(#relation_lit)
Expand All @@ -74,17 +77,25 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
scope.column_ref(#relation_expr, #source_name_lit)?
}
});
assignments.push(quote! {
if let Some(value) = ::kite_sql::orm::try_get::<#field_ty>(&mut tuple, schema, #field_name_lit) {
struct_instance.#field_name = value;
field_index_declarations.push(quote! {
let mut #field_index_ident = None;
});
field_index_resolvers.push(quote! {
if #field_index_ident.is_none() && __kite_orm_column_name == #field_name_lit {
#field_index_ident = Some(__kite_orm_index);
__kite_orm_found_fields += 1;
}
});
field_initializers.push(quote! {
#field_name: ::kite_sql::orm::take_value_at::<#field_ty>(
&mut tuple,
#field_index_ident,
#field_name_lit,
)?
});
}

let mut from_generics = generics.clone();
from_generics.params.insert(0, parse_quote!('__kite_arena));
from_generics.params.insert(0, parse_quote!('__kite_schema));
let (from_impl_generics, _, from_where_clause) = from_generics.split_for_impl();
let field_count = field_index_declarations.len();
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

Ok(quote! {
Expand All @@ -105,13 +116,26 @@ pub(crate) fn handle(ast: DeriveInput) -> Result<TokenStream, Error> {
}
}

impl #from_impl_generics From<(&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)> for #struct_name #ty_generics
#from_where_clause
impl #impl_generics ::kite_sql::orm::FromQueryRow for #struct_name #ty_generics
#where_clause
{
fn from((schema, mut tuple): (&::kite_sql::types::tuple::SchemaView<'__kite_schema, '__kite_arena>, ::kite_sql::types::tuple::Tuple)) -> Self {
let mut struct_instance = <Self as ::std::default::Default>::default();
#(#assignments)*
struct_instance
fn from_query_row(
schema: &::kite_sql::types::tuple::SchemaView<'_, '_>,
mut tuple: ::kite_sql::types::tuple::Tuple,
) -> ::std::result::Result<Self, ::kite_sql::errors::DatabaseError> {
let mut __kite_orm_found_fields = 0usize;
#(#field_index_declarations)*
for (__kite_orm_index, __kite_orm_column) in schema.iter().enumerate() {
let __kite_orm_column_name = __kite_orm_column.name();
#(#field_index_resolvers)*
if __kite_orm_found_fields == #field_count {
break;
}
}

Ok(Self {
#(#field_initializers),*
})
}
}
})
Expand Down
14 changes: 8 additions & 6 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ use crate::optimizer::heuristic::batch::HepBatchStrategy;
use crate::optimizer::heuristic::optimizer::HepOptimizerPipeline;
use crate::optimizer::rule::implementation::ImplementationRuleImpl;
use crate::optimizer::rule::normalization::NormalizationRuleImpl;
#[cfg(feature = "orm")]
use crate::orm::FromQueryRow;
use crate::planner::operator::Operator;
use crate::planner::{LogicalPlan, PlanArena, TableArenaCell};
#[cfg(all(not(target_arch = "wasm32"), feature = "lmdb"))]
Expand Down Expand Up @@ -945,12 +947,12 @@ pub trait ResultIter: BorrowResultIter + Iterator<Item = Result<Tuple, DatabaseE
/// Converts this iterator into a typed ORM iterator.
///
/// This is available when the `orm` feature is enabled and the target type
/// implements `From<(&SchemaView, Tuple)>`, which is typically generated by
/// `#[derive(Model)]`.
/// implements `FromQueryRow`, which is typically generated by
/// `#[derive(Model)]` or `#[derive(Projection)]`.
fn orm<T>(self) -> OrmIter<Self, T>
where
Self: Sized,
T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>,
T: FromQueryRow,
{
OrmIter::new(self)
}
Expand All @@ -969,7 +971,7 @@ pub struct OrmIter<I, T> {
impl<I, T> OrmIter<I, T>
where
I: ResultIter,
T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>,
T: FromQueryRow,
{
fn new(inner: I) -> Self {
Self {
Expand All @@ -993,7 +995,7 @@ where
impl<I, T> Iterator for OrmIter<I, T>
where
I: ResultIter,
T: for<'view, 'schema, 'arena> From<(&'view SchemaView<'schema, 'arena>, Tuple)>,
T: FromQueryRow,
{
type Item = Result<T, DatabaseError>;

Expand All @@ -1002,7 +1004,7 @@ where
Ok(tuple) => tuple,
Err(err) => return Some(Err(err)),
};
Some(Ok(self.inner.schema(|schema| T::from((schema, tuple)))))
Some(self.inner.schema(|schema| T::from_query_row(schema, tuple)))
}
}

Expand Down
Loading
Loading