diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index a34aafda6..75b322faa 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -469,7 +469,7 @@ createBackend logFunc serverVersion smap conn = , connCommit = const $ PG.commit conn , connRollback = const $ PG.rollback conn , connEscapeFieldName = escapeF - , connEscapeTableName = escapeE . getEntityDBName + , connEscapeTableName = entityIdentifier , connEscapeRawName = escape , connNoLimit = "LIMIT ALL" , connRDBMS = "postgresql" @@ -498,7 +498,7 @@ insertSql' ent vals = (fieldNames, placeholders) = unzip (Util.mkInsertPlaceholders ent escapeF) sql = T.concat [ "INSERT INTO " - , escapeE $ getEntityDBName ent + , entityIdentifier ent , if null (getEntityFields ent) then " DEFAULT VALUES" else T.concat @@ -514,7 +514,7 @@ upsertSql' :: EntityDef -> NonEmpty (FieldNameHS, FieldNameDB) -> Text -> Text upsertSql' ent uniqs updateVal = T.concat [ "INSERT INTO " - , escapeE (getEntityDBName ent) + , entityIdentifier ent , "(" , T.intercalate "," fieldNames , ") VALUES (" @@ -543,7 +543,7 @@ insertManySql' ent valss = (fieldNames, placeholders)= unzip (Util.mkInsertPlaceholders ent escapeF) sql = T.concat [ "INSERT INTO " - , escapeE (getEntityDBName ent) + , entityIdentifier ent , "(" , T.intercalate "," fieldNames , ") VALUES (" @@ -626,14 +626,16 @@ withStmt' conn query vals = doesTableExist :: (Text -> IO Statement) -> EntityNameDB + -> (Maybe SchemaNameDB) -> IO Bool -doesTableExist getter (EntityNameDB name) = do +doesTableExist getter (EntityNameDB name) mSchema = do stmt <- getter sql with (stmtQuery stmt vals) (\src -> runConduit $ src .| start) where - sql = "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" - <> " AND schemaname != 'information_schema' AND tablename=?" - vals = [PersistText name] + schema = maybe "public" escapeS mSchema + sql = "SELECT COUNT(*) FROM pg_catalog.pg_tables " + <> "WHERE tablename=? AND schemaname=?" + vals = [PersistText name, PersistText schema] start = await >>= maybe (error "No results when checking doesTableExist") start' start' [PersistInt64 0] = finish False @@ -651,12 +653,13 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do ([], old'') -> do exists' <- if null old - then doesTableExist getter name + then doesTableExist getter name schema else return True return $ Right $ migrationText exists' old'' (errs, _) -> return $ Left errs where name = getEntityDBName entity + schema = getEntitySchema entity (newcols', udefs, fdefs) = postgresMkColumns allDefs entity migrationText exists' old'' | not exists' = @@ -664,8 +667,8 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do | otherwise = let (acs, ats) = getAlters allDefs entity (newcols, udspair) old' - acs' = map (AlterColumn name) acs - ats' = map (AlterTable name) ats + acs' = map (AlterColumn name schema) acs + ats' = map (AlterTable name schema) ats in acs' ++ ats' where @@ -679,7 +682,7 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do (addTable newcols entity) : uniques ++ references ++ foreignsAlt where uniques = flip concatMap udspair $ \(uname, ucols) -> - [AlterTable name $ AddUniqueConstraint uname ucols] + [AlterTable name schema $ AddUniqueConstraint uname ucols] references = mapMaybe (\Column { cName, cReference } -> @@ -692,12 +695,14 @@ mkForeignAlt :: EntityDef -> ForeignDef -> Maybe AlterDB -mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference +mkForeignAlt entity fdef = pure $ AlterColumn tableName_ schemaName_ addReference where tableName_ = getEntityDBName entity + schemaName_ = getEntitySchema entity addReference = AddReference (foreignRefTableDBName fdef) + (foreignRefSchemaDBName fdef) constraintName childfields escapedParentFields @@ -711,17 +716,22 @@ mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference addTable :: [Column] -> EntityDef -> AlterDB addTable cols entity = - AddTable $ T.concat + AddTable $ T.concat $ + case schema of + Nothing -> stmt + -- Lower case e: see Database.Persist.Sql.Migration + Just s -> "CREATe SCHEMA IF NOT EXISTS " <> s <> ";\n" : stmt + where + stmt = -- Lower case e: see Database.Persist.Sql.Migration [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! - , escapeE name + , entityIdentifier entity , "(" , idtxt , if null nonIdCols then "" else "," , T.intercalate "," $ map showColumn nonIdCols , ")" ] - where nonIdCols = case entityPrimary entity of Just _ -> @@ -735,6 +745,8 @@ addTable cols entity = name = getEntityDBName entity + schema = + escapeS <$> getEntitySchema entity idtxt = case getEntityId entity of EntityIdNaturalKey pdef -> @@ -773,7 +785,7 @@ data AlterColumn | Default Column Text | NoDefault Column | Update' Column Text - | AddReference EntityNameDB ConstraintNameDB [FieldNameDB] [Text] FieldCascade + | AddReference EntityNameDB (Maybe SchemaNameDB) ConstraintNameDB [FieldNameDB] [Text] FieldCascade | DropReference ConstraintNameDB deriving Show @@ -783,8 +795,8 @@ data AlterTable deriving Show data AlterDB = AddTable Text - | AlterColumn EntityNameDB AlterColumn - | AlterTable EntityNameDB AlterTable + | AlterColumn EntityNameDB (Maybe SchemaNameDB) AlterColumn + | AlterTable EntityNameDB (Maybe SchemaNameDB) AlterTable deriving Show -- | Returns all of the columns in the given table currently in the database. @@ -865,7 +877,7 @@ getColumns getter def cols = do $ groupBy ((==) `on` fst) rows processColumns = CL.mapM $ \x'@((PersistText cname) : _) -> do - col <- liftIO $ getColumn getter (getEntityDBName def) x' (Map.lookup cname refMap) + col <- liftIO $ getColumn getter (getEntityDBName def) (getEntitySchema def) x' (Map.lookup cname refMap) pure $ case col of Left e -> Left e Right c -> Right $ Left c @@ -923,18 +935,22 @@ getAlters defs def (c1, u1) (c2, u2) = getColumn :: (Text -> IO Statement) -> EntityNameDB + -> Maybe SchemaNameDB -> [PersistValue] -> Maybe (EntityNameDB, ConstraintNameDB) -> IO (Either Text Column) -getColumn getter tableName' [ PersistText columnName - , PersistText isNullable - , PersistText typeName - , defaultValue - , generationExpression - , numericPrecision - , numericScale - , maxlen - ] refName_ = runExceptT $ do +getColumn getter + tableName' + schemaName' + [ PersistText columnName + , PersistText isNullable + , PersistText typeName + , defaultValue + , generationExpression + , numericPrecision + , numericScale + , maxlen + ] refName_ = runExceptT $ do defaultValue' <- case defaultValue of PersistNull -> @@ -974,7 +990,11 @@ getColumn getter tableName' [ PersistText columnName , cGenerated = fmap stripSuffixes generationExpression' , cDefaultConstraintName = Nothing , cMaxLen = Nothing - , cReference = fmap (\(a,b,c,d) -> ColumnReference a b (mkCascade c d)) ref + , -- The ColumnReference always has a non-null SchemaNameDB. The default schema name + -- in Postgres is "public", but Postgres doesn't know whether a table with + -- schema "public" was explicitly given that schema by the Persistent + -- app developer. + cReference = fmap (\(a,b,c,d,e) -> ColumnReference a (Just b) c (mkCascade d e)) ref } where @@ -1012,10 +1032,15 @@ getColumn getter tableName' [ PersistText columnName Nothing -> loop' ps Just t' -> t' + getRef + :: FieldNameDB + -> (a, ConstraintNameDB) + -> IO (Maybe (EntityNameDB, SchemaNameDB, ConstraintNameDB, Text, Text)) getRef cname (_, refName') = do let sql = T.concat [ "SELECT DISTINCT " , "ccu.table_name, " + , "ccu.table_schema, " , "tc.constraint_name, " , "rc.update_rule, " , "rc.delete_rule " @@ -1029,6 +1054,7 @@ getColumn getter tableName' [ PersistText columnName , "WHERE tc.constraint_type='FOREIGN KEY' " , "AND kcu.ordinal_position=1 " , "AND kcu.table_name=? " + , "AND kcu.table_schema=? " , "AND kcu.column_name=? " , "AND tc.constraint_name=?" ] @@ -1037,6 +1063,7 @@ getColumn getter tableName' [ PersistText columnName with (stmtQuery stmt [ PersistText $ unEntityNameDB tableName' + , PersistText $ fromMaybe "public" $ unSchemaNameDB <$> schemaName' , PersistText $ unFieldNameDB cname , PersistText $ unConstraintNameDB refName' ] @@ -1045,8 +1072,8 @@ getColumn getter tableName' [ PersistText columnName case cntrs of [] -> return Nothing - [[PersistText table, PersistText constraint, PersistText updRule, PersistText delRule]] -> - return $ Just (EntityNameDB table, ConstraintNameDB constraint, updRule, delRule) + [[PersistText table, PersistText schema, PersistText constraint, PersistText updRule, PersistText delRule]] -> + return $ Just (EntityNameDB table, SchemaNameDB schema, ConstraintNameDB constraint, updRule, delRule) xs -> error $ mconcat [ "Postgresql.getColumn: error fetching constraints. Expected a single result for foreign key query for table: " @@ -1098,7 +1125,7 @@ getColumn getter tableName' [ PersistText columnName , " Specify the values as numeric(total_digits, digits_after_decimal_place)." ] -getColumn _ _ columnName _ = +getColumn _ _ _ columnName _ = return $ Left $ T.pack $ "Invalid result from information_schema: " ++ show columnName -- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer @@ -1140,6 +1167,7 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName -> [AddReference (crTableName colRef) + (crSchemaName colRef) (crConstraintName colRef) [name] (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) @@ -1217,14 +1245,16 @@ getAddReference -> FieldNameDB -> ColumnReference -> Maybe AlterDB -getAddReference allDefs entity cname cr@ColumnReference {crTableName = s, crConstraintName=constraintName} = do +getAddReference allDefs entity cname cr@ColumnReference {crTableName = s, crSchemaName = refschema, crConstraintName=constraintName} = do guard $ Just cname /= fmap fieldDB (getEntityIdField entity) pure $ AlterColumn table - (AddReference s constraintName [cname] id_ (crFieldCascade cr) + schema + (AddReference s refschema constraintName [cname] id_ (crFieldCascade cr) ) where table = getEntityDBName entity + schema = getEntitySchema entity id_ = fromMaybe (error $ "Could not find ID of entity " ++ show s) @@ -1266,90 +1296,90 @@ showSqlType (SqlOther t) = t showAlterDb :: AlterDB -> (Bool, Text) showAlterDb (AddTable s) = (False, s) -showAlterDb (AlterColumn t ac) = - (isUnsafe ac, showAlter t ac) +showAlterDb (AlterColumn t s ac) = + (isUnsafe ac, showAlter t s ac) where isUnsafe (Drop _ safeRemove) = not safeRemove isUnsafe _ = False -showAlterDb (AlterTable t at) = (False, showAlterTable t at) +showAlterDb (AlterTable t s at) = (False, showAlterTable t s at) -showAlterTable :: EntityNameDB -> AlterTable -> Text -showAlterTable table (AddUniqueConstraint cname cols) = T.concat +showAlterTable :: EntityNameDB -> Maybe SchemaNameDB -> AlterTable -> Text +showAlterTable table schema (AddUniqueConstraint cname cols) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " ADD CONSTRAINT " , escapeC cname , " UNIQUE(" , T.intercalate "," $ map escapeF cols , ")" ] -showAlterTable table (DropConstraint cname) = T.concat +showAlterTable table schema (DropConstraint cname) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " DROP CONSTRAINT " , escapeC cname ] -showAlter :: EntityNameDB -> AlterColumn -> Text -showAlter table (ChangeType c t extra) = +showAlter :: EntityNameDB -> Maybe SchemaNameDB -> AlterColumn -> Text +showAlter table schema (ChangeType c t extra) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " ALTER COLUMN " , escapeF (cName c) , " TYPE " , showSqlType t , extra ] -showAlter table (IsNull c) = +showAlter table schema (IsNull c) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " ALTER COLUMN " , escapeF (cName c) , " DROP NOT NULL" ] -showAlter table (NotNull c) = +showAlter table schema (NotNull c) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " ALTER COLUMN " , escapeF (cName c) , " SET NOT NULL" ] -showAlter table (Add' col) = +showAlter table schema (Add' col) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " ADD COLUMN " , showColumn col ] -showAlter table (Drop c _) = +showAlter table schema (Drop c _) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " DROP COLUMN " , escapeF (cName c) ] -showAlter table (Default c s) = +showAlter table schema (Default c s) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " ALTER COLUMN " , escapeF (cName c) , " SET DEFAULT " , s ] -showAlter table (NoDefault c) = T.concat +showAlter table schema (NoDefault c) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " ALTER COLUMN " , escapeF (cName c) , " DROP DEFAULT" ] -showAlter table (Update' c s) = T.concat +showAlter table schema (Update' c s) = T.concat [ "UPDATE " - , escapeE table + , escapeES table schema , " SET " , escapeF (cName c) , "=" @@ -1358,22 +1388,22 @@ showAlter table (Update' c s) = T.concat , escapeF (cName c) , " IS NULL" ] -showAlter table (AddReference reftable fkeyname t2 id2 cascade) = T.concat +showAlter table schema (AddReference reftable refschema fkeyname t2 id2 cascade) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " ADD CONSTRAINT " , escapeC fkeyname , " FOREIGN KEY(" , T.intercalate "," $ map escapeF t2 , ") REFERENCES " - , escapeE reftable + , escapeES reftable refschema , "(" , T.intercalate "," id2 , ")" ] <> renderFieldCascade cascade -showAlter table (DropReference cname) = T.concat +showAlter table schema (DropReference cname) = T.concat [ "ALTER TABLE " - , escapeE table + , escapeES table schema , " DROP CONSTRAINT " , escapeC cname ] @@ -1397,6 +1427,8 @@ escapeE = escapeWith escape escapeF :: FieldNameDB -> Text escapeF = escapeWith escape +escapeS :: SchemaNameDB -> Text +escapeS = escapeWith escape escape :: Text -> Text escape s = @@ -1406,6 +1438,14 @@ escape s = go ('"':xs) = "\"\"" ++ go xs go (x:xs) = x : go xs +entityIdentifier :: EntityDef -> Text +entityIdentifier ed = escapeES (getEntityDBName ed) (getEntitySchema ed) + +escapeES :: EntityNameDB -> Maybe SchemaNameDB -> Text +escapeES entityName schemaName = case schemaName of + Nothing -> escapeE entityName + Just schema -> escapeS schema <> "." <> escapeE entityName + -- | Information required to connect to a PostgreSQL database -- using @persistent@'s generic facilities. These values are the -- same that are given to 'withPostgresqlPool'. @@ -1563,12 +1603,13 @@ mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do (errs, _) -> return $ Left errs where name = getEntityDBName entity + schema = getEntitySchema entity migrationText exists' old'' = if not exists' then createText newcols fdefs udspair else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old' - acs' = map (AlterColumn name) acs - ats' = map (AlterTable name) ats + acs' = map (AlterColumn name schema) acs + ats' = map (AlterTable name schema) ats in acs' ++ ats' where old' = partitionEithers old'' @@ -1582,7 +1623,7 @@ mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do (addTable newcols entity) : uniques ++ references ++ foreignsAlt where uniques = flip concatMap udspair $ \(uname, ucols) -> - [AlterTable name $ AddUniqueConstraint uname ucols] + [AlterTable name schema $ AddUniqueConstraint uname ucols] references = mapMaybe (\Column { cName, cReference } -> @@ -2065,4 +2106,3 @@ instance (PersistUniqueWrite b) => PersistUniqueWrite (RawPostgresql b) where upsertBy uniq rec = withReaderT persistentBackend . upsertBy uniq rec putMany = withReaderT persistentBackend . putMany #endif - diff --git a/persistent/Database/Persist/Quasi/Internal.hs b/persistent/Database/Persist/Quasi/Internal.hs index ce817b216..9e418c4f3 100644 --- a/persistent/Database/Persist/Quasi/Internal.hs +++ b/persistent/Database/Persist/Quasi/Internal.hs @@ -1392,6 +1392,14 @@ takeForeign ps entityName = takeRefTable EntityNameHS refTableName , foreignRefTableDBName = EntityNameDB $ psToDBName ps refTableName + , -- TODO: The existing foreign key syntax for + -- UnboundForeignDef is not sufficiently rich to + -- allow specifying the schema of the foreign + -- relation. We need to add the ability to parse + -- schema=foo directives inline for foreign keys + -- and insert those values here. + foreignRefSchemaDBName = + Nothing , foreignConstraintNameHaskell = constraintName , foreignConstraintNameDBName = diff --git a/persistent/Database/Persist/Sql/Internal.hs b/persistent/Database/Persist/Sql/Internal.hs index c8e099fee..22bf143cd 100644 --- a/persistent/Database/Persist/Sql/Internal.hs +++ b/persistent/Database/Persist/Sql/Internal.hs @@ -161,8 +161,8 @@ mkColumns allDefs t overrides = mkColumnReference :: FieldDef -> Maybe ColumnReference mkColumnReference fd = fmap - (\(tName, cName) -> - ColumnReference tName cName $ overrideNothings $ fieldCascade fd + (\(tName, sName, cName) -> + ColumnReference tName sName cName $ overrideNothings $ fieldCascade fd ) $ ref (fieldDB fd) (fieldReference fd) (fieldAttrs fd) @@ -178,27 +178,28 @@ mkColumns allDefs t overrides = ref :: FieldNameDB -> ReferenceDef -> [FieldAttr] - -> Maybe (EntityNameDB, ConstraintNameDB) -- table name, constraint name + -> Maybe (EntityNameDB, Maybe SchemaNameDB, ConstraintNameDB) -- table name, schema name, constraint name ref c fe [] | ForeignRef f <- fe = - Just (resolveTableName allDefs f, refNameFn tableName c) + let (table, schema) = resolveTableName allDefs f + in Just (table, schema, refNameFn tableName c) | otherwise = Nothing ref _ _ (FieldAttrNoreference:_) = Nothing ref c fe (a:as) = case a of FieldAttrReference x -> do - (_, constraintName) <- ref c fe as - pure (EntityNameDB x, constraintName) + (_, schema, constraintName) <- ref c fe as + pure (EntityNameDB x, schema, constraintName) FieldAttrConstraint x -> do - (tableName_, _) <- ref c fe as - pure (tableName_, ConstraintNameDB x) + (tableName_, schema, _) <- ref c fe as + pure (tableName_, schema, ConstraintNameDB x) _ -> ref c fe as refName :: EntityNameDB -> FieldNameDB -> ConstraintNameDB refName (EntityNameDB table) (FieldNameDB column) = ConstraintNameDB $ Data.Monoid.mconcat [table, "_", column, "_fkey"] -resolveTableName :: [EntityDef] -> EntityNameHS -> EntityNameDB +resolveTableName :: [EntityDef] -> EntityNameHS -> (EntityNameDB, Maybe SchemaNameDB) resolveTableName [] (EntityNameHS t) = error $ "Table not found: " `Data.Monoid.mappend` T.unpack t resolveTableName (e:es) hn - | getEntityHaskellName e == hn = getEntityDBName e + | getEntityHaskellName e == hn = (getEntityDBName e, getEntitySchema e) | otherwise = resolveTableName es hn diff --git a/persistent/Database/Persist/Sql/Types.hs b/persistent/Database/Persist/Sql/Types.hs index a9f592d86..f2fe2e18d 100644 --- a/persistent/Database/Persist/Sql/Types.hs +++ b/persistent/Database/Persist/Sql/Types.hs @@ -39,6 +39,7 @@ data ColumnReference = ColumnReference -- ^ The table name that the -- -- @since 2.11.0.0 + , crSchemaName :: !(Maybe SchemaNameDB) , crConstraintName :: !ConstraintNameDB -- ^ The name of the foreign key constraint. -- @@ -137,4 +138,3 @@ defaultConnectionPoolConfig = ConnectionPoolConfig 1 600 10 -- processing). newtype Single a = Single {unSingle :: a} deriving (Eq, Ord, Show, Read) - diff --git a/persistent/Database/Persist/Types/Base.hs b/persistent/Database/Persist/Types/Base.hs index e7f88d353..e802b1fac 100644 --- a/persistent/Database/Persist/Types/Base.hs +++ b/persistent/Database/Persist/Types/Base.hs @@ -554,6 +554,7 @@ type ForeignFieldDef = (FieldNameHS, FieldNameDB) data ForeignDef = ForeignDef { foreignRefTableHaskell :: !EntityNameHS , foreignRefTableDBName :: !EntityNameDB + , foreignRefSchemaDBName :: !(Maybe SchemaNameDB) , foreignConstraintNameHaskell :: !ConstraintNameHS , foreignConstraintNameDBName :: !ConstraintNameDB , foreignFieldCascade :: !FieldCascade