diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index f6c130633..022d71a1c 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -437,7 +437,7 @@ createBackend logFunc serverVersion smap conn = , connCommit = const $ PG.commit conn , connRollback = const $ PG.rollback conn , connEscapeFieldName = escapeF - , connEscapeTableName = escapeE . getEntityDBName + , connEscapeTableName = \entDef -> prependSchemaAndEscape entDef $ getEntityDBName entDef , connEscapeRawName = escape , connNoLimit = "LIMIT ALL" , connRDBMS = "postgresql" @@ -760,7 +760,8 @@ getColumns :: (Text -> IO Statement) -> EntityDef -> [Column] -> IO [Either Text (Either Column (ConstraintNameDB, [FieldNameDB]))] getColumns getter def cols = do - let sqlv = T.concat + let tableSchema = fromMaybe "current_schema()" (getEntityDBSchema def) + sqlv = T.concat [ "SELECT " , "column_name " , ",is_nullable " @@ -772,7 +773,7 @@ getColumns getter def cols = do , ",character_maximum_length " , "FROM information_schema.columns " , "WHERE table_catalog=current_database() " - , "AND table_schema=current_schema() " + , "AND table_schema=" <> tableSchema <> " " , "AND table_name=? " ] @@ -797,7 +798,7 @@ getColumns getter def cols = do , "information_schema.table_constraints AS k " , "WHERE c.table_catalog=current_database() " , "AND c.table_catalog=k.table_catalog " - , "AND c.table_schema=current_schema() " + , "AND c.table_schema=" <> tableSchema <> " " , "AND c.table_schema=k.table_schema " , "AND c.table_name=? " , "AND c.table_name=k.table_name " @@ -1338,6 +1339,12 @@ showAlter table (DropReference cname) = T.concat , escapeC cname ] +prependSchemaAndEscape :: EntityDef -> EntityNameDB -> Text +prependSchemaAndEscape entDef entDBName = case getEntityDBSchema entDef of + Nothing -> escapeE entDBName + Just "public" -> escapeE entDBName + Just schema -> escape schema <> "." <> escapeE entDBName + -- | Get the SQL string for the table that a PeristEntity represents. -- Useful for raw SQL queries. tableName :: (PersistEntity record) => record -> Text @@ -1574,7 +1581,7 @@ mockMigration mig = do , connCommit = undefined , connRollback = undefined , connEscapeFieldName = escapeF - , connEscapeTableName = escapeE . getEntityDBName + , connEscapeTableName = \entDef -> prependSchemaAndEscape entDef $ getEntityDBName entDef , connEscapeRawName = escape , connNoLimit = undefined , connRDBMS = undefined