From 21376245e49ec9417e7c2b9e385ad371cb76cbd6 Mon Sep 17 00:00:00 2001 From: Gleb Popov <6yearold@gmail.com> Date: Wed, 1 Feb 2023 19:54:01 +0300 Subject: [PATCH 1/6] Introduce entitySchema field into EntityDef --- persistent/Database/Persist/EntityDef.hs | 16 ++++++++++++++++ persistent/Database/Persist/Quasi/Internal.hs | 1 + persistent/Database/Persist/Types/Base.hs | 2 ++ 3 files changed, 19 insertions(+) diff --git a/persistent/Database/Persist/EntityDef.hs b/persistent/Database/Persist/EntityDef.hs index 4e2fe93fc..cebef79df 100644 --- a/persistent/Database/Persist/EntityDef.hs +++ b/persistent/Database/Persist/EntityDef.hs @@ -9,6 +9,7 @@ module Database.Persist.EntityDef -- * Accessors , getEntityHaskellName , getEntityDBName + , getEntityDBSchema , getEntityFields , getEntityFieldsDatabase , getEntityForeignDefs @@ -27,6 +28,7 @@ module Database.Persist.EntityDef , setEntityId , setEntityIdDef , setEntityDBName + , setEntityDBSchema , overEntityFields -- * Related Types , EntityIdDef(..) @@ -86,6 +88,14 @@ getEntityDBName -> EntityNameDB getEntityDBName = entityDB +-- | Return the database schema name for the given entity. +-- +-- @since XXX +getEntityDBSchema + :: EntityDef + -> Maybe Text +getEntityDBSchema = entitySchema + getEntityExtra :: EntityDef -> Map Text [[Text]] getEntityExtra = entityExtra @@ -98,6 +108,12 @@ setEntityDBName db ed = ed { entityDB = db } getEntityComments :: EntityDef -> Maybe Text getEntityComments = entityComments +-- | Sets or resets the database schema name for the given entity. +-- +-- @since XXX +setEntityDBSchema :: Maybe Text -> EntityDef -> EntityDef +setEntityDBSchema schema entDef = entDef { entitySchema = schema } + -- | -- -- @since 2.13.0.0 diff --git a/persistent/Database/Persist/Quasi/Internal.hs b/persistent/Database/Persist/Quasi/Internal.hs index 8f67e991b..668f3c5a7 100644 --- a/persistent/Database/Persist/Quasi/Internal.hs +++ b/persistent/Database/Persist/Quasi/Internal.hs @@ -693,6 +693,7 @@ mkUnboundEntityDef ps parsedEntDef = EntityDef { entityHaskell = entNameHS , entityDB = entNameDB + , entitySchema = Nothing -- idField is the user-specified Id -- otherwise useAutoIdField -- but, adjust it if the user specified a Primary diff --git a/persistent/Database/Persist/Types/Base.hs b/persistent/Database/Persist/Types/Base.hs index b17def38b..364874183 100644 --- a/persistent/Database/Persist/Types/Base.hs +++ b/persistent/Database/Persist/Types/Base.hs @@ -131,6 +131,8 @@ data EntityDef = EntityDef -- ^ The name of the entity as Haskell understands it. , entityDB :: !EntityNameDB -- ^ The name of the database table corresponding to the entity. + , entitySchema :: !(Maybe Text) + -- ^ The schema name of the database table. , entityId :: !EntityIdDef -- ^ The entity's primary key or identifier. , entityAttrs :: ![Attr] From 9fe94a2703531f0dbe5f7fd22ff2ac855fb572e3 Mon Sep 17 00:00:00 2001 From: Gleb Popov <6yearold@gmail.com> Date: Wed, 1 Feb 2023 19:07:39 +0300 Subject: [PATCH 2/6] Populate entitySchema with entity-level "schema=..." attribute --- persistent/Database/Persist/Quasi/Internal.hs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/persistent/Database/Persist/Quasi/Internal.hs b/persistent/Database/Persist/Quasi/Internal.hs index 668f3c5a7..9f6f75045 100644 --- a/persistent/Database/Persist/Quasi/Internal.hs +++ b/persistent/Database/Persist/Quasi/Internal.hs @@ -693,7 +693,7 @@ mkUnboundEntityDef ps parsedEntDef = EntityDef { entityHaskell = entNameHS , entityDB = entNameDB - , entitySchema = Nothing + , entitySchema = getSchemaName (parsedEntityDefEntityAttributes parsedEntDef) -- idField is the user-specified Id -- otherwise useAutoIdField -- but, adjust it if the user specified a Primary @@ -940,6 +940,9 @@ getDbName :: PersistSettings -> Text -> [Text] -> Text getDbName ps n = fromMaybe (psToDBName ps n) . listToMaybe . mapMaybe (T.stripPrefix "sql=") +getSchemaName :: [Attr] -> Maybe Text +getSchemaName = listToMaybe . mapMaybe (T.stripPrefix "schema=") + getDbName' :: PersistSettings -> Text -> [FieldAttr] -> FieldNameDB getDbName' ps n = getSqlNameOr (FieldNameDB $ psToDBName ps n) From 8970320cb4f32235d7f235d102a6c0c409e1b76f Mon Sep 17 00:00:00 2001 From: Gleb Popov <6yearold@gmail.com> Date: Wed, 1 Feb 2023 19:13:31 +0300 Subject: [PATCH 3/6] Introduce backendSpecificSchemaEntityName override. --- persistent/Database/Persist/Sql.hs | 2 ++ persistent/Database/Persist/Sql/Internal.hs | 34 +++++++++++++++++++-- persistent/Database/Persist/Sql/Types.hs | 1 + 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/persistent/Database/Persist/Sql.hs b/persistent/Database/Persist/Sql.hs index 32bba1021..4abcac19a 100644 --- a/persistent/Database/Persist/Sql.hs +++ b/persistent/Database/Persist/Sql.hs @@ -52,6 +52,8 @@ module Database.Persist.Sql , emptyBackendSpecificOverrides , getBackendSpecificForeignKeyName , setBackendSpecificForeignKeyName + , getBackendSpecificSchemaEntityName + , setBackendSpecificSchemaEntityName , defaultAttribute -- * Internal , IsolationLevel(..) diff --git a/persistent/Database/Persist/Sql/Internal.hs b/persistent/Database/Persist/Sql/Internal.hs index c8e099fee..9b8041c10 100644 --- a/persistent/Database/Persist/Sql/Internal.hs +++ b/persistent/Database/Persist/Sql/Internal.hs @@ -10,6 +10,8 @@ module Database.Persist.Sql.Internal , BackendSpecificOverrides(..) , getBackendSpecificForeignKeyName , setBackendSpecificForeignKeyName + , getBackendSpecificSchemaEntityName + , setBackendSpecificSchemaEntityName , emptyBackendSpecificOverrides ) where @@ -35,6 +37,7 @@ import Database.Persist.Types -- @since 2.11 data BackendSpecificOverrides = BackendSpecificOverrides { backendSpecificForeignKeyName :: Maybe (EntityNameDB -> FieldNameDB -> ConstraintNameDB) + , backendSpecificSchemaEntityName :: Maybe (Maybe Text -> EntityNameDB -> EntityNameDB) } -- | If the override is defined, then this returns a function that accepts an @@ -60,6 +63,28 @@ setBackendSpecificForeignKeyName setBackendSpecificForeignKeyName func bso = bso { backendSpecificForeignKeyName = Just func } +-- | If the override is defined, then this returns a function that accepts an +-- entity name and its schema name and provides the fully qualified entity name. +-- +-- An abstract accessor for the 'BackendSpecificOverrides' +-- +-- @since XXX +getBackendSpecificSchemaEntityName + :: BackendSpecificOverrides + -> Maybe (Maybe Text -> EntityNameDB -> EntityNameDB) +getBackendSpecificSchemaEntityName = + backendSpecificSchemaEntityName + +-- | Set the backend's schema and entity combining function to this value. +-- +-- @since XXX +setBackendSpecificSchemaEntityName + :: (Maybe Text -> EntityNameDB -> EntityNameDB) + -> BackendSpecificOverrides + -> BackendSpecificOverrides +setBackendSpecificSchemaEntityName func bso = + bso { backendSpecificSchemaEntityName = Just func } + findMaybe :: (a -> Maybe b) -> [a] -> Maybe b findMaybe p = listToMaybe . mapMaybe p @@ -67,7 +92,7 @@ findMaybe p = listToMaybe . mapMaybe p -- -- @since 2.11 emptyBackendSpecificOverrides :: BackendSpecificOverrides -emptyBackendSpecificOverrides = BackendSpecificOverrides Nothing +emptyBackendSpecificOverrides = BackendSpecificOverrides Nothing Nothing defaultAttribute :: [FieldAttr] -> Maybe Text defaultAttribute = findMaybe $ \case @@ -133,7 +158,7 @@ mkColumns allDefs t overrides = } tableName :: EntityNameDB - tableName = getEntityDBName t + tableName = schemaEntityNameFn (getEntityDBSchema t) (getEntityDBName t) go :: FieldDef -> Column go fd = @@ -156,6 +181,11 @@ mkColumns allDefs t overrides = FieldAttrMaxlen n -> Just n _ -> Nothing + schemaEntityNameFn = flip fromMaybe (backendSpecificSchemaEntityName overrides) $ + \mbSchema entName -> case mbSchema of + Nothing -> entName + Just "" -> entName + Just schema -> EntityNameDB (schema <> "." <> unEntityNameDB entName) refNameFn = fromMaybe refName (backendSpecificForeignKeyName overrides) mkColumnReference :: FieldDef -> Maybe ColumnReference diff --git a/persistent/Database/Persist/Sql/Types.hs b/persistent/Database/Persist/Sql/Types.hs index a9f592d86..324897d0c 100644 --- a/persistent/Database/Persist/Sql/Types.hs +++ b/persistent/Database/Persist/Sql/Types.hs @@ -34,6 +34,7 @@ data Column = Column -- | This value specifies how a field references another table. -- -- @since 2.11.0.0 +-- TODO: what about schema name there? data ColumnReference = ColumnReference { crTableName :: !EntityNameDB -- ^ The table name that the From ddedf5d4c9eacb87df8b923798a46f909bcbbc5f Mon Sep 17 00:00:00 2001 From: Gleb Popov <6yearold@gmail.com> Date: Wed, 1 Feb 2023 19:51:54 +0300 Subject: [PATCH 4/6] Schema support for Sqlite backend. --- persistent-sqlite/Database/Persist/Sqlite.hs | 81 +++++++++++++------- 1 file changed, 54 insertions(+), 27 deletions(-) diff --git a/persistent-sqlite/Database/Persist/Sqlite.hs b/persistent-sqlite/Database/Persist/Sqlite.hs index ccfecd605..a0628d059 100644 --- a/persistent-sqlite/Database/Persist/Sqlite.hs +++ b/persistent-sqlite/Database/Persist/Sqlite.hs @@ -299,7 +299,7 @@ wrapConnectionInfo connInfo conn logFunc = do , connCommit = helper "COMMIT" , connRollback = ignoreExceptions . helper "ROLLBACK" , connEscapeFieldName = escape . unFieldNameDB - , connEscapeTableName = escape . unEntityNameDB . getEntityDBName + , connEscapeTableName = escapeS . schemaNamePair , connEscapeRawName = escape , connNoLimit = "LIMIT -1" , connRDBMS = "sqlite" @@ -358,7 +358,7 @@ insertSql' ent vals = ISRManyKeys sql vals where sql = T.concat [ "INSERT INTO " - , escapeE $ getEntityDBName ent + , escapeS $ schemaNamePair ent , "(" , T.intercalate "," $ map (escapeF . fieldDB) cols , ") VALUES(" @@ -372,12 +372,12 @@ insertSql' ent vals = [ "SELECT " , escapeF $ fieldDB fd , " FROM " - , escapeE $ getEntityDBName ent + , escapeS $ schemaNamePair ent , " WHERE _ROWID_=last_insert_rowid()" ] ins = T.concat [ "INSERT INTO " - , escapeE $ getEntityDBName ent + , escapeS $ schemaNamePair ent , if null cols then " VALUES(null)" else T.concat @@ -434,8 +434,14 @@ showSqlType SqlBlob = "BLOB" showSqlType SqlBool = "BOOLEAN" showSqlType (SqlOther t) = t +type SchemaEntityName = (Text, EntityNameDB) + sqliteMkColumns :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -sqliteMkColumns allDefs t = mkColumns allDefs t emptyBackendSpecificOverrides +sqliteMkColumns allDefs t = mkColumns allDefs t sqliteSpecificOverrides + +sqliteSpecificOverrides :: BackendSpecificOverrides +sqliteSpecificOverrides = setBackendSpecificSchemaEntityName (\schema name -> unescapeS $ schemaNamePair' schema name) + $ emptyBackendSpecificOverrides migrate' :: [EntityDef] @@ -444,9 +450,9 @@ migrate' -> IO (Either [Text] CautiousMigration) migrate' allDefs getter val = do let (cols, uniqs, fdefs) = sqliteMkColumns allDefs val - let newSql = mkCreateTable False def (filter (not . safeToRemove val . cName) cols, uniqs, fdefs) + let newSql = mkCreateTable def (filter (not . safeToRemove val . cName) cols, uniqs, fdefs) stmt <- getter "SELECT sql FROM sqlite_master WHERE type='table' AND name=?" - oldSql' <- with (stmtQuery stmt [PersistText $ unEntityNameDB table]) + oldSql' <- with (stmtQuery stmt [PersistText table]) (\src -> runConduit $ src .| go) case oldSql' of Nothing -> return $ Right [(False, newSql)] @@ -458,7 +464,7 @@ migrate' allDefs getter val = do return $ Right sql where def = val - table = getEntityDBName def + table = unEntityNameDB $ unescapeS $ schemaNamePair def go = do x <- CL.head case x of @@ -490,7 +496,7 @@ mockMigration mig = do , connCommit = helper "COMMIT" , connRollback = ignoreExceptions . helper "ROLLBACK" , connEscapeFieldName = escape . unFieldNameDB - , connEscapeTableName = escape . unEntityNameDB . getEntityDBName + , connEscapeTableName = escapeS . schemaNamePair , connEscapeRawName = escape , connNoLimit = "LIMIT -1" , connRDBMS = "sqlite" @@ -528,7 +534,7 @@ getCopyTable :: [EntityDef] -> EntityDef -> IO [(Bool, Text)] getCopyTable allDefs getter def = do - stmt <- getter $ T.concat [ "PRAGMA table_info(", escapeE table, ")" ] + stmt <- getter $ T.concat [ "PRAGMA table_info(", escapeS table, ")" ] oldCols' <- with (stmtQuery stmt []) (\src -> runConduit $ src .| getCols) let oldCols = map FieldNameDB oldCols' let newCols = filter (not . safeToRemove def) $ map cName cols @@ -549,42 +555,44 @@ getCopyTable allDefs getter def = do names <- getCols return $ name : names Just y -> error $ "Invalid result from PRAGMA table_info: " ++ show y - table = getEntityDBName def - tableTmp = EntityNameDB $ unEntityNameDB table <> "_backup" + defTmp = setEntityDBSchema (Just "temp") + $ setEntityDBName (escapeWith (EntityNameDB . (<> "_backup")) $ getEntityDBName def) + def + table = schemaNamePair def + tableTmp = schemaNamePair defTmp (cols, uniqs, fdef) = sqliteMkColumns allDefs def cols' = filter (not . safeToRemove def . cName) cols - newSql = mkCreateTable False def (cols', uniqs, fdef) - tmpSql = mkCreateTable True (setEntityDBName tableTmp def) (cols', uniqs, []) - dropTmp = "DROP TABLE " <> escapeE tableTmp - dropOld = "DROP TABLE " <> escapeE table + newSql = mkCreateTable def (cols', uniqs, fdef) + tmpSql = mkCreateTable defTmp (cols', uniqs, []) + dropTmp = "DROP TABLE " <> escapeS tableTmp + dropOld = "DROP TABLE " <> escapeS table copyToTemp common = T.concat [ "INSERT INTO " - , escapeE tableTmp + , escapeS tableTmp , "(" , T.intercalate "," $ map escapeF common , ") SELECT " , T.intercalate "," $ map escapeF common , " FROM " - , escapeE table + , escapeS table ] copyToFinal newCols = T.concat [ "INSERT INTO " - , escapeE table + , escapeS table , " SELECT " , T.intercalate "," $ map escapeF newCols , " FROM " - , escapeE tableTmp + , escapeS tableTmp ] -mkCreateTable :: Bool -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -> Text -mkCreateTable isTemp entity (cols, uniqs, fdefs) = +mkCreateTable :: EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -> Text +mkCreateTable entity (cols, uniqs, fdefs) = T.concat (header <> columns <> footer) where + isTemp = getEntityDBSchema entity == Just "temp" header = - [ "CREATE" - , if isTemp then " TEMP" else "" - , " TABLE " - , escapeE $ getEntityDBName entity + [ "CREATE TABLE " + , escapeS $ schemaNamePair entity , "(" ] @@ -678,6 +686,16 @@ sqlUnique (UniqueDef _ cname cols _) = T.concat , ")" ] +schemaNamePair :: EntityDef -> SchemaEntityName +schemaNamePair ent = schemaNamePair' (getEntityDBSchema ent) (getEntityDBName ent) + +schemaNamePair' :: Maybe Text -> EntityNameDB -> SchemaEntityName +schemaNamePair' mbSchema entName = case mbSchema of + Nothing -> ("", entName) + Just "main" -> ("", entName) + Just "temp" -> ("temp", entName) + Just schema -> ("", EntityNameDB $ (schema <> "_") <> unEntityNameDB entName) + escapeC :: ConstraintNameDB -> Text escapeC = escapeWith escape @@ -687,6 +705,15 @@ escapeE = escapeWith escape escapeF :: FieldNameDB -> Text escapeF = escapeWith escape +escapeS :: (Text, EntityNameDB) -> Text +escapeS ("", entDBName) = escapeE entDBName +-- no need to escape schema as it is either "" or "temp" +escapeS (schema, entDBName) = schema <> "." <> escapeE entDBName + +unescapeS :: (Text, EntityNameDB) -> EntityNameDB +unescapeS ("", entDBName) = entDBName +unescapeS (schema, entDBName) = EntityNameDB $ escapeWith ((schema <> ".") <>) entDBName + escape :: Text -> Text escape s = T.concat [q, T.concatMap go s, q] @@ -713,7 +740,7 @@ putManySql' conflictColumns fields ent n = q fieldDbToText = escapeF . fieldDB mkAssignment f = T.concat [f, "=EXCLUDED.", f] - table = escapeE . getEntityDBName $ ent + table = escapeS . schemaNamePair $ ent columns = Util.commaSeparated $ map fieldDbToText fields placeholders = map (const "?") fields updates = map (mkAssignment . fieldDbToText) fields From 3e8fd0ad73646aa3b55ea9b89d817a53fd293631 Mon Sep 17 00:00:00 2001 From: Gleb Popov <6yearold@gmail.com> Date: Thu, 26 Jan 2023 14:10:49 +0300 Subject: [PATCH 5/6] Schema support for PostgreSQL backend --- .../Database/Persist/Postgresql.hs | 160 +++++++++++------- persistent-postgresql/test/main.hs | 3 + 2 files changed, 104 insertions(+), 59 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index f6c130633..d4ae5b62d 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -437,7 +437,7 @@ createBackend logFunc serverVersion smap conn = , connCommit = const $ PG.commit conn , connRollback = const $ PG.rollback conn , connEscapeFieldName = escapeF - , connEscapeTableName = escapeE . getEntityDBName + , connEscapeTableName = \entDef -> escapeS $ schemaNamePair entDef , connEscapeRawName = escape , connNoLimit = "LIMIT ALL" , connRDBMS = "postgresql" @@ -466,7 +466,7 @@ insertSql' ent vals = (fieldNames, placeholders) = unzip (Util.mkInsertPlaceholders ent escapeF) sql = T.concat [ "INSERT INTO " - , escapeE $ getEntityDBName ent + , escapeS . schemaNamePair $ ent , if null (getEntityFields ent) then " DEFAULT VALUES" else T.concat @@ -482,7 +482,7 @@ upsertSql' :: EntityDef -> NonEmpty (FieldNameHS, FieldNameDB) -> Text -> Text upsertSql' ent uniqs updateVal = T.concat [ "INSERT INTO " - , escapeE (getEntityDBName ent) + , escapeS . schemaNamePair $ ent , "(" , T.intercalate "," fieldNames , ") VALUES (" @@ -501,7 +501,7 @@ upsertSql' ent uniqs updateVal = wher = T.intercalate " AND " $ map (singleClause . snd) $ NEL.toList uniqs singleClause :: FieldNameDB -> Text - singleClause field = escapeE (getEntityDBName ent) <> "." <> (escapeF field) <> " =?" + singleClause field = escapeS (schemaNamePair ent) <> "." <> escapeF field <> " =?" -- | SQL for inserting multiple rows at once and returning their primary keys. insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult @@ -511,7 +511,7 @@ insertManySql' ent valss = (fieldNames, placeholders)= unzip (Util.mkInsertPlaceholders ent escapeF) sql = T.concat [ "INSERT INTO " - , escapeE (getEntityDBName ent) + , escapeS . schemaNamePair $ ent , "(" , T.intercalate "," fieldNames , ") VALUES (" @@ -593,14 +593,17 @@ withStmt' conn query vals = Ok v -> return v doesTableExist :: (Text -> IO Statement) - -> EntityNameDB + -> SchemaEntityName -> IO Bool -doesTableExist getter (EntityNameDB name) = do +doesTableExist getter (schema, EntityNameDB name) = do stmt <- getter sql with (stmtQuery stmt vals) (\src -> runConduit $ src .| start) where - sql = "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" - <> " AND schemaname != 'information_schema' AND tablename=?" + schemaSql = if schema == "" + then "current_schema()" + else "\'" <> schema <> "\'" + sql = "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname = " <> schemaSql + <> " AND tablename=?" vals = [PersistText name] start = await >>= maybe (error "No results when checking doesTableExist") start' @@ -619,12 +622,12 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do ([], old'') -> do exists' <- if null old - then doesTableExist getter name + then doesTableExist getter $ schemaNamePair entity else return True - return $ Right $ migrationText exists' old'' + return $ Right $ createSchemaFor entity <> migrationText exists' old'' (errs, _) -> return $ Left errs where - name = getEntityDBName entity + name = schemaNamePair entity (newcols', udefs, fdefs) = postgresMkColumns allDefs entity migrationText exists' old'' | not exists' = @@ -656,16 +659,21 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do newcols foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs_ +createSchemaFor :: EntityDef -> [AlterDB] +createSchemaFor ent = case fst $ schemaNamePair ent of + "" -> [] + schema -> pure $ AddSchema ("CREATe SCHEMA IF NOT EXISTS " <> schema) + mkForeignAlt :: EntityDef -> ForeignDef -> Maybe AlterDB mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference where - tableName_ = getEntityDBName entity + tableName_@(schema, _) = schemaNamePair entity addReference = AddReference - (foreignRefTableDBName fdef) + (schema, foreignRefTableDBName fdef) constraintName childfields escapedParentFields @@ -682,7 +690,7 @@ addTable cols entity = AddTable $ T.concat -- Lower case e: see Database.Persist.Sql.Migration [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! - , escapeE name + , name , "(" , idtxt , if null nonIdCols then "" else "," @@ -701,8 +709,7 @@ addTable cols entity = Just (cName c) /= fmap fieldDB (getEntityIdField entity) && not (safeToRemove entity (cName c)) - name = - getEntityDBName entity + name = escapeS $ schemaNamePair entity idtxt = case getEntityId entity of EntityIdNaturalKey pdef -> @@ -731,6 +738,7 @@ mayDefault def = case def of Just d -> " DEFAULT " <> d type SafeToRemove = Bool +type SchemaEntityName = (Text, EntityNameDB) data AlterColumn = ChangeType Column SqlType Text @@ -741,7 +749,7 @@ data AlterColumn | Default Column Text | NoDefault Column | Update' Column Text - | AddReference EntityNameDB ConstraintNameDB [FieldNameDB] [Text] FieldCascade + | AddReference SchemaEntityName ConstraintNameDB [FieldNameDB] [Text] FieldCascade | DropReference ConstraintNameDB deriving Show @@ -751,8 +759,9 @@ data AlterTable deriving Show data AlterDB = AddTable Text - | AlterColumn EntityNameDB AlterColumn - | AlterTable EntityNameDB AlterTable + | AddSchema Text + | AlterColumn SchemaEntityName AlterColumn + | AlterTable SchemaEntityName AlterTable deriving Show -- | Returns all of the columns in the given table currently in the database. @@ -760,7 +769,10 @@ getColumns :: (Text -> IO Statement) -> EntityDef -> [Column] -> IO [Either Text (Either Column (ConstraintNameDB, [FieldNameDB]))] getColumns getter def cols = do - let sqlv = T.concat + let tableSchemaSql = case fst $ schemaNamePair def of + "" -> "current_schema()" + schema -> "\'" <> schema <> "\'" + sqlv = T.concat [ "SELECT " , "column_name " , ",is_nullable " @@ -772,7 +784,7 @@ getColumns getter def cols = do , ",character_maximum_length " , "FROM information_schema.columns " , "WHERE table_catalog=current_database() " - , "AND table_schema=current_schema() " + , "AND table_schema=" <> tableSchemaSql <> " " , "AND table_name=? " ] @@ -797,7 +809,7 @@ getColumns getter def cols = do , "information_schema.table_constraints AS k " , "WHERE c.table_catalog=current_database() " , "AND c.table_catalog=k.table_catalog " - , "AND c.table_schema=current_schema() " + , "AND c.table_schema=" <> tableSchemaSql <> " " , "AND c.table_schema=k.table_schema " , "AND c.table_name=? " , "AND c.table_name=k.table_name " @@ -833,7 +845,7 @@ getColumns getter def cols = do $ groupBy ((==) `on` fst) rows processColumns = CL.mapM $ \x'@((PersistText cname) : _) -> do - col <- liftIO $ getColumn getter (getEntityDBName def) x' (Map.lookup cname refMap) + col <- liftIO $ getColumn getter (schemaNamePair def) x' (Map.lookup cname refMap) pure $ case col of Left e -> Left e Right c -> Right $ Left c @@ -890,7 +902,7 @@ getAlters defs def (c1, u1) (c2, u2) = getColumn :: (Text -> IO Statement) - -> EntityNameDB + -> SchemaEntityName -> [PersistValue] -> Maybe (EntityNameDB, ConstraintNameDB) -> IO (Either Text Column) @@ -981,7 +993,10 @@ getColumn getter tableName' [ PersistText columnName Just t' -> t' getRef cname (_, refName') = do - let sql = T.concat + let tableSchemaSql = case fst tableName' of + "" -> "current_schema()" + s -> "\'" <> s <> "\'" + sql = T.concat [ "SELECT DISTINCT " , "ccu.table_name, " , "tc.constraint_name, " @@ -996,6 +1011,7 @@ getColumn getter tableName' [ PersistText columnName , " ON rc.constraint_name = ccu.constraint_name " , "WHERE tc.constraint_type='FOREIGN KEY' " , "AND kcu.ordinal_position=1 " + , "AND kcu.table_schema=" <> tableSchemaSql <> " " , "AND kcu.table_name=? " , "AND kcu.column_name=? " , "AND tc.constraint_name=?" @@ -1004,7 +1020,7 @@ getColumn getter tableName' [ PersistText columnName cntrs <- with (stmtQuery stmt - [ PersistText $ unEntityNameDB tableName' + [ PersistText $ unEntityNameDB $ snd tableName' , PersistText $ unFieldNameDB cname , PersistText $ unConstraintNameDB refName' ] @@ -1018,7 +1034,7 @@ getColumn getter tableName' [ PersistText columnName xs -> error $ mconcat [ "Postgresql.getColumn: error fetching constraints. Expected a single result for foreign key query for table: " - , T.unpack (unEntityNameDB tableName') + , T.unpack (escapeS tableName') , " and column: " , T.unpack (unFieldNameDB cname) , " but got: " @@ -1046,7 +1062,7 @@ getColumn getter tableName' [ PersistText columnName [ "No precision and scale were specified for the column: " , columnName , " in table: " - , unEntityNameDB tableName' + , escapeS tableName' , ". Postgres defaults to a maximum scale of 147,455 and precision of 16383," , " which is probably not what you intended." , " Specify the values as numeric(total_digits, digits_after_decimal_place)." @@ -1056,7 +1072,7 @@ getColumn getter tableName' [ PersistText columnName [ "Can not get numeric field precision for the column: " , columnName , " in table: " - , unEntityNameDB tableName' + , escapeS tableName' , ". Expected an integer for both precision and scale, " , "got: " , T.pack $ show a @@ -1094,12 +1110,16 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName refAdd Nothing = [] refAdd (Just colRef) = - case find ((== crTableName colRef) . getEntityDBName) defs of + let mbRefdef = find (\ent -> getEntityDBSchema ent == getEntityDBSchema edef + && getEntityDBName ent == crTableName colRef) + defs + (schema, _) = schemaNamePair edef + in case mbRefdef of Just refdef | Just _oldName /= fmap fieldDB (getEntityIdField edef) -> [AddReference - (crTableName colRef) + (schema, crTableName colRef) (crConstraintName colRef) [name] (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) @@ -1108,7 +1128,7 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName Just _ -> [] Nothing -> error $ "could not find the entityDef for reftable[" - ++ show (crTableName colRef) ++ "]" + ++ show (escapeS (schema, crTableName colRef)) ++ "]" modRef = if equivalentRef ref ref' then [] @@ -1181,15 +1201,17 @@ getAddReference allDefs entity cname cr@ColumnReference {crTableName = s, crCons guard $ Just cname /= fmap fieldDB (getEntityIdField entity) pure $ AlterColumn table - (AddReference s constraintName [cname] id_ (crFieldCascade cr) + (AddReference (schema, s) constraintName [cname] id_ (crFieldCascade cr) ) where - table = getEntityDBName entity + table@(schema, _) = schemaNamePair entity id_ = fromMaybe (error $ "Could not find ID of entity " ++ show s) $ do - entDef <- find ((== s) . getEntityDBName) allDefs + entDef <- find (\ent -> getEntityDBSchema ent == getEntityDBSchema entity + && getEntityDBName ent == s) + allDefs return $ NEL.toList $ Util.dbIdColumnsEsc escapeF entDef showColumn :: Column -> Text @@ -1226,6 +1248,7 @@ showSqlType (SqlOther t) = t showAlterDb :: AlterDB -> (Bool, Text) showAlterDb (AddTable s) = (False, s) +showAlterDb (AddSchema s) = (False, s) showAlterDb (AlterColumn t ac) = (isUnsafe ac, showAlter t ac) where @@ -1233,10 +1256,10 @@ showAlterDb (AlterColumn t ac) = isUnsafe _ = False showAlterDb (AlterTable t at) = (False, showAlterTable t at) -showAlterTable :: EntityNameDB -> AlterTable -> Text +showAlterTable :: SchemaEntityName -> AlterTable -> Text showAlterTable table (AddUniqueConstraint cname cols) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " ADD CONSTRAINT " , escapeC cname , " UNIQUE(" @@ -1245,16 +1268,16 @@ showAlterTable table (AddUniqueConstraint cname cols) = T.concat ] showAlterTable table (DropConstraint cname) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " DROP CONSTRAINT " , escapeC cname ] -showAlter :: EntityNameDB -> AlterColumn -> Text +showAlter :: SchemaEntityName -> AlterColumn -> Text showAlter table (ChangeType c t extra) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " ALTER COLUMN " , escapeF (cName c) , " TYPE " @@ -1264,7 +1287,7 @@ showAlter table (ChangeType c t extra) = showAlter table (IsNull c) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " ALTER COLUMN " , escapeF (cName c) , " DROP NOT NULL" @@ -1272,7 +1295,7 @@ showAlter table (IsNull c) = showAlter table (NotNull c) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " ALTER COLUMN " , escapeF (cName c) , " SET NOT NULL" @@ -1280,21 +1303,21 @@ showAlter table (NotNull c) = showAlter table (Add' col) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " ADD COLUMN " , showColumn col ] showAlter table (Drop c _) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " DROP COLUMN " , escapeF (cName c) ] showAlter table (Default c s) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " ALTER COLUMN " , escapeF (cName c) , " SET DEFAULT " @@ -1302,14 +1325,14 @@ showAlter table (Default c s) = ] showAlter table (NoDefault c) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " ALTER COLUMN " , escapeF (cName c) , " DROP DEFAULT" ] showAlter table (Update' c s) = T.concat [ "UPDATE " - , escapeE table + , escapeS table , " SET " , escapeF (cName c) , "=" @@ -1320,24 +1343,33 @@ showAlter table (Update' c s) = T.concat ] showAlter table (AddReference reftable fkeyname t2 id2 cascade) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " ADD CONSTRAINT " , escapeC fkeyname , " FOREIGN KEY(" , T.intercalate "," $ map escapeF t2 , ") REFERENCES " - , escapeE reftable + , escapeS reftable , "(" , T.intercalate "," id2 , ")" ] <> renderFieldCascade cascade showAlter table (DropReference cname) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeS table , " DROP CONSTRAINT " , escapeC cname ] +schemaNamePair :: EntityDef -> SchemaEntityName +schemaNamePair ent = schemaNamePair' (getEntityDBSchema ent) (getEntityDBName ent) + +schemaNamePair' :: Maybe Text -> EntityNameDB -> SchemaEntityName +schemaNamePair' mbSchema entName = case mbSchema of + Nothing -> ("", entName) + Just "public" -> ("", entName) + Just schema -> (schema, entName) + -- | Get the SQL string for the table that a PeristEntity represents. -- Useful for raw SQL queries. tableName :: (PersistEntity record) => record -> Text @@ -1357,6 +1389,13 @@ escapeE = escapeWith escape escapeF :: FieldNameDB -> Text escapeF = escapeWith escape +escapeS :: SchemaEntityName -> Text +escapeS ("", entDBName) = escapeE entDBName +escapeS (schema, entDBName) = escape schema <> "." <> escapeE entDBName + +unescapeS :: SchemaEntityName -> EntityNameDB +unescapeS ("", entDBName) = entDBName +unescapeS (schema, entDBName) = EntityNameDB $ escapeWith ((schema <> ".") <>) entDBName escape :: Text -> Text escape s = @@ -1519,10 +1558,10 @@ mockMigrate :: [EntityDef] -> IO (Either [Text] [(Bool, Text)]) mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do case partitionEithers [] of - ([], old'') -> return $ Right $ migrationText False old'' + ([], old'') -> return $ Right $ createSchemaFor entity <> migrationText False old'' (errs, _) -> return $ Left errs where - name = getEntityDBName entity + name = schemaNamePair entity migrationText exists' old'' = if not exists' then createText newcols fdefs udspair @@ -1574,7 +1613,7 @@ mockMigration mig = do , connCommit = undefined , connRollback = undefined , connEscapeFieldName = escapeF - , connEscapeTableName = escapeE . getEntityDBName + , connEscapeTableName = \entDef -> escapeS $ schemaNamePair entDef , connEscapeRawName = escape , connNoLimit = undefined , connRDBMS = undefined @@ -1775,7 +1814,7 @@ mkBulkUpsertQuery records conn fieldValues updates filters uniqDef = [] -> error "The entity you're trying to insert does not have any fields." (field:_) -> field entityFieldNames = map fieldDbToText (getEntityFields entityDef') - nameOfTable = escapeE . getEntityDBName $ entityDef' + nameOfTable = escapeS . schemaNamePair $ entityDef' copyUnlessValues = map snd fieldsToMaybeCopy recordValues = concatMap (map toPersistValue . toPersistFields) records recordPlaceholders = @@ -1837,7 +1876,7 @@ putManySql' conflictColumns (filter isFieldNotGenerated -> fields) ent n = q fieldDbToText = escapeF . fieldDB mkAssignment f = T.concat [f, "=EXCLUDED.", f] - table = escapeE . getEntityDBName $ ent + table = escapeS . schemaNamePair $ ent columns = Util.commaSeparated $ map fieldDbToText fields placeholders = map (const "?") fields updates = map (mkAssignment . fieldDbToText) fields @@ -1867,9 +1906,12 @@ migrateEnableExtension extName = WriterT $ WriterT $ do else return (((), []), []) postgresMkColumns :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -postgresMkColumns allDefs t = - mkColumns allDefs t - $ setBackendSpecificForeignKeyName refName emptyBackendSpecificOverrides +postgresMkColumns allDefs t = mkColumns allDefs t postgresSpecificOverrides + +postgresSpecificOverrides :: BackendSpecificOverrides +postgresSpecificOverrides = setBackendSpecificForeignKeyName refName + $ setBackendSpecificSchemaEntityName (\schema name -> unescapeS $ schemaNamePair' schema name) + $ emptyBackendSpecificOverrides -- | Wrapper for persistent SqlBackends that carry the corresponding -- `Postgresql.Connection`. diff --git a/persistent-postgresql/test/main.hs b/persistent-postgresql/test/main.hs index c00650ac0..4487e003f 100644 --- a/persistent-postgresql/test/main.hs +++ b/persistent-postgresql/test/main.hs @@ -55,6 +55,7 @@ import qualified RawSqlTest import qualified ReadWriteTest import qualified Recursive import qualified RenameTest +import qualified SchemaTest import qualified SumTypeTest import qualified TransactionLevelTest import qualified TreeTest @@ -139,6 +140,7 @@ main = do , PgIntervalTest.pgIntervalMigrate , UpsertWhere.upsertWhereMigrate , ImplicitUuidSpec.implicitUuidMigrate + , SchemaTest.migration ] PersistentTest.cleanDB ForeignKey.cleanDB @@ -214,3 +216,4 @@ main = do PgIntervalTest.specs ArrayAggTest.specs GeneratedColumnTestSQL.specsWith runConnAssert + SchemaTest.specsWith runConnAssert From 1942681c7326358695b2e092f0b7e460169b893b Mon Sep 17 00:00:00 2001 From: Gleb Popov <6yearold@gmail.com> Date: Wed, 1 Feb 2023 19:06:24 +0300 Subject: [PATCH 6/6] Add a simple test for schema functionality. --- persistent-sqlite/test/main.hs | 3 ++ persistent-test/persistent-test.cabal | 5 ++-- persistent-test/src/SchemaTest.hs | 32 ++++++++++++++++++++++ persistent/test/Database/Persist/THSpec.hs | 1 + 4 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 persistent-test/src/SchemaTest.hs diff --git a/persistent-sqlite/test/main.hs b/persistent-sqlite/test/main.hs index f8a7a87ea..3f097a927 100644 --- a/persistent-sqlite/test/main.hs +++ b/persistent-sqlite/test/main.hs @@ -43,6 +43,7 @@ import qualified RawSqlTest import qualified ReadWriteTest import qualified Recursive import qualified RenameTest +import qualified SchemaTest import qualified SumTypeTest import qualified TransactionLevelTest import qualified TypeLitFieldDefsTest @@ -175,6 +176,7 @@ main = do , MigrationColumnLengthTest.migration , TransactionLevelTest.migration , LongIdentifierTest.migration + , SchemaTest.migration ] PersistentTest.cleanDB ForeignKey.cleanDB @@ -243,6 +245,7 @@ main = do MigrationTest.specsWith db LongIdentifierTest.specsWith db GeneratedColumnTestSQL.specsWith db + SchemaTest.specsWith db it "issue #328" $ asIO $ runSqliteInfo (mkSqliteConnectionInfo ":memory:") $ do void $ runMigrationSilent migrateAll diff --git a/persistent-test/persistent-test.cabal b/persistent-test/persistent-test.cabal index 72e900b80..a75414e96 100644 --- a/persistent-test/persistent-test.cabal +++ b/persistent-test/persistent-test.cabal @@ -15,7 +15,7 @@ bug-reports: https://github.com/yesodweb/persistent/issues extra-source-files: ChangeLog.md library - exposed-modules: + exposed-modules: CompositeTest CustomPersistField CustomPersistFieldTest @@ -51,6 +51,7 @@ library RenameTest Recursive SumTypeTest + SchemaTest TransactionLevelTest TreeTest TypeLitFieldDefsTest @@ -60,7 +61,7 @@ library hs-source-dirs: src - build-depends: + build-depends: base >= 4.9 && < 5 , aeson >= 1.0 , blaze-html >= 0.9 diff --git a/persistent-test/src/SchemaTest.hs b/persistent-test/src/SchemaTest.hs new file mode 100644 index 000000000..64bf79d66 --- /dev/null +++ b/persistent-test/src/SchemaTest.hs @@ -0,0 +1,32 @@ +{-# LANGUAGE UndecidableInstances #-} +module SchemaTest where + +import Init +import qualified PersistentTestModels as PTM + +share [mkPersist persistSettings, mkMigrate "migration"] [persistLowerCase| +Person schema=my_special_schema + name Text + age Int + weight Int + deriving Show Eq +|] + +cleanDB :: (MonadIO m, PersistQuery backend, PersistEntityBackend Person ~ backend) => ReaderT backend m () +cleanDB = do + deleteWhere ([] :: [Filter Person]) + deleteWhere ([] :: [Filter PTM.Person]) + +specsWith :: (MonadIO m, MonadFail m) => RunDb SqlBackend m -> Spec +specsWith runDb = describe "schema support" $ do + it "insert a Person under non-default schema" $ runDb $ do + insert_ $ Person "name" 1 2 + return () + it "insert PTM.Person and Person and check they end up in different tables" $ runDb $ do + cleanDB + insert_ $ Person "name" 1 2 + insert_ $ PTM.Person "name2" 3 Nothing + schemaPersonCount <- count ([] :: [Filter Person]) + ptmPersoncount <- count ([] :: [Filter PTM.Person]) + -- both tables should contain only one record despite similarly named Entities + schemaPersonCount + ptmPersoncount @== 2 diff --git a/persistent/test/Database/Persist/THSpec.hs b/persistent/test/Database/Persist/THSpec.hs index 0ea783206..cf9be8e51 100644 --- a/persistent/test/Database/Persist/THSpec.hs +++ b/persistent/test/Database/Persist/THSpec.hs @@ -318,6 +318,7 @@ spec = describe "THSpec" $ do EntityDef { entityHaskell = EntityNameHS "HasSimpleCascadeRef" , entityDB = EntityNameDB "HasSimpleCascadeRef" + , entitySchema = Nothing , entityId = EntityIdField FieldDef { fieldHaskell = FieldNameHS "Id"