Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create schemas for postgres schema generator #1682

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions tortoise/backends/base/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

class BaseSchemaGenerator:
DIALECT = "sql"
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};'
SCHEMA_CREATE_TEMPLATE = ''
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}{schema_name}"{table_name}" ({fields}){extra}{comment};'
FIELD_TEMPLATE = '"{name}" {type} {nullable} {unique}{primary}{default}{comment}'
INDEX_CREATE_TEMPLATE = 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" ({fields});'
UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace(" INDEX", " UNIQUE INDEX")
UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})'
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}'
FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}'
FK_TEMPLATE = ' REFERENCES {schema_name}"{table}" ("{field}") ON DELETE {on_delete}{comment}'
M2M_TABLE_TEMPLATE = (
'CREATE TABLE {exists}"{table_name}" (\n'
' "{backward_key}" {backward_type} NOT NULL{backward_fk},\n'
Expand Down Expand Up @@ -60,13 +61,14 @@ def _create_fk_string(
self,
constraint_name: str,
db_column: str,
schema_name: str,
table: str,
field: str,
on_delete: str,
comment: str,
) -> str:
return self.FK_TEMPLATE.format(
db_column=db_column, table=table, field=field, on_delete=on_delete, comment=comment
db_column=db_column, schema_name=schema_name, table=table, field=field, on_delete=on_delete, comment=comment
)

def _table_comment_generator(self, table: str, comment: str) -> str:
Expand All @@ -91,7 +93,7 @@ def _escape_default_value(self, default: Any):
# needs to be implemented for each supported client
raise NotImplementedError()

def _column_comment_generator(self, table: str, column: str, comment: str) -> str:
def _column_comment_generator(self, schema_name, table: str, column: str, comment: str) -> str:
# Databases have their own way of supporting comments for column level
# needs to be implemented for each supported client
raise NotImplementedError() # pragma: nocoverage
Expand Down Expand Up @@ -185,6 +187,9 @@ def _get_unique_constraint_sql(self, model: "Type[Model]", field_names: List[str
index_name=self._generate_index_name("uid", model, field_names),
fields=", ".join([self.quote(f) for f in field_names]),
)

def _get_schema_name(self, model: "Type[Model]") -> str:
return ""

def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
fields_to_create = []
Expand All @@ -196,10 +201,14 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
self._get_models_to_create(models_to_create)
models_tables = [model._meta.db_table for model in models_to_create]
for field_name, column_name in model._meta.fields_db_projection.items():
schema_name = self._get_schema_name(model)
field_object = model._meta.fields_map[field_name]
comment = (
self._column_comment_generator(
table=model._meta.db_table, column=column_name, comment=field_object.description
schema_name=schema_name,
table=model._meta.db_table,
column=column_name,
comment=field_object.description
)
if field_object.description
else ""
Expand Down Expand Up @@ -259,13 +268,14 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
if not to_field_name:
to_field_name = reference.to_field_instance.model_field_name

schema_name = self._get_schema_name(reference.related_model)
field_creation_string = self._create_string(
db_column=column_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable=nullable,
unique=unique,
is_primary_key=field_object.pk,
comment=comment if not reference.db_constraint else "",
comment="",
default=default,
) + (
self._create_fk_string(
Expand All @@ -276,6 +286,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
to_field_name,
),
db_column=column_name,
schema_name=schema_name,
table=reference.related_model._meta.db_table,
field=to_field_name,
on_delete=reference.on_delete,
Expand Down Expand Up @@ -335,16 +346,22 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
fields_to_create.extend(self._get_inner_statements())

table_fields_string = "\n {}\n".format(",\n ".join(fields_to_create))
schema_name = self._get_schema_name(model)
table_comment = (
self._table_comment_generator(
table=model._meta.db_table, comment=model._meta.table_description
schema_name=schema_name,
table=model._meta.db_table,
comment=model._meta.table_description
)
if model._meta.table_description
else ""
)

schema_name = self._get_schema_name(model)

table_create_string = self.TABLE_CREATE_TEMPLATE.format(
exists="IF NOT EXISTS " if safe else "",
schema_name=schema_name,
table_name=model._meta.db_table,
fields=table_fields_string,
comment=table_comment,
Expand All @@ -362,9 +379,12 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
backward_key, forward_key = field_object.backward_key, field_object.forward_key
backward_fk = forward_fk = ""
if field_object.db_constraint:
backward_schema_name = self._get_schema_name(model)
forward_schema_name = self._get_schema_name(field_object.related_model)
backward_fk = self._create_fk_string(
"",
backward_key,
backward_schema_name,
model._meta.db_table,
model._meta.db_pk_column,
field_object.on_delete,
Expand All @@ -373,6 +393,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
forward_fk = self._create_fk_string(
"",
forward_key,
forward_schema_name,
field_object.related_model._meta.db_table,
field_object.related_model._meta.db_pk_column,
field_object.on_delete,
Expand All @@ -382,6 +403,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
table_name = field_object.through
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
exists=exists,
schema_name=schema_name,
table_name=table_name,
backward_fk=backward_fk,
forward_fk=forward_fk,
Expand All @@ -394,7 +416,9 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict:
extra=self._table_generate_extra(table=field_object.through),
comment=(
self._table_comment_generator(
table=field_object.through, comment=field_object.description
schema_name=schema_name,
table=field_object.through,
comment=field_object.description
)
if field_object.description
else ""
Expand Down Expand Up @@ -439,17 +463,30 @@ def _get_models_to_create(self, models_to_create: "List[Type[Model]]") -> None:
model._check()
models_to_create.append(model)

def _get_schemas_to_create(self, models_to_create, schemas_to_create: "List[str]") -> None:
pass

def get_create_schema_sql(self, safe: bool = True) -> str:
models_to_create: "List[Type[Model]]" = []

self._get_models_to_create(models_to_create)

schema_names_to_create = []
self._get_schemas_to_create(models_to_create, schema_names_to_create)

tables_to_create = []
for model in models_to_create:
tables_to_create.append(self._get_table_sql(model, safe))

tables_to_create_count = len(tables_to_create)

schemas_to_create: List[str] = []

for schema_name in schema_names_to_create:
schemas_to_create.append(self.SCHEMA_CREATE_TEMPLATE.format(
schema_name = schema_name
))

created_tables: Set[dict] = set()
ordered_tables_for_create: List[str] = []
m2m_tables_to_create: List[str] = []
Expand All @@ -469,7 +506,7 @@ def get_create_schema_sql(self, safe: bool = True) -> str:
ordered_tables_for_create.append(next_table_for_create["table_creation_string"])
m2m_tables_to_create += next_table_for_create["m2m_tables"]

schema_creation_string = "\n".join(ordered_tables_for_create + m2m_tables_to_create)
schema_creation_string = "\n".join(schemas_to_create + ordered_tables_for_create + m2m_tables_to_create)
return schema_creation_string

async def generate_from_string(self, creation_string: str) -> None:
Expand Down
33 changes: 29 additions & 4 deletions tortoise/backends/base_postgres/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@

class BasePostgresSchemaGenerator(BaseSchemaGenerator):
DIALECT = "postgres"
TABLE_COMMENT_TEMPLATE = "COMMENT ON TABLE \"{table}\" IS '{comment}';"
COLUMN_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table}"."{column}" IS \'{comment}\';'
SCHEMA_CREATE_TEMPLATE = 'CREATE SCHEMA IF NOT EXISTS "{schema_name}";'
TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}{schema_name}"{table_name}" ({fields}){extra}{comment};'
M2M_TABLE_TEMPLATE = (
'CREATE TABLE {exists}{schema_name}"{table_name}" (\n'
' "{backward_key}" {backward_type} NOT NULL{backward_fk},\n'
' "{forward_key}" {forward_type} NOT NULL{forward_fk}\n'
"){extra}{comment};"
)
TABLE_COMMENT_TEMPLATE = "COMMENT ON TABLE {schema_name}\"{table}\" IS '{comment}';"
COLUMN_COMMENT_TEMPLATE = 'COMMENT ON COLUMN {schema_name}"{table}"."{column}" IS \'{comment}\';'
GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}'

def __init__(self, client: "BasePostgresClient") -> None:
Expand All @@ -30,9 +38,12 @@ def _table_comment_generator(self, table: str, comment: str) -> str:
self.comments_array.append(comment)
return ""

def _column_comment_generator(self, table: str, column: str, comment: str) -> str:
def _column_comment_generator(self, schema_name, table: str, column: str, comment: str) -> str:
comment = self.COLUMN_COMMENT_TEMPLATE.format(
table=table, column=column, comment=self._escape_comment(comment)
schema_name=schema_name,
table=table,
column=column,
comment=self._escape_comment(comment)
)
if comment not in self.comments_array:
self.comments_array.append(comment)
Expand Down Expand Up @@ -61,3 +72,17 @@ def _escape_default_value(self, default: Any):
if isinstance(default, bool):
return default
return encoders.get(type(default))(default) # type: ignore

def _get_schema_name(self, model: "Type[Model]") -> str:
schema_name = ""
if model._meta.schema and model._meta.schema != 'public':
schema_name = f'"{model._meta.schema}".'
return schema_name

def _get_schemas_to_create(self, models_to_create, schemas_to_create: "List[String]") -> None:
for model in models_to_create:
schema_name = ""
if model._meta.schema and model._meta.schema != 'public':
schema_name = model._meta.schema
if schema_name and schema_name not in schemas_to_create:
schemas_to_create.append(schema_name)
Loading