Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Schema support #1466

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 101 additions & 59 deletions persistent-postgresql/Database/Persist/Postgresql.hs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since ColumnReference isn't supported yet, this addition will fail for things that need to use it - which is specifying foreign key relationships in migrations. I suspect that a test that works on this schema would fail:

mkPersist sqlSettings [persistLowerCase|
  Foo schema=foo_schema
    name Int

  Bar 
    foo FooId 
|]

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions persistent-postgresql/test/main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ import qualified RawSqlTest
import qualified ReadWriteTest
import qualified Recursive
import qualified RenameTest
import qualified SchemaTest
import qualified SumTypeTest
import qualified TransactionLevelTest
import qualified TreeTest
Expand Down Expand Up @@ -139,6 +140,7 @@ main = do
, PgIntervalTest.pgIntervalMigrate
, UpsertWhere.upsertWhereMigrate
, ImplicitUuidSpec.implicitUuidMigrate
, SchemaTest.migration
]
PersistentTest.cleanDB
ForeignKey.cleanDB
Expand Down Expand Up @@ -214,3 +216,4 @@ main = do
PgIntervalTest.specs
ArrayAggTest.specs
GeneratedColumnTestSQL.specsWith runConnAssert
SchemaTest.specsWith runConnAssert
81 changes: 54 additions & 27 deletions persistent-sqlite/Database/Persist/Sqlite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ wrapConnectionInfo connInfo conn logFunc = do
, connCommit = helper "COMMIT"
, connRollback = ignoreExceptions . helper "ROLLBACK"
, connEscapeFieldName = escape . unFieldNameDB
, connEscapeTableName = escape . unEntityNameDB . getEntityDBName
, connEscapeTableName = escapeS . schemaNamePair
, connEscapeRawName = escape
, connNoLimit = "LIMIT -1"
, connRDBMS = "sqlite"
Expand Down Expand Up @@ -358,7 +358,7 @@ insertSql' ent vals =
ISRManyKeys sql vals
where sql = T.concat
[ "INSERT INTO "
, escapeE $ getEntityDBName ent
, escapeS $ schemaNamePair ent
, "("
, T.intercalate "," $ map (escapeF . fieldDB) cols
, ") VALUES("
Expand All @@ -372,12 +372,12 @@ insertSql' ent vals =
[ "SELECT "
, escapeF $ fieldDB fd
, " FROM "
, escapeE $ getEntityDBName ent
, escapeS $ schemaNamePair ent
, " WHERE _ROWID_=last_insert_rowid()"
]
ins = T.concat
[ "INSERT INTO "
, escapeE $ getEntityDBName ent
, escapeS $ schemaNamePair ent
, if null cols
then " VALUES(null)"
else T.concat
Expand Down Expand Up @@ -434,8 +434,14 @@ showSqlType SqlBlob = "BLOB"
showSqlType SqlBool = "BOOLEAN"
showSqlType (SqlOther t) = t

type SchemaEntityName = (Text, EntityNameDB)

sqliteMkColumns :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef])
sqliteMkColumns allDefs t = mkColumns allDefs t emptyBackendSpecificOverrides
sqliteMkColumns allDefs t = mkColumns allDefs t sqliteSpecificOverrides

sqliteSpecificOverrides :: BackendSpecificOverrides
sqliteSpecificOverrides = setBackendSpecificSchemaEntityName (\schema name -> unescapeS $ schemaNamePair' schema name)
$ emptyBackendSpecificOverrides

migrate'
:: [EntityDef]
Expand All @@ -444,9 +450,9 @@ migrate'
-> IO (Either [Text] CautiousMigration)
migrate' allDefs getter val = do
let (cols, uniqs, fdefs) = sqliteMkColumns allDefs val
let newSql = mkCreateTable False def (filter (not . safeToRemove val . cName) cols, uniqs, fdefs)
let newSql = mkCreateTable def (filter (not . safeToRemove val . cName) cols, uniqs, fdefs)
stmt <- getter "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
oldSql' <- with (stmtQuery stmt [PersistText $ unEntityNameDB table])
oldSql' <- with (stmtQuery stmt [PersistText table])
(\src -> runConduit $ src .| go)
case oldSql' of
Nothing -> return $ Right [(False, newSql)]
Expand All @@ -458,7 +464,7 @@ migrate' allDefs getter val = do
return $ Right sql
where
def = val
table = getEntityDBName def
table = unEntityNameDB $ unescapeS $ schemaNamePair def
go = do
x <- CL.head
case x of
Expand Down Expand Up @@ -490,7 +496,7 @@ mockMigration mig = do
, connCommit = helper "COMMIT"
, connRollback = ignoreExceptions . helper "ROLLBACK"
, connEscapeFieldName = escape . unFieldNameDB
, connEscapeTableName = escape . unEntityNameDB . getEntityDBName
, connEscapeTableName = escapeS . schemaNamePair
, connEscapeRawName = escape
, connNoLimit = "LIMIT -1"
, connRDBMS = "sqlite"
Expand Down Expand Up @@ -528,7 +534,7 @@ getCopyTable :: [EntityDef]
-> EntityDef
-> IO [(Bool, Text)]
getCopyTable allDefs getter def = do
stmt <- getter $ T.concat [ "PRAGMA table_info(", escapeE table, ")" ]
stmt <- getter $ T.concat [ "PRAGMA table_info(", escapeS table, ")" ]
oldCols' <- with (stmtQuery stmt []) (\src -> runConduit $ src .| getCols)
let oldCols = map FieldNameDB oldCols'
let newCols = filter (not . safeToRemove def) $ map cName cols
Expand All @@ -549,42 +555,44 @@ getCopyTable allDefs getter def = do
names <- getCols
return $ name : names
Just y -> error $ "Invalid result from PRAGMA table_info: " ++ show y
table = getEntityDBName def
tableTmp = EntityNameDB $ unEntityNameDB table <> "_backup"
defTmp = setEntityDBSchema (Just "temp")
$ setEntityDBName (escapeWith (EntityNameDB . (<> "_backup")) $ getEntityDBName def)
def
table = schemaNamePair def
tableTmp = schemaNamePair defTmp
(cols, uniqs, fdef) = sqliteMkColumns allDefs def
cols' = filter (not . safeToRemove def . cName) cols
newSql = mkCreateTable False def (cols', uniqs, fdef)
tmpSql = mkCreateTable True (setEntityDBName tableTmp def) (cols', uniqs, [])
dropTmp = "DROP TABLE " <> escapeE tableTmp
dropOld = "DROP TABLE " <> escapeE table
newSql = mkCreateTable def (cols', uniqs, fdef)
tmpSql = mkCreateTable defTmp (cols', uniqs, [])
dropTmp = "DROP TABLE " <> escapeS tableTmp
dropOld = "DROP TABLE " <> escapeS table
copyToTemp common = T.concat
[ "INSERT INTO "
, escapeE tableTmp
, escapeS tableTmp
, "("
, T.intercalate "," $ map escapeF common
, ") SELECT "
, T.intercalate "," $ map escapeF common
, " FROM "
, escapeE table
, escapeS table
]
copyToFinal newCols = T.concat
[ "INSERT INTO "
, escapeE table
, escapeS table
, " SELECT "
, T.intercalate "," $ map escapeF newCols
, " FROM "
, escapeE tableTmp
, escapeS tableTmp
]

mkCreateTable :: Bool -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -> Text
mkCreateTable isTemp entity (cols, uniqs, fdefs) =
mkCreateTable :: EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -> Text
mkCreateTable entity (cols, uniqs, fdefs) =
T.concat (header <> columns <> footer)
where
isTemp = getEntityDBSchema entity == Just "temp"
header =
[ "CREATE"
, if isTemp then " TEMP" else ""
, " TABLE "
, escapeE $ getEntityDBName entity
[ "CREATE TABLE "
, escapeS $ schemaNamePair entity
, "("
]

Expand Down Expand Up @@ -678,6 +686,16 @@ sqlUnique (UniqueDef _ cname cols _) = T.concat
, ")"
]

schemaNamePair :: EntityDef -> SchemaEntityName
schemaNamePair ent = schemaNamePair' (getEntityDBSchema ent) (getEntityDBName ent)

schemaNamePair' :: Maybe Text -> EntityNameDB -> SchemaEntityName
schemaNamePair' mbSchema entName = case mbSchema of
Nothing -> ("", entName)
Just "main" -> ("", entName)
Just "temp" -> ("temp", entName)
Just schema -> ("", EntityNameDB $ (schema <> "_") <> unEntityNameDB entName)
Comment on lines +692 to +697
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intriguing - so sqlite only has two schema, main and temp? And we mimic that by using schema_table_name? I'd be inclined to use __ (two underscore) to make the distinction more clear - that way you can see foo__bar_baz and know that foo is the supposed schema, rather than foo_bar_baz potentially being either BarBaz schema=foo or FooBarBaz (no schema).


escapeC :: ConstraintNameDB -> Text
escapeC = escapeWith escape

Expand All @@ -687,6 +705,15 @@ escapeE = escapeWith escape
escapeF :: FieldNameDB -> Text
escapeF = escapeWith escape

escapeS :: (Text, EntityNameDB) -> Text
escapeS ("", entDBName) = escapeE entDBName
-- no need to escape schema as it is either "" or "temp"
escapeS (schema, entDBName) = schema <> "." <> escapeE entDBName

unescapeS :: (Text, EntityNameDB) -> EntityNameDB
unescapeS ("", entDBName) = entDBName
unescapeS (schema, entDBName) = EntityNameDB $ escapeWith ((schema <> ".") <>) entDBName

escape :: Text -> Text
escape s =
T.concat [q, T.concatMap go s, q]
Expand All @@ -713,7 +740,7 @@ putManySql' conflictColumns fields ent n = q
fieldDbToText = escapeF . fieldDB
mkAssignment f = T.concat [f, "=EXCLUDED.", f]

table = escapeE . getEntityDBName $ ent
table = escapeS . schemaNamePair $ ent
columns = Util.commaSeparated $ map fieldDbToText fields
placeholders = map (const "?") fields
updates = map (mkAssignment . fieldDbToText) fields
Expand Down
3 changes: 3 additions & 0 deletions persistent-sqlite/test/main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import qualified RawSqlTest
import qualified ReadWriteTest
import qualified Recursive
import qualified RenameTest
import qualified SchemaTest
import qualified SumTypeTest
import qualified TransactionLevelTest
import qualified TypeLitFieldDefsTest
Expand Down Expand Up @@ -175,6 +176,7 @@ main = do
, MigrationColumnLengthTest.migration
, TransactionLevelTest.migration
, LongIdentifierTest.migration
, SchemaTest.migration
]
PersistentTest.cleanDB
ForeignKey.cleanDB
Expand Down Expand Up @@ -243,6 +245,7 @@ main = do
MigrationTest.specsWith db
LongIdentifierTest.specsWith db
GeneratedColumnTestSQL.specsWith db
SchemaTest.specsWith db

it "issue #328" $ asIO $ runSqliteInfo (mkSqliteConnectionInfo ":memory:") $ do
void $ runMigrationSilent migrateAll
Expand Down
5 changes: 3 additions & 2 deletions persistent-test/persistent-test.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ bug-reports: https://github.com/yesodweb/persistent/issues
extra-source-files: ChangeLog.md

library
exposed-modules:
exposed-modules:
CompositeTest
CustomPersistField
CustomPersistFieldTest
Expand Down Expand Up @@ -51,6 +51,7 @@ library
RenameTest
Recursive
SumTypeTest
SchemaTest
TransactionLevelTest
TreeTest
TypeLitFieldDefsTest
Expand All @@ -60,7 +61,7 @@ library

hs-source-dirs: src

build-depends:
build-depends:
base >= 4.9 && < 5
, aeson >= 1.0
, blaze-html >= 0.9
Expand Down
32 changes: 32 additions & 0 deletions persistent-test/src/SchemaTest.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{-# LANGUAGE UndecidableInstances #-}
module SchemaTest where

import Init
import qualified PersistentTestModels as PTM

share [mkPersist persistSettings, mkMigrate "migration"] [persistLowerCase|
Person schema=my_special_schema
name Text
age Int
weight Int
deriving Show Eq
|]

cleanDB :: (MonadIO m, PersistQuery backend, PersistEntityBackend Person ~ backend) => ReaderT backend m ()
cleanDB = do
deleteWhere ([] :: [Filter Person])
deleteWhere ([] :: [Filter PTM.Person])

specsWith :: (MonadIO m, MonadFail m) => RunDb SqlBackend m -> Spec
specsWith runDb = describe "schema support" $ do
it "insert a Person under non-default schema" $ runDb $ do
insert_ $ Person "name" 1 2
return ()
it "insert PTM.Person and Person and check they end up in different tables" $ runDb $ do
cleanDB
insert_ $ Person "name" 1 2
insert_ $ PTM.Person "name2" 3 Nothing
schemaPersonCount <- count ([] :: [Filter Person])
ptmPersoncount <- count ([] :: [Filter PTM.Person])
-- both tables should contain only one record despite similarly named Entities
schemaPersonCount + ptmPersoncount @== 2
16 changes: 16 additions & 0 deletions persistent/Database/Persist/EntityDef.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module Database.Persist.EntityDef
-- * Accessors
, getEntityHaskellName
, getEntityDBName
, getEntityDBSchema
, getEntityFields
, getEntityFieldsDatabase
, getEntityForeignDefs
Expand All @@ -27,6 +28,7 @@ module Database.Persist.EntityDef
, setEntityId
, setEntityIdDef
, setEntityDBName
, setEntityDBSchema
, overEntityFields
-- * Related Types
, EntityIdDef(..)
Expand Down Expand Up @@ -86,6 +88,14 @@ getEntityDBName
-> EntityNameDB
getEntityDBName = entityDB

-- | Return the database schema name for the given entity.
--
-- @since XXX
getEntityDBSchema
:: EntityDef
-> Maybe Text
getEntityDBSchema = entitySchema

getEntityExtra :: EntityDef -> Map Text [[Text]]
getEntityExtra = entityExtra

Expand All @@ -98,6 +108,12 @@ setEntityDBName db ed = ed { entityDB = db }
getEntityComments :: EntityDef -> Maybe Text
getEntityComments = entityComments

-- | Sets or resets the database schema name for the given entity.
--
-- @since XXX
setEntityDBSchema :: Maybe Text -> EntityDef -> EntityDef
setEntityDBSchema schema entDef = entDef { entitySchema = schema }

-- |
--
-- @since 2.13.0.0
Expand Down
4 changes: 4 additions & 0 deletions persistent/Database/Persist/Quasi/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ mkUnboundEntityDef ps parsedEntDef =
EntityDef
{ entityHaskell = entNameHS
, entityDB = entNameDB
, entitySchema = getSchemaName (parsedEntityDefEntityAttributes parsedEntDef)
-- idField is the user-specified Id
-- otherwise useAutoIdField
-- but, adjust it if the user specified a Primary
Expand Down Expand Up @@ -939,6 +940,9 @@ getDbName :: PersistSettings -> Text -> [Text] -> Text
getDbName ps n =
fromMaybe (psToDBName ps n) . listToMaybe . mapMaybe (T.stripPrefix "sql=")

getSchemaName :: [Attr] -> Maybe Text
getSchemaName = listToMaybe . mapMaybe (T.stripPrefix "schema=")

getDbName' :: PersistSettings -> Text -> [FieldAttr] -> FieldNameDB
getDbName' ps n =
getSqlNameOr (FieldNameDB $ psToDBName ps n)
Expand Down
2 changes: 2 additions & 0 deletions persistent/Database/Persist/Sql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ module Database.Persist.Sql
, emptyBackendSpecificOverrides
, getBackendSpecificForeignKeyName
, setBackendSpecificForeignKeyName
, getBackendSpecificSchemaEntityName
, setBackendSpecificSchemaEntityName
, defaultAttribute
-- * Internal
, IsolationLevel(..)
Expand Down
Loading