Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Escape table name if it is a reserved keyword #26

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ormx-macros/src/backend/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(crate) fn getters<B: Backend>(table: &Table<B>) -> TokenStream {
let sql = format!(
"SELECT {} FROM {} WHERE {} = {}",
column_list,
table.table,
table.name(),
field.column(),
B::Bindings::default().next().unwrap()
);
Expand Down Expand Up @@ -103,7 +103,7 @@ pub fn setters<B: Backend>(table: &Table<B>) -> TokenStream {
let mut bindings = B::Bindings::default();
let sql = format!(
"UPDATE {} SET {} = {} WHERE {} = {}",
table.table,
table.name(),
field.column(),
bindings.next().unwrap(),
table.id.column(),
Expand Down Expand Up @@ -141,7 +141,7 @@ pub fn setters<B: Backend>(table: &Table<B>) -> TokenStream {
}
}

pub(crate) fn impl_patch<B: Backend>(patch: &Patch) -> TokenStream {
pub(crate) fn impl_patch<B: Backend>(patch: &Patch<B>) -> TokenStream {
let patch_ident = &patch.ident;
let table_path = &patch.table;
let field_idents = &patch
Expand All @@ -165,7 +165,7 @@ pub(crate) fn impl_patch<B: Backend>(patch: &Patch) -> TokenStream {

let sql = format!(
"UPDATE {} SET {} WHERE {} = {}",
&patch.table_name,
&patch.table_name(),
assignments,
patch.id,
bindings.next().unwrap()
Expand Down
10 changes: 5 additions & 5 deletions ormx-macros/src/backend/common/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn get<B: Backend>(table: &Table<B>, column_list: &str) -> TokenStream {
let get_sql = format!(
"SELECT {} FROM {} WHERE {} = {}",
column_list,
table.table,
table.name(),
table.id.column(),
B::Bindings::default().next().unwrap()
);
Expand Down Expand Up @@ -82,7 +82,7 @@ fn update<B: Backend>(table: &Table<B>) -> TokenStream {

let update_sql = format!(
"UPDATE {} SET {} WHERE {} = {}",
table.table,
table.name(),
assignments,
table.id.column(),
bindings.next().unwrap()
Expand All @@ -107,7 +107,7 @@ fn update<B: Backend>(table: &Table<B>) -> TokenStream {

fn stream_all<B: Backend>(table: &Table<B>, column_list: &str) -> TokenStream {
let box_stream = crate::utils::box_stream();
let all_sql = format!("SELECT {} FROM {}", column_list, table.table);
let all_sql = format!("SELECT {} FROM {}", column_list, table.name());

quote! {
fn stream_all<'a, 'c: 'a>(
Expand All @@ -125,7 +125,7 @@ fn stream_all_paginated<B: Backend>(table: &Table<B>, column_list: &str) -> Toke
let all_sql = format!(
"SELECT {} FROM {} LIMIT {} OFFSET {}",
column_list,
table.table,
table.name(),
bindings.next().unwrap(),
bindings.next().unwrap()
);
Expand All @@ -147,7 +147,7 @@ fn delete<B: Backend>(table: &Table<B>) -> TokenStream {
let id_ty = &table.id.ty;
let delete_sql = format!(
"DELETE FROM {} WHERE {} = {}",
table.table,
table.name(),
table.id.column(),
B::Bindings::default().next().unwrap()
);
Expand Down
2 changes: 1 addition & 1 deletion ormx-macros/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub trait Backend: Sized + Clone {
}

/// Implement [Patch]
fn impl_patch(patch: &Patch) -> TokenStream {
fn impl_patch(patch: &Patch<Self>) -> TokenStream {
common::impl_patch::<Self>(patch)
}
}
4 changes: 2 additions & 2 deletions ormx-macros/src/backend/mysql/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ fn query_default(table: &Table<MySqlBackend>) -> TokenStream {
let query_default_sql = format!(
"SELECT {} FROM {} WHERE {} = ?",
default_fields.map(TableField::fmt_for_select).join(", "),
table.table,
table.name(),
table.id.column()
);

Expand All @@ -97,7 +97,7 @@ fn insert(table: &Table<MySqlBackend>) -> TokenStream {

let insert_sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
table.table,
table.name(),
insert_fields.iter().map(|field| field.column()).join(", "),
MySqlBindings.take(insert_fields.len()).join(", ")
);
Expand Down
4 changes: 2 additions & 2 deletions ormx-macros/src/backend/postgres/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ fn insert_sql(table: &Table<PgBackend>, insert_fields: &[&TableField<PgBackend>]
if returning_fields.is_empty() {
format!(
"INSERT INTO {} ({}) VALUES ({})",
table.table, columns, fields
table.name(), columns, fields
)
} else {
format!(
"INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
table.table, columns, fields, returning_fields
table.name(), columns, fields, returning_fields
)
}
}
Expand Down
16 changes: 14 additions & 2 deletions ormx-macros/src/patch/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::convert::TryFrom;
use std::{convert::TryFrom, borrow::Cow, marker::PhantomData};

use proc_macro2::TokenStream;
use quote::quote;
Expand All @@ -8,12 +8,14 @@ use crate::backend::{Backend, Implementation};

mod parse;

pub struct Patch {
pub struct Patch<B: Backend> {
pub ident: Ident,
pub table_name: String,
pub reserved_table_name: bool,
pub table: Path,
pub id: String,
pub fields: Vec<PatchField>,
pub _phantom: PhantomData<*const B>,
}

pub struct PatchField {
Expand All @@ -24,6 +26,16 @@ pub struct PatchField {
pub by_ref: bool,
}

impl<B: Backend> Patch<B> {
pub fn table_name(&self) -> Cow<str> {
if self.reserved_table_name {
format!("{}{}{}", B::QUOTE, self.table_name, B::QUOTE).into()
} else {
Cow::Borrowed(&self.table_name)
}
}
}

impl PatchField {
pub fn fmt_as_argument(&self) -> TokenStream {
let ident = &self.ident;
Expand Down
13 changes: 9 additions & 4 deletions ormx-macros/src/patch/parse.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use std::convert::TryFrom;
use std::{convert::TryFrom, marker::PhantomData};

use syn::{Data, DeriveInput, Error, Field, Result};

use super::Patch;
use crate::{
attrs::{parse_attrs, PatchAttr, PatchFieldAttr},
patch::PatchField,
utils::{missing_attr, set_once},
utils::{missing_attr, set_once}, backend::Backend,
};

impl TryFrom<&syn::DeriveInput> for Patch {
impl<B: Backend> TryFrom<&syn::DeriveInput> for Patch<B> {
type Error = Error;

fn try_from(value: &DeriveInput) -> Result<Self> {
Expand All @@ -35,12 +35,17 @@ impl TryFrom<&syn::DeriveInput> for Patch {
}
}

let table_name = table_name.ok_or_else(|| missing_attr("table_name"))?;
let reserved_table_name = B::RESERVED_IDENTS.contains(&&*table_name.to_string().to_uppercase());

Ok(Patch {
ident: value.ident.clone(),
table_name: table_name.ok_or_else(|| missing_attr("table_name"))?,
table_name,
reserved_table_name,
table: table.ok_or_else(|| missing_attr("table"))?,
id: id.ok_or_else(|| missing_attr("id"))?,
fields,
_phantom: PhantomData,
})
}
}
Expand Down
9 changes: 9 additions & 0 deletions ormx-macros/src/table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct Table<B: Backend> {
pub ident: Ident,
pub vis: Visibility,
pub table: String,
pub reserved_table_name: bool,
pub id: TableField<B>,
pub fields: Vec<TableField<B>>,
pub insertable: Option<Insertable>,
Expand Down Expand Up @@ -59,6 +60,14 @@ impl<B: Backend> Table<B> {
.map(|field| field.fmt_for_select())
.join(", ")
}

pub fn name(&self) -> Cow<str> {
if self.reserved_table_name {
format!("{}{}{}", B::QUOTE, self.table, B::QUOTE).into()
} else {
Cow::Borrowed(&self.table)
}
}
}

impl<B: Backend> TableField<B> {
Expand Down
12 changes: 11 additions & 1 deletion ormx-macros/src/table/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,20 @@ impl<B: Backend> TryFrom<&syn::DeriveInput> for Table<B> {
));
}

let table = table.ok_or_else(|| missing_attr("table"))?;
let reserved_table_name = B::RESERVED_IDENTS.contains(&&*table.to_string().to_uppercase());
if reserved_table_name {
proc_macro_error::emit_warning!(
Span::call_site(),
"This table name is a reserved keyword, you might want to consider choosing a different name."
);
}

Ok(Table {
ident: value.ident.clone(),
vis: value.vis.clone(),
table: table.ok_or_else(|| missing_attr("table"))?,
table,
reserved_table_name,
id,
insertable,
fields,
Expand Down