From 3ee2958ac5930113fdd4e6d0f642d26218f19a86 Mon Sep 17 00:00:00 2001 From: Juris Krumgolds Date: Sat, 27 Jul 2024 20:59:00 +0300 Subject: [PATCH] Create schemas for postgres schema generator --- tortoise/backends/base/schema_generator.py | 55 ++++++++++++++++--- .../base_postgres/schema_generator.py | 33 +++++++++-- 2 files changed, 75 insertions(+), 13 deletions(-) diff --git a/tortoise/backends/base/schema_generator.py b/tortoise/backends/base/schema_generator.py index 00ecc5d29..caa37c91b 100644 --- a/tortoise/backends/base/schema_generator.py +++ b/tortoise/backends/base/schema_generator.py @@ -17,13 +17,14 @@ class BaseSchemaGenerator: DIALECT = "sql" - TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}"{table_name}" ({fields}){extra}{comment};' + SCHEMA_CREATE_TEMPLATE = '' + TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}{schema_name}"{table_name}" ({fields}){extra}{comment};' FIELD_TEMPLATE = '"{name}" {type} {nullable} {unique}{primary}{default}{comment}' INDEX_CREATE_TEMPLATE = 'CREATE INDEX {exists}"{index_name}" ON "{table_name}" ({fields});' UNIQUE_INDEX_CREATE_TEMPLATE = INDEX_CREATE_TEMPLATE.replace(" INDEX", " UNIQUE INDEX") UNIQUE_CONSTRAINT_CREATE_TEMPLATE = 'CONSTRAINT "{index_name}" UNIQUE ({fields})' GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}{comment}' - FK_TEMPLATE = ' REFERENCES "{table}" ("{field}") ON DELETE {on_delete}{comment}' + FK_TEMPLATE = ' REFERENCES {schema_name}"{table}" ("{field}") ON DELETE {on_delete}{comment}' M2M_TABLE_TEMPLATE = ( 'CREATE TABLE {exists}"{table_name}" (\n' ' "{backward_key}" {backward_type} NOT NULL{backward_fk},\n' @@ -60,13 +61,14 @@ def _create_fk_string( self, constraint_name: str, db_column: str, + schema_name: str, table: str, field: str, on_delete: str, comment: str, ) -> str: return self.FK_TEMPLATE.format( - db_column=db_column, table=table, field=field, on_delete=on_delete, comment=comment + db_column=db_column, schema_name=schema_name, table=table, field=field, on_delete=on_delete, comment=comment ) def _table_comment_generator(self, table: str, comment: str) -> str: @@ -91,7 +93,7 @@ def _escape_default_value(self, default: Any): # needs to be implemented for each supported client raise NotImplementedError() - def _column_comment_generator(self, table: str, column: str, comment: str) -> str: + def _column_comment_generator(self, schema_name, table: str, column: str, comment: str) -> str: # Databases have their own way of supporting comments for column level # needs to be implemented for each supported client raise NotImplementedError() # pragma: nocoverage @@ -185,6 +187,9 @@ def _get_unique_constraint_sql(self, model: "Type[Model]", field_names: List[str index_name=self._generate_index_name("uid", model, field_names), fields=", ".join([self.quote(f) for f in field_names]), ) + + def _get_schema_name(self, model: "Type[Model]") -> str: + return "" def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: fields_to_create = [] @@ -196,10 +201,14 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: self._get_models_to_create(models_to_create) models_tables = [model._meta.db_table for model in models_to_create] for field_name, column_name in model._meta.fields_db_projection.items(): + schema_name = self._get_schema_name(model) field_object = model._meta.fields_map[field_name] comment = ( self._column_comment_generator( - table=model._meta.db_table, column=column_name, comment=field_object.description + schema_name=schema_name, + table=model._meta.db_table, + column=column_name, + comment=field_object.description ) if field_object.description else "" @@ -259,13 +268,14 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: if not to_field_name: to_field_name = reference.to_field_instance.model_field_name + schema_name = self._get_schema_name(reference.related_model) field_creation_string = self._create_string( db_column=column_name, field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), nullable=nullable, unique=unique, is_primary_key=field_object.pk, - comment=comment if not reference.db_constraint else "", + comment="", default=default, ) + ( self._create_fk_string( @@ -276,6 +286,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: to_field_name, ), db_column=column_name, + schema_name=schema_name, table=reference.related_model._meta.db_table, field=to_field_name, on_delete=reference.on_delete, @@ -335,16 +346,22 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: fields_to_create.extend(self._get_inner_statements()) table_fields_string = "\n {}\n".format(",\n ".join(fields_to_create)) + schema_name = self._get_schema_name(model) table_comment = ( self._table_comment_generator( - table=model._meta.db_table, comment=model._meta.table_description + schema_name=schema_name, + table=model._meta.db_table, + comment=model._meta.table_description ) if model._meta.table_description else "" ) + schema_name = self._get_schema_name(model) + table_create_string = self.TABLE_CREATE_TEMPLATE.format( exists="IF NOT EXISTS " if safe else "", + schema_name=schema_name, table_name=model._meta.db_table, fields=table_fields_string, comment=table_comment, @@ -362,9 +379,12 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: backward_key, forward_key = field_object.backward_key, field_object.forward_key backward_fk = forward_fk = "" if field_object.db_constraint: + backward_schema_name = self._get_schema_name(model) + forward_schema_name = self._get_schema_name(field_object.related_model) backward_fk = self._create_fk_string( "", backward_key, + backward_schema_name, model._meta.db_table, model._meta.db_pk_column, field_object.on_delete, @@ -373,6 +393,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: forward_fk = self._create_fk_string( "", forward_key, + forward_schema_name, field_object.related_model._meta.db_table, field_object.related_model._meta.db_pk_column, field_object.on_delete, @@ -382,6 +403,7 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: table_name = field_object.through m2m_create_string = self.M2M_TABLE_TEMPLATE.format( exists=exists, + schema_name=schema_name, table_name=table_name, backward_fk=backward_fk, forward_fk=forward_fk, @@ -394,7 +416,9 @@ def _get_table_sql(self, model: "Type[Model]", safe: bool = True) -> dict: extra=self._table_generate_extra(table=field_object.through), comment=( self._table_comment_generator( - table=field_object.through, comment=field_object.description + schema_name=schema_name, + table=field_object.through, + comment=field_object.description ) if field_object.description else "" @@ -439,17 +463,30 @@ def _get_models_to_create(self, models_to_create: "List[Type[Model]]") -> None: model._check() models_to_create.append(model) + def _get_schemas_to_create(self, models_to_create, schemas_to_create: "List[str]") -> None: + pass + def get_create_schema_sql(self, safe: bool = True) -> str: models_to_create: "List[Type[Model]]" = [] self._get_models_to_create(models_to_create) + schema_names_to_create = [] + self._get_schemas_to_create(models_to_create, schema_names_to_create) + tables_to_create = [] for model in models_to_create: tables_to_create.append(self._get_table_sql(model, safe)) tables_to_create_count = len(tables_to_create) + schemas_to_create: List[str] = [] + + for schema_name in schema_names_to_create: + schemas_to_create.append(self.SCHEMA_CREATE_TEMPLATE.format( + schema_name = schema_name + )) + created_tables: Set[dict] = set() ordered_tables_for_create: List[str] = [] m2m_tables_to_create: List[str] = [] @@ -469,7 +506,7 @@ def get_create_schema_sql(self, safe: bool = True) -> str: ordered_tables_for_create.append(next_table_for_create["table_creation_string"]) m2m_tables_to_create += next_table_for_create["m2m_tables"] - schema_creation_string = "\n".join(ordered_tables_for_create + m2m_tables_to_create) + schema_creation_string = "\n".join(schemas_to_create + ordered_tables_for_create + m2m_tables_to_create) return schema_creation_string async def generate_from_string(self, creation_string: str) -> None: diff --git a/tortoise/backends/base_postgres/schema_generator.py b/tortoise/backends/base_postgres/schema_generator.py index 556892396..2a444321e 100644 --- a/tortoise/backends/base_postgres/schema_generator.py +++ b/tortoise/backends/base_postgres/schema_generator.py @@ -9,8 +9,16 @@ class BasePostgresSchemaGenerator(BaseSchemaGenerator): DIALECT = "postgres" - TABLE_COMMENT_TEMPLATE = "COMMENT ON TABLE \"{table}\" IS '{comment}';" - COLUMN_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table}"."{column}" IS \'{comment}\';' + SCHEMA_CREATE_TEMPLATE = 'CREATE SCHEMA IF NOT EXISTS "{schema_name}";' + TABLE_CREATE_TEMPLATE = 'CREATE TABLE {exists}{schema_name}"{table_name}" ({fields}){extra}{comment};' + M2M_TABLE_TEMPLATE = ( + 'CREATE TABLE {exists}{schema_name}"{table_name}" (\n' + ' "{backward_key}" {backward_type} NOT NULL{backward_fk},\n' + ' "{forward_key}" {forward_type} NOT NULL{forward_fk}\n' + "){extra}{comment};" + ) + TABLE_COMMENT_TEMPLATE = "COMMENT ON TABLE {schema_name}\"{table}\" IS '{comment}';" + COLUMN_COMMENT_TEMPLATE = 'COMMENT ON COLUMN {schema_name}"{table}"."{column}" IS \'{comment}\';' GENERATED_PK_TEMPLATE = '"{field_name}" {generated_sql}' def __init__(self, client: "BasePostgresClient") -> None: @@ -30,9 +38,12 @@ def _table_comment_generator(self, table: str, comment: str) -> str: self.comments_array.append(comment) return "" - def _column_comment_generator(self, table: str, column: str, comment: str) -> str: + def _column_comment_generator(self, schema_name, table: str, column: str, comment: str) -> str: comment = self.COLUMN_COMMENT_TEMPLATE.format( - table=table, column=column, comment=self._escape_comment(comment) + schema_name=schema_name, + table=table, + column=column, + comment=self._escape_comment(comment) ) if comment not in self.comments_array: self.comments_array.append(comment) @@ -61,3 +72,17 @@ def _escape_default_value(self, default: Any): if isinstance(default, bool): return default return encoders.get(type(default))(default) # type: ignore + + def _get_schema_name(self, model: "Type[Model]") -> str: + schema_name = "" + if model._meta.schema and model._meta.schema != 'public': + schema_name = f'"{model._meta.schema}".' + return schema_name + + def _get_schemas_to_create(self, models_to_create, schemas_to_create: "List[String]") -> None: + for model in models_to_create: + schema_name = "" + if model._meta.schema and model._meta.schema != 'public': + schema_name = model._meta.schema + if schema_name and schema_name not in schemas_to_create: + schemas_to_create.append(schema_name) \ No newline at end of file