Skip to content

Commit

Permalink
Schema support for Sqlite backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
arrowd committed Jan 27, 2023
1 parent 6fdbf84 commit e7d27b6
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 26 deletions.
67 changes: 41 additions & 26 deletions persistent-sqlite/Database/Persist/Sqlite.hs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ wrapConnectionInfo connInfo conn logFunc = do
, connCommit = helper "COMMIT"
, connRollback = ignoreExceptions . helper "ROLLBACK"
, connEscapeFieldName = escape . unFieldNameDB
, connEscapeTableName = escape . unEntityNameDB . getEntityDBName
, connEscapeTableName = escapeS . getEntityDBNameWithSchema
, connEscapeRawName = escape
, connNoLimit = "LIMIT -1"
, connRDBMS = "sqlite"
Expand Down Expand Up @@ -358,7 +358,7 @@ insertSql' ent vals =
ISRManyKeys sql vals
where sql = T.concat
[ "INSERT INTO "
, escapeE $ getEntityDBName ent
, escapeS $ getEntityDBNameWithSchema ent
, "("
, T.intercalate "," $ map (escapeF . fieldDB) cols
, ") VALUES("
Expand All @@ -372,12 +372,12 @@ insertSql' ent vals =
[ "SELECT "
, escapeF $ fieldDB fd
, " FROM "
, escapeE $ getEntityDBName ent
, escapeS $ getEntityDBNameWithSchema ent
, " WHERE _ROWID_=last_insert_rowid()"
]
ins = T.concat
[ "INSERT INTO "
, escapeE $ getEntityDBName ent
, escapeS $ getEntityDBNameWithSchema ent
, if null cols
then " VALUES(null)"
else T.concat
Expand Down Expand Up @@ -444,9 +444,9 @@ migrate'
-> IO (Either [Text] CautiousMigration)
migrate' allDefs getter val = do
let (cols, uniqs, fdefs) = sqliteMkColumns allDefs val
let newSql = mkCreateTable False def (filter (not . safeToRemove val . cName) cols, uniqs, fdefs)
let newSql = mkCreateTable def (filter (not . safeToRemove val . cName) cols, uniqs, fdefs)
stmt <- getter "SELECT sql FROM sqlite_master WHERE type='table' AND name=?"
oldSql' <- with (stmtQuery stmt [PersistText $ unEntityNameDB table])
oldSql' <- with (stmtQuery stmt [PersistText table])
(\src -> runConduit $ src .| go)
case oldSql' of
Nothing -> return $ Right [(False, newSql)]
Expand All @@ -458,7 +458,7 @@ migrate' allDefs getter val = do
return $ Right sql
where
def = val
table = getEntityDBName def
table = unescapeS $ getEntityDBNameWithSchema def
go = do
x <- CL.head
case x of
Expand Down Expand Up @@ -490,7 +490,7 @@ mockMigration mig = do
, connCommit = helper "COMMIT"
, connRollback = ignoreExceptions . helper "ROLLBACK"
, connEscapeFieldName = escape . unFieldNameDB
, connEscapeTableName = escape . unEntityNameDB . getEntityDBName
, connEscapeTableName = escapeS . getEntityDBNameWithSchema
, connEscapeRawName = escape
, connNoLimit = "LIMIT -1"
, connRDBMS = "sqlite"
Expand Down Expand Up @@ -528,7 +528,7 @@ getCopyTable :: [EntityDef]
-> EntityDef
-> IO [(Bool, Text)]
getCopyTable allDefs getter def = do
stmt <- getter $ T.concat [ "PRAGMA table_info(", escapeE table, ")" ]
stmt <- getter $ T.concat [ "PRAGMA table_info(", escapeS table, ")" ]
oldCols' <- with (stmtQuery stmt []) (\src -> runConduit $ src .| getCols)
let oldCols = map FieldNameDB oldCols'
let newCols = filter (not . safeToRemove def) $ map cName cols
Expand All @@ -549,42 +549,44 @@ getCopyTable allDefs getter def = do
names <- getCols
return $ name : names
Just y -> error $ "Invalid result from PRAGMA table_info: " ++ show y
table = getEntityDBName def
tableTmp = EntityNameDB $ unEntityNameDB table <> "_backup"
defTmp = setEntityDBSchema (Just "temp")
$ setEntityDBName (escapeWith (EntityNameDB . (<> "_backup")) $ getEntityDBName def)
def
table = getEntityDBNameWithSchema def
tableTmp = getEntityDBNameWithSchema defTmp
(cols, uniqs, fdef) = sqliteMkColumns allDefs def
cols' = filter (not . safeToRemove def . cName) cols
newSql = mkCreateTable False def (cols', uniqs, fdef)
tmpSql = mkCreateTable True (setEntityDBName tableTmp def) (cols', uniqs, [])
dropTmp = "DROP TABLE " <> escapeE tableTmp
dropOld = "DROP TABLE " <> escapeE table
newSql = mkCreateTable def (cols', uniqs, fdef)
tmpSql = mkCreateTable defTmp (cols', uniqs, [])
dropTmp = "DROP TABLE " <> escapeS tableTmp
dropOld = "DROP TABLE " <> escapeS table
copyToTemp common = T.concat
[ "INSERT INTO "
, escapeE tableTmp
, escapeS tableTmp
, "("
, T.intercalate "," $ map escapeF common
, ") SELECT "
, T.intercalate "," $ map escapeF common
, " FROM "
, escapeE table
, escapeS table
]
copyToFinal newCols = T.concat
[ "INSERT INTO "
, escapeE table
, escapeS table
, " SELECT "
, T.intercalate "," $ map escapeF newCols
, " FROM "
, escapeE tableTmp
, escapeS tableTmp
]

mkCreateTable :: Bool -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -> Text
mkCreateTable isTemp entity (cols, uniqs, fdefs) =
mkCreateTable :: EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -> Text
mkCreateTable entity (cols, uniqs, fdefs) =
T.concat (header <> columns <> footer)
where
isTemp = getEntityDBSchema entity == Just "temp"
header =
[ "CREATE"
, if isTemp then " TEMP" else ""
, " TABLE "
, escapeE $ getEntityDBName entity
[ "CREATE TABLE "
, escapeS $ getEntityDBNameWithSchema entity
, "("
]

Expand Down Expand Up @@ -687,6 +689,12 @@ escapeE = escapeWith escape
escapeF :: FieldNameDB -> Text
escapeF = escapeWith escape

escapeS :: (Text, EntityNameDB) -> Text
escapeS (schema, entDBName) = schema <> escapeE entDBName

unescapeS :: (Text, EntityNameDB) -> Text
unescapeS (schema, entDBName) = schema <> unEntityNameDB entDBName

escape :: Text -> Text
escape s =
T.concat [q, T.concatMap go s, q]
Expand All @@ -695,6 +703,13 @@ escape s =
go '"' = "\"\""
go c = T.singleton c

getEntityDBNameWithSchema :: EntityDef -> (Text, EntityNameDB)
getEntityDBNameWithSchema entDef = case getEntityDBSchema entDef of
Nothing -> ("", getEntityDBName entDef)
Just "main" -> ("", getEntityDBName entDef)
Just "temp" -> ("temp.", getEntityDBName entDef)
Just schema -> ("", escapeWith (EntityNameDB . ((schema <> "_") <>)) (getEntityDBName entDef))

putManySql :: EntityDef -> Int -> Text
putManySql ent n = putManySql' conflictColumns (toList fields) ent n
where
Expand All @@ -713,7 +728,7 @@ putManySql' conflictColumns fields ent n = q
fieldDbToText = escapeF . fieldDB
mkAssignment f = T.concat [f, "=EXCLUDED.", f]

table = escapeE . getEntityDBName $ ent
table = escapeS . getEntityDBNameWithSchema $ ent
columns = Util.commaSeparated $ map fieldDbToText fields
placeholders = map (const "?") fields
updates = map (mkAssignment . fieldDbToText) fields
Expand Down
11 changes: 11 additions & 0 deletions persistent-sqlite/test/Database/Persist/Sqlite/CompositeSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import Control.Monad.Trans.Resource (MonadResource)
import qualified Data.Conduit.List as CL
import Conduit
import Database.Persist.Sqlite
import Database.Persist.SqlBackend.Internal
import Database.Persist.EntityDef.Internal
import System.IO (hClose)
import Control.Exception (handle, IOException, throwIO)
import System.IO.Temp (withSystemTempFile)
Expand Down Expand Up @@ -83,6 +85,15 @@ spec = describe "CompositeSpec" $ do
runSqliteInfo connInfo $ do
void $ runMigrationSilent compositeMigrateTest
validateForeignKeys
it "test schema handling (issues #93 and #1454)" $ asIO $ runSqliteInfo (mkSqliteConnectionInfo ":memory:") $ do
backend <- ask
let tableName rec = connEscapeTableName backend (entityDef $ Just rec)
tableNameWithSchema schema rec = connEscapeTableName backend ((entityDef (Just rec)) { entitySchema = Just schema })

tableName (undefined :: SimpleComposite) @== "\"simple_composite\""
tableNameWithSchema "main" (undefined :: SimpleComposite) @== "\"simple_composite\""
tableNameWithSchema "temp" (undefined :: SimpleComposite) @== "temp.\"simple_composite\""
tableNameWithSchema "foo" (undefined :: SimpleComposite) @== "\"foo_simple_composite\""


validateForeignKeys
Expand Down
16 changes: 16 additions & 0 deletions persistent/Database/Persist/EntityDef.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module Database.Persist.EntityDef
-- * Accessors
, getEntityHaskellName
, getEntityDBName
, getEntityDBSchema
, getEntityFields
, getEntityFieldsDatabase
, getEntityForeignDefs
Expand All @@ -27,6 +28,7 @@ module Database.Persist.EntityDef
, setEntityId
, setEntityIdDef
, setEntityDBName
, setEntityDBSchema
, overEntityFields
-- * Related Types
, EntityIdDef(..)
Expand Down Expand Up @@ -86,6 +88,14 @@ getEntityDBName
-> EntityNameDB
getEntityDBName = entityDB

-- | Return the database schema name for the given entity.
--
-- @since 2.14.4.5
getEntityDBSchema
:: EntityDef
-> Maybe Text
getEntityDBSchema = entitySchema

getEntityExtra :: EntityDef -> Map Text [[Text]]
getEntityExtra = entityExtra

Expand All @@ -98,6 +108,12 @@ setEntityDBName db ed = ed { entityDB = db }
getEntityComments :: EntityDef -> Maybe Text
getEntityComments = entityComments

-- | Sets or resets the database schema name for the given entity.
--
-- @since 2.14.4.5
setEntityDBSchema :: Maybe Text -> EntityDef -> EntityDef
setEntityDBSchema schema entDef = entDef { entitySchema = schema }

-- |
--
-- @since 2.13.0.0
Expand Down
1 change: 1 addition & 0 deletions persistent/Database/Persist/Quasi/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ mkUnboundEntityDef ps parsedEntDef =
EntityDef
{ entityHaskell = entNameHS
, entityDB = entNameDB
, entitySchema = Nothing
-- idField is the user-specified Id
-- otherwise useAutoIdField
-- but, adjust it if the user specified a Primary
Expand Down
2 changes: 2 additions & 0 deletions persistent/Database/Persist/Types/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ data EntityDef = EntityDef
-- ^ The name of the entity as Haskell understands it.
, entityDB :: !EntityNameDB
-- ^ The name of the database table corresponding to the entity.
, entitySchema :: !(Maybe Text)
-- ^ The schema name of the database table.
, entityId :: !EntityIdDef
-- ^ The entity's primary key or identifier.
, entityAttrs :: ![Attr]
Expand Down
1 change: 1 addition & 0 deletions persistent/test/Database/Persist/THSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ spec = describe "THSpec" $ do
EntityDef
{ entityHaskell = EntityNameHS "HasSimpleCascadeRef"
, entityDB = EntityNameDB "HasSimpleCascadeRef"
, entitySchema = Nothing
, entityId =
EntityIdField FieldDef
{ fieldHaskell = FieldNameHS "Id"
Expand Down

0 comments on commit e7d27b6

Please sign in to comment.