diff --git a/persistent-mysql/Database/Persist/MySQL.hs b/persistent-mysql/Database/Persist/MySQL.hs index db2952ee7..980fb70f2 100644 --- a/persistent-mysql/Database/Persist/MySQL.hs +++ b/persistent-mysql/Database/Persist/MySQL.hs @@ -349,7 +349,9 @@ migrate' connectInfo allDefs getter val = do $ map (findTypeAndMaxLen name) ucols let foreigns = do - Column { cName=cname, cReference=Just (refTblName, refConstraintName) } <- newcols + Column { cName=cname, cReference=Just cRef } <- newcols + let refConstraintName = crConstraintName cRef + let refTblName = crTableName cRef let refTarget = addReference allDefs refConstraintName refTblName cname @@ -386,7 +388,7 @@ migrate' connectInfo allDefs getter val = do ( map (\c -> case cReference c of - Just (_,fk) -> + Just ColumnReference {crConstraintName=fk} -> case find (\f -> fk == foreignConstraintNameDBName f) fdefs of Just _ -> c { cReference = Nothing } Nothing -> c @@ -614,7 +616,7 @@ getColumn -> (Text -> IO Statement) -> DBName -> [PersistValue] - -> Maybe (DBName, DBName) + -> Maybe ColumnReference -> IO (Either Text Column) getColumn connectInfo getter tname [ PersistText cname , PersistText null_ @@ -623,7 +625,7 @@ getColumn connectInfo getter tname [ PersistText cname , colMaxLen , colPrecision , colScale - , default'] refName = + , default'] cRef = fmap (either (Left . pack) Right) $ runExceptT $ do -- Default value @@ -638,7 +640,7 @@ getColumn connectInfo getter tname [ PersistText cname Right t -> return (Just t) _ -> fail $ "Invalid default column: " ++ show default' - ref <- getRef refName + ref <- getRef (crConstraintName <$> cRef) let colMaxLen' = case colMaxLen of PersistInt64 l -> Just (fromIntegral l) _ -> Nothing @@ -660,7 +662,7 @@ getColumn connectInfo getter tname [ PersistText cname , cReference = ref } where getRef Nothing = return Nothing - getRef (Just (_, refName')) = do + getRef (Just refName') = do -- Foreign key (if any) stmt <- lift . getter $ T.concat [ "SELECT REFERENCED_TABLE_NAME, " @@ -684,7 +686,9 @@ getColumn connectInfo getter tname [ PersistText cname case cntrs of [] -> return Nothing [[PersistText tab, PersistText ref, PersistInt64 pos]] -> - return $ if pos == 1 then Just (DBName tab, DBName ref) else Nothing + -- TODO: Fix cascade reference is ignored + return $ if pos == 1 then Just (ColumnReference (DBName tab) (DBName ref) noCascade) + else Nothing xs -> error $ mconcat [ "MySQL.getColumn/getRef: error fetching constraints. Expected a single result for foreign key query for table: " , T.unpack (unDBName tname) @@ -756,7 +760,7 @@ getAlters allDefs edef (c1, u1) (c2, u2) = dropColumn col = map ((,) (cName col)) $ - [DropReference n | Just (_, n) <- [cReference col]] ++ + [DropReference (crConstraintName cr) | Just cr <- [cReference col]] ++ [Drop] getAltersU [] old = map (DropUniqueConstraint . fst) old @@ -795,7 +799,7 @@ findAlters edef allDefs col@(Column name isNull type_ def _defConstraintName max [] -> case ref of Nothing -> ([(name, Add' col)],[]) - Just (tname, cname) -> + Just ColumnReference {crTableName=tname, crConstraintName=cname} -> let cnstr = [addReference allDefs cname tname name] in (map ((,) tname) (Add' col : cnstr), cols) @@ -803,13 +807,13 @@ findAlters edef allDefs col@(Column name isNull type_ def _defConstraintName max let -- Foreign key refDrop = case (ref == ref', ref') of - (False, Just (_, cname)) -> + (False, Just ColumnReference {crConstraintName=cname}) -> [(name, DropReference cname)] _ -> [] refAdd = case (ref == ref', ref) of - (False, Just (tname, cname)) + (False, Just ColumnReference {crTableName=tname, crConstraintName=cname}) | tname /= entityDB edef , cname /= fieldDB (entityId edef) -> @@ -851,7 +855,7 @@ showColumn (Column n nu t def _defConstraintName maxLen ref) = concat else " DEFAULT " ++ T.unpack s , case ref of Nothing -> "" - Just (s, _) -> " REFERENCES " ++ escapeDBName s + Just cRef -> " REFERENCES " ++ escapeDBName (crTableName cRef) ] @@ -1081,8 +1085,8 @@ mockMigrate _connectInfo allDefs _getter val = do AddUniqueConstraint uname $ map (findTypeAndMaxLen name) ucols ] let foreigns = do - Column { cName=cname, cReference=Just (refTblName, refConstraintName) } <- newcols - return $ AlterColumn name (refTblName, addReference allDefs refConstraintName refTblName cname) + Column { cName=cname, cReference= Just ColumnReference{crTableName = refTable, crConstraintName = refConstr}} <- newcols + return $ AlterColumn name (refTable, addReference allDefs refConstr refTable cname) let foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef)) in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignRefTableDBName fdef) (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs diff --git a/persistent-postgresql/ChangeLog.md b/persistent-postgresql/ChangeLog.md index 7d9b14a23..7e7633963 100644 --- a/persistent-postgresql/ChangeLog.md +++ b/persistent-postgresql/ChangeLog.md @@ -2,6 +2,10 @@ ## (Unreleased) 2.11.0.0 +* Foreign Key improvements [#1121] https://github.com/yesodweb/persistent/pull/1121 + * It is now supported to refer to a table with an auto generated Primary Kay + * It is now supported to refer to non-primary fields, using the keyword `References` + * Implement interval support. [#1053](https://github.com/yesodweb/persistent/pull/1053) * [#1060](https://github.com/yesodweb/persistent/pull/1060) * The QuasiQuoter now supports `OnDelete` and `OnUpdate` cascade options. diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index e6b5153b7..7d63a2b99 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE NamedFieldPuns #-} @@ -849,7 +850,7 @@ getColumns getter def cols = do us <- with (stmtQuery stmt' vals) (\src -> runConduit $ src .| helperU) return $ cs ++ us where - refMap = Map.fromList $ foldl' ref [] cols + refMap = fmap (\cr -> (crTableName cr, crConstraintName cr)) $ Map.fromList $ foldl' ref [] cols where ref rs c = case cReference c of Nothing -> rs (Just r) -> (unDBName $ cName c, r) : rs @@ -942,7 +943,8 @@ getColumn getter tableName' [PersistText columnName, PersistText isNullable, Per , cDefault = fmap stripSuffixes d'' , cDefaultConstraintName = Nothing , cMaxLen = Nothing - , cReference = ref + -- TODO: Fix cascade reference is ignored + , cReference = fmap (\(a,b) -> ColumnReference a b noCascade) ref } where stripSuffixes t = @@ -1053,9 +1055,9 @@ findAlters defs edef col@(Column name isNull sqltype def _defConstraintName _max ([(name, Add' col)], cols) Just (Column _oldName isNull' sqltype' def' _defConstraintName' _maxLen' ref') -> let refDrop Nothing = [] - refDrop (Just (_, cname)) = [(name, DropReference cname)] + refDrop (Just ColumnReference {crConstraintName=cname}) = [(name, DropReference cname)] refAdd Nothing = [] - refAdd (Just (tname, a)) = + refAdd (Just ColumnReference {crTableName=tname, crConstraintName=a}) = case find ((==tname) . entityDB) defs of Just refdef | entityDB edef /= tname @@ -1066,6 +1068,7 @@ findAlters defs edef col@(Column name isNull sqltype def _defConstraintName _max a [name] (Util.dbIdColumnsEsc escape refdef) + -- TODO: Fix cascade reference is ignored noCascade ) ] @@ -1073,7 +1076,7 @@ findAlters defs edef col@(Column name isNull sqltype def _defConstraintName _max Nothing -> error $ "could not find the entityDef for reftable[" ++ show tname ++ "]" modRef = - if fmap snd ref == fmap snd ref' + if fmap crConstraintName ref == fmap crConstraintName ref' then [] else refDrop ref' ++ refAdd ref modNull = case (isNull, isNull') of @@ -1113,13 +1116,14 @@ getAddReference :: [EntityDef] -> EntityDef -> DBName - -> (DBName, DBName) + -> ColumnReference -> Maybe AlterDB -getAddReference allDefs entity cname (s, constraintName) = do +getAddReference allDefs entity cname ColumnReference {crTableName = s, crConstraintName=constraintName} = do guard $ table /= s && cname /= fieldDB (entityId entity) pure $ AlterColumn table ( s + -- TODO: Fix cascade reference is ignored , AddReference constraintName [cname] id_ noCascade ) where diff --git a/persistent-sqlite/ChangeLog.md b/persistent-sqlite/ChangeLog.md index 7059d066f..c52e5754b 100644 --- a/persistent-sqlite/ChangeLog.md +++ b/persistent-sqlite/ChangeLog.md @@ -2,6 +2,11 @@ ## (Unreleased) 2.11.0.0 +* Foreign Key improvements [#1121] (https://github.com/yesodweb/persistent/pull/1121) + * It is now supported to refer to a table with an auto generated Primary Kay + * It is now supported to refer to non-primary fields, using the keyword `References` + * It is now supported to have cascade options for simple/single-field Foreign Keys + * [#1060](https://github.com/yesodweb/persistent/pull/1060) * The QuasiQuoter now supports `OnDelete` and `OnUpdate` cascade options. diff --git a/persistent-sqlite/Database/Persist/Sqlite.hs b/persistent-sqlite/Database/Persist/Sqlite.hs index 50c58ec00..2ecbc04d4 100644 --- a/persistent-sqlite/Database/Persist/Sqlite.hs +++ b/persistent-sqlite/Database/Persist/Sqlite.hs @@ -581,8 +581,14 @@ sqlColumn noRef (Column name isNull typ def _cn _maxLen ref) = T.concat , mayDefault def , case ref of Nothing -> "" - Just (table, _) -> if noRef then "" else " REFERENCES " <> escape table + Just ColumnReference {crTableName=table, crFieldCascade=cascadeOpts} -> + if noRef then "" else " REFERENCES " <> escape table + <> onDelete cascadeOpts <> onUpdate cascadeOpts ] + where + + onDelete opts = maybe "" (T.append " ON DELETE " . renderCascadeAction) (fcOnDelete opts) + onUpdate opts = maybe "" (T.append " ON UPDATE " . renderCascadeAction) (fcOnUpdate opts) sqlForeign :: ForeignDef -> Text sqlForeign fdef = T.concat $ diff --git a/persistent-template/Database/Persist/TH.hs b/persistent-template/Database/Persist/TH.hs index b4c74b4ac..e5a512039 100644 --- a/persistent-template/Database/Persist/TH.hs +++ b/persistent-template/Database/Persist/TH.hs @@ -290,7 +290,7 @@ data FieldSqlTypeExp = FieldSqlTypeExp FieldDef SqlTypeExp instance Lift FieldSqlTypeExp where lift (FieldSqlTypeExp FieldDef{..} sqlTypeExp) = - [|FieldDef fieldHaskell fieldDB fieldType $(lift sqlTypeExp) fieldAttrs fieldStrict fieldReference fieldComments|] + [|FieldDef fieldHaskell fieldDB fieldType $(lift sqlTypeExp) fieldAttrs fieldStrict fieldReference fieldComments fieldCascadeOpts|] #if MIN_VERSION_template_haskell(2,16,0) liftTyped = unsafeTExpCoerce . lift #endif @@ -1104,11 +1104,11 @@ mkEntity entityMap mps t = do fpv <- mkFromPersistValues mps t utv <- mkUniqueToValues $ entityUniques t puk <- mkUniqueKeys t + let primaryField = entityId t + fields <- mapM (mkField mps t) $ primaryField : entityFields t fkc <- mapM (mkForeignKeysComposite mps t) $ entityForeigns t - let primaryField = entityId t - fields <- mapM (mkField mps t) $ primaryField : entityFields t toFieldNames <- mkToFieldNames $ entityUniques t (keyTypeDec, keyInstanceDecs) <- mkKeyTypeDec mps t @@ -1335,7 +1335,8 @@ mkLenses mps ent = fmap mconcat $ forM (entityFields ent) $ \field -> do ] mkForeignKeysComposite :: MkPersistSettings -> EntityDef -> ForeignDef -> Q [Dec] -mkForeignKeysComposite mps t ForeignDef {..} = do +mkForeignKeysComposite mps t ForeignDef {..} = + if not foreignToPrimary then return [] else do let fieldName f = mkName $ unpack $ recName mps (entityHaskell t) f let fname = fieldName foreignConstraintNameHaskell let reftableString = unpack $ unHaskellName foreignRefTableHaskell @@ -1343,8 +1344,12 @@ mkForeignKeysComposite mps t ForeignDef {..} = do let tablename = mkName $ unpack $ entityText t recordName <- newName "record" - let fldsE = map (\((foreignName, _),_) -> VarE (fieldName foreignName) - `AppE` VarE recordName) foreignFields + let mkFldE ((foreignName, _),ff) = case ff of + (HaskellName {unHaskellName = "Id"}, DBName {unDBName = "id"}) + -> AppE (VarE $ mkName "toBackendKey") $ + VarE (fieldName foreignName) `AppE` VarE recordName + _ -> VarE (fieldName foreignName) `AppE` VarE recordName + let fldsE = map mkFldE foreignFields let mkKeyE = foldl' AppE (maybeExp foreignNullable $ ConE reftableKeyName) fldsE let fn = FunD fname [normalClause [VarP recordName] mkKeyE] @@ -1689,18 +1694,22 @@ liftAndFixKeys entityMap EntityDef{..} = |] liftAndFixKey :: EntityMap -> FieldDef -> Q Exp -liftAndFixKey entityMap (FieldDef a b c sqlTyp e f fieldRef mcomments) = - [|FieldDef a b c $(sqlTyp') e f fieldRef' mcomments|] +liftAndFixKey entityMap (FieldDef a b c sqlTyp e f fieldRef mcomments casc) = + [|FieldDef a b c $(sqlTyp') e f fieldRef' mcomments casc|] where (fieldRef', sqlTyp') = fromMaybe (fieldRef, lift sqlTyp) $ case fieldRef of ForeignRef refName _ft -> case M.lookup refName entityMap of - Nothing -> Nothing + Nothing -> checkCascade Just ent -> case fieldReference $ entityId ent of fr@(ForeignRef _Name ft) -> Just (fr, lift $ SqlTypeExp ft) - _ -> Nothing - _ -> Nothing + _ -> checkCascade + _ -> checkCascade + checkCascade = case casc of + FieldCascade Nothing Nothing -> Nothing + _ -> error $ "cascade field is not allown for field " <> show a + <> ". It doesn't reference any other tables." deriving instance Lift EntityDef diff --git a/persistent-test/src/ForeignKey.hs b/persistent-test/src/ForeignKey.hs index 756952358..a7616f7b2 100644 --- a/persistent-test/src/ForeignKey.hs +++ b/persistent-test/src/ForeignKey.hs @@ -16,6 +16,15 @@ share [mkPersist persistSettings { mpsGeneric = False }, mkMigrate "compositeMig Foreign Parent OnDeleteCascade OnUpdateCascade fkparent pname deriving Show Eq + ParentImplicit + name String + + ChildImplicit + pname String + parentId ParentImplicitId noreference + Foreign ParentImplicit OnDeleteCascade OnUpdateCascade fkparent parentId + deriving Show Eq + ParentComposite name String lastName String @@ -33,6 +42,45 @@ share [mkPersist persistSettings { mpsGeneric = False }, mkMigrate "compositeMig Primary name Foreign SelfReferenced OnDeleteCascade fkparent pname deriving Show Eq + + A + aa String + ab Int + U1 aa + + B + ba String + bb Int + Foreign A OnDeleteCascade fkA ba References aa + deriving Show Eq + + AComposite + aa String + ab Int + U2 aa ab + + BComposite + ba String + bb Int + Foreign AComposite OnDeleteCascade fkAComposite ba bb References aa ab + deriving Show Eq + + BExplicit + ba AId noreference + Foreign A OnDeleteCascade fkAI ba References Id + deriving Show Eq + + Chain + name String + previous ChainId Maybe noreference + Foreign Chain OnDeleteSetNull fkChain previous References Id + deriving Show Eq + + Chain2 + name String + previous Chain2Id Maybe noreference + Foreign Chain2 OnDeleteCascade fkChain previous References Id + deriving Show Eq |] specsWith :: (MonadIO m, MonadFail m) => RunDb SqlBackend m -> Spec @@ -50,6 +98,13 @@ specsWith runDb = describe "foreign keys options" $ do update kf [ParentName =. "B"] cs <- selectList [] [] fmap (childPname . entityVal) cs @== ["B"] + it "delete cascades on implicit Primary key" $ runDb $ do + kf <- insert $ ParentImplicit "A" + kc <- insert $ ChildImplicit "B" kf + delete kf + cs <- selectList [] [] + let expected = [] :: [Entity ChildImplicit] + cs @== expected it "delete Composite cascades" $ runDb $ do kf <- insert $ ParentComposite "A" "B" kc <- insert $ ChildComposite "A" "B" @@ -64,3 +119,50 @@ specsWith runDb = describe "foreign keys options" $ do srs <- selectList [] [] let expected = [] :: [Entity SelfReferenced] srs @== expected + it "delete cascades with explicit Reference" $ runDb $ do + kf <- insert $ A "A" 40 + kc <- insert $ B "A" 15 + delete kf + return () + cs <- selectList [] [] + let expected = [] :: [Entity B] + cs @== expected + it "delete cascades with explicit Composite Reference" $ runDb $ do + kf <- insert $ AComposite "A" 20 + kc <- insert $ BComposite "A" 20 + delete kf + return () + cs <- selectList [] [] + let expected = [] :: [Entity B] + cs @== expected + it "delete cascades with explicit Composite Reference" $ runDb $ do + kf <- insert $ AComposite "A" 20 + kc <- insert $ BComposite "A" 20 + delete kf + return () + cs <- selectList [] [] + let expected = [] :: [Entity B] + cs @== expected + it "delete cascades with explicit Id field" $ runDb $ do + kf <- insert $ A "A" 20 + kc <- insert $ BExplicit kf + delete kf + return () + cs <- selectList [] [] + let expected = [] :: [Entity B] + cs @== expected + it "deletes sets null with self reference" $ runDb $ do + kf <- insert $ Chain "A" Nothing + insert $ Chain "B" (Just kf) + delete kf + cs <- selectList [] [] + let expected = [Entity {entityKey = ChainKey 2, entityVal = Chain "B" Nothing}] + cs @== expected + it "deletes cascades with self reference to the whole chain" $ runDb $ do + k1 <- insert $ Chain2 "A" Nothing + k2 <- insert $ Chain2 "B" (Just k1) + k3 <- insert $ Chain2 "C" (Just k2) + delete k1 + cs <- selectList [] [] + let expected = [] :: [Entity Chain2] + cs @== expected diff --git a/persistent/ChangeLog.md b/persistent/ChangeLog.md index 7e84def32..05dd6ac42 100644 --- a/persistent/ChangeLog.md +++ b/persistent/ChangeLog.md @@ -2,6 +2,10 @@ ## (Unreleased) 2.11.0.0 +* Foreign Key improvements [#1121] https://github.com/yesodweb/persistent/pull/1121 + * It is now supported to refer to a table with an auto generated Primary Kay + * It is now supported to refer to non-primary fields, using the keyword `References` + * It is now supported to have cascade options for simple/single-field Foreign Keys * Introduces a breaking change to the internal function `mkColumns`, which can now be passed a record of functions to override its default behavior. [#996](https://github.com/yesodweb/persistent/pull/996) * Added explicit `forall` notation to make most API functions play nice when using `TypeApplications`. (e.g. instead of `selectList @_ @_ @User [] []`, you can now write `selectList @User [] []`) [#1006](https://github.com/yesodweb/persistent/pull/1006) * [#1060](https://github.com/yesodweb/persistent/pull/1060) diff --git a/persistent/Database/Persist/Quasi.hs b/persistent/Database/Persist/Quasi.hs index c27a0dcc8..edb069bbd 100644 --- a/persistent/Database/Persist/Quasi.hs +++ b/persistent/Database/Persist/Quasi.hs @@ -562,45 +562,48 @@ fixForeignKeysAll unEnts = map fixForeignKeys unEnts fixForeignKeys (UnboundEntityDef foreigns ent) = ent { entityForeigns = map (fixForeignKey ent) foreigns } - -- check the count and the sqltypes match and update the foreignFields with the names of the primary columns + -- check the count and the sqltypes match and update the foreignFields with the names of the referenced columns fixForeignKey :: EntityDef -> UnboundForeignDef -> ForeignDef - fixForeignKey ent (UnboundForeignDef foreignFieldTexts fdef) = - let pentError = - error $ "could not find table " ++ show (foreignRefTableHaskell fdef) - ++ " fdef=" ++ show fdef ++ " allnames=" - ++ show (map (unHaskellName . entityHaskell . unboundEntityDef) unEnts) - ++ "\n\nents=" ++ show ents - pent = - fromMaybe pentError $ M.lookup (foreignRefTableHaskell fdef) entLookup - in - case entityPrimary pent of - Just pdef -> - if length foreignFieldTexts /= length (compositeFields pdef) - then - lengthError pdef - else - let - fds_ffs = - zipWith (toForeignFields pent) - foreignFieldTexts - (compositeFields pdef) - dbname = - unDBName (entityDB pent) - oldDbName = - unDBName (foreignRefTableDBName fdef) - in fdef - { foreignFields = map snd fds_ffs - , foreignNullable = setNull $ map fst fds_ffs - , foreignRefTableDBName = - DBName dbname - , foreignConstraintNameDBName = - DBName - . T.replace oldDbName dbname . unDBName - $ foreignConstraintNameDBName fdef - } - Nothing -> - error $ "no explicit primary key fdef="++show fdef++ " ent="++show ent + fixForeignKey ent (UnboundForeignDef foreignFieldTexts parentFieldTexts fdef) = + case mfdefs of + Just fdefs -> + if length foreignFieldTexts /= length fdefs + then + lengthError fdefs + else + let + fds_ffs = + zipWith toForeignFields + foreignFieldTexts + fdefs + dbname = + unDBName (entityDB pent) + oldDbName = + unDBName (foreignRefTableDBName fdef) + in fdef + { foreignFields = map snd fds_ffs + , foreignNullable = setNull $ map fst fds_ffs + , foreignRefTableDBName = + DBName dbname + , foreignConstraintNameDBName = + DBName + . T.replace oldDbName dbname . unDBName + $ foreignConstraintNameDBName fdef + } + Nothing -> + error $ "no primary key found fdef="++show fdef++ " ent="++show ent where + pentError = + error $ "could not find table " ++ show (foreignRefTableHaskell fdef) + ++ " fdef=" ++ show fdef ++ " allnames=" + ++ show (map (unHaskellName . entityHaskell . unboundEntityDef) unEnts) + ++ "\n\nents=" ++ show ents + pent = + fromMaybe pentError $ M.lookup (foreignRefTableHaskell fdef) entLookup + mfdefs = case parentFieldTexts of + [] -> entitiesPrimary pent + _ -> Just $ map (getFd pent . HaskellName) parentFieldTexts + setNull :: [FieldDef] -> Bool setNull [] = error "setNull: impossible!" setNull (fd:fds) = let nullSetting = isNull fd in @@ -609,31 +612,32 @@ fixForeignKeysAll unEnts = map fixForeignKeys unEnts ++ show (map (unHaskellName . fieldHaskell) (fd:fds)) isNull = (NotNullable /=) . nullable . fieldAttrs - toForeignFields pent fieldText pfd = - case chktypes fd haskellField (entityFields pent) pfh of + toForeignFields :: Text -> FieldDef + -> (FieldDef, (ForeignFieldDef, ForeignFieldDef)) + toForeignFields fieldText pfd = + case chktypes fd haskellField pfd of Just err -> error err Nothing -> (fd, ((haskellField, fieldDB fd), (pfh, pfdb))) where - fd = getFd (entityFields ent) haskellField + fd = getFd ent haskellField haskellField = HaskellName fieldText (pfh, pfdb) = (fieldHaskell pfd, fieldDB pfd) - chktypes :: FieldDef -> HaskellName -> [FieldDef] -> HaskellName -> Maybe String - chktypes ffld _fkey pflds pkey = + chktypes ffld _fkey pfld = if fieldType ffld == fieldType pfld then Nothing else Just $ "fieldType mismatch: " ++ show (fieldType ffld) ++ ", " ++ show (fieldType pfld) - where - pfld = getFd pflds pkey - entName = entityHaskell ent - getFd [] t = error $ "foreign key constraint for: " ++ show (unHaskellName entName) - ++ " unknown column: " ++ show t - getFd (f:fs) t + getFd :: EntityDef -> HaskellName -> FieldDef + getFd entity t = go (keyAndEntityFields entity) + where + go [] = error $ "foreign key constraint for: " ++ show (unHaskellName $ entityHaskell entity) + ++ " unknown column: " ++ show t + go (f:fs) | fieldHaskell f == t = f - | otherwise = getFd fs t + | otherwise = go fs - lengthError pdef = error $ "found " ++ show (length foreignFieldTexts) ++ " fkeys and " ++ show (length (compositeFields pdef)) ++ " pkeys: fdef=" ++ show fdef ++ " pdef=" ++ show pdef + lengthError pdef = error $ "found " ++ show (length foreignFieldTexts) ++ " fkeys and " ++ show (length pdef) ++ " pkeys: fdef=" ++ show fdef ++ " pdef=" ++ show pdef data UnboundEntityDef = UnboundEntityDef @@ -737,6 +741,7 @@ mkAutoIdField ps entName idName idSqlType = FieldDef , fieldAttrs = [] , fieldStrict = True , fieldComments = Nothing + , fieldCascadeOpts = FieldCascade Nothing Nothing } defaultReferenceTypeCon :: FieldType @@ -768,25 +773,43 @@ takeCols -> [Text] -> Maybe FieldDef takeCols _ _ ("deriving":_) = Nothing -takeCols onErr ps (n':typ:rest) +takeCols onErr ps (n':typ:rest') | not (T.null n) && isLower (T.head n) = case parseFieldType typ of Left err -> onErr typ err Right ft -> Just FieldDef { fieldHaskell = HaskellName n - , fieldDB = DBName $ getDbName ps n rest + , fieldDB = DBName $ getDbName ps n attr , fieldType = ft , fieldSqlType = SqlOther $ "SqlType unset for " `mappend` n - , fieldAttrs = rest + , fieldAttrs = attr , fieldStrict = fromMaybe (psStrictFields ps) mstrict , fieldReference = NoReference , fieldComments = Nothing + , fieldCascadeOpts = FieldCascade onUpd onDel } where (mstrict, n) | Just x <- T.stripPrefix "!" n' = (Just True, x) | Just x <- T.stripPrefix "~" n' = (Just False, x) | otherwise = (Nothing, n') + (onDel, onUpd, attr) = go rest' Nothing Nothing + + go (txt : rest) onDelete' onUpdate' = + case (T.stripPrefix "OnDelete" txt, T.stripPrefix "OnUpdate" txt) of + (Just onDelete, _) -> case (readEither $ T.unpack onDelete, onDelete') of + (Right action, Nothing) -> go rest (Just action) onUpdate' + (Right _, Just _) -> error $ + "found more than one OnDelete actions at field " ++ show (n':typ:rest') + (Left _, _) -> (onDelete', onUpdate', txt : rest) + (_, Just onUpdate) -> case (readEither $ T.unpack onUpdate, onUpdate') of + (Right action, Nothing) -> go rest onDelete' (Just action) + (Right _, Just _) -> error $ + "found more than one OnUpdate actions at field " ++ show (n':typ:rest') + _ -> (onDelete', onUpdate', txt : rest) + _ -> (onDelete', onUpdate', txt : rest) + go [] onDelete' onUpdate' = (onDelete', onUpdate', []) + takeCols _ _ _ = Nothing getDbName :: PersistSettings -> Text -> [Text] -> Text @@ -900,7 +923,8 @@ takeUniq _ tableName _ xs = ++ show xs data UnboundForeignDef = UnboundForeignDef - { _unboundFields :: [Text] -- ^ fields in other entity + { _unboundForeignFields :: [Text] -- ^ fields in the parent entity + , _unboundParentFields :: [Text] -- ^ fields in parent entity , _unboundForeignDef :: ForeignDef } @@ -920,7 +944,7 @@ takeForeign ps tableName _defs = takeRefTable where go :: [Text] -> Maybe CascadeAction -> Maybe CascadeAction -> UnboundForeignDef go (n:rest) onDelete onUpdate | not (T.null n) && isLower (T.head n) - = UnboundForeignDef fields $ ForeignDef + = UnboundForeignDef fFields pFields $ ForeignDef { foreignRefTableHaskell = HaskellName refTableName , foreignRefTableDBName = @@ -939,9 +963,18 @@ takeForeign ps tableName _defs = takeRefTable attrs , foreignNullable = False + , foreignToPrimary = + null pFields } where - (fields,attrs) = break ("!" `T.isPrefixOf`) rest + (fields ,attrs) = break ("!" `T.isPrefixOf`) rest + (fFields, pFields) = case break (== "References") fields of + (ffs, []) -> (ffs, []) + (ffs, _ : pfs) -> case (length ffs, length pfs) of + (flen, plen) | flen == plen -> (ffs, pfs) + (flen, plen) -> error $ errorPrefix ++ concat + [ "Found ", show flen, " foreign fields but " + , show plen, " parent fields" ] go ((T.stripPrefix "OnDelete" -> Just onDelete) : rest) onDelete' onUpdate = case (onDelete', readEither $ T.unpack onDelete) of (Nothing, Right cascadingAction) -> go rest (Just cascadingAction) onUpdate diff --git a/persistent/Database/Persist/Sql/Internal.hs b/persistent/Database/Persist/Sql/Internal.hs index 9a19c7520..cf9330c3b 100644 --- a/persistent/Database/Persist/Sql/Internal.hs +++ b/persistent/Database/Persist/Sql/Internal.hs @@ -1,5 +1,6 @@ {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} -- | Intended for creating new backends. module Database.Persist.Sql.Internal @@ -87,12 +88,13 @@ mkColumns allDefs t overrides = , cDefaultConstraintName = Nothing , cMaxLen = maxLen $ fieldAttrs fd - , cReference = ref (fieldDB fd) (fieldReference fd) (fieldAttrs fd) + , cReference = mkColumnReference fd } tableName :: DBName tableName = entityDB t + go :: FieldDef -> Column go fd = Column @@ -102,7 +104,7 @@ mkColumns allDefs t overrides = , cDefault = defaultAttribute $ fieldAttrs fd , cDefaultConstraintName = Nothing , cMaxLen = maxLen $ fieldAttrs fd - , cReference = ref (fieldDB fd) (fieldReference fd) (fieldAttrs fd) + , cReference = mkColumnReference fd } maxLen :: [Attr] -> Maybe Integer @@ -117,6 +119,11 @@ mkColumns allDefs t overrides = refNameFn = fromMaybe refName (backendSpecificForeignKeyName overrides) + mkColumnReference :: FieldDef -> Maybe ColumnReference + mkColumnReference fd = + fmap (\(tName, cName) -> ColumnReference tName cName (fieldCascadeOpts fd)) + $ ref (fieldDB fd) (fieldReference fd) (fieldAttrs fd) + ref :: DBName -> ReferenceDef -> [Attr] diff --git a/persistent/Database/Persist/Sql/Types.hs b/persistent/Database/Persist/Sql/Types.hs index f00339ad6..da687b88f 100644 --- a/persistent/Database/Persist/Sql/Types.hs +++ b/persistent/Database/Persist/Sql/Types.hs @@ -7,6 +7,8 @@ module Database.Persist.Sql.Types , OverflowNatural(..) ) where +import Database.Persist.Types.Base (FieldCascade) + import Control.Exception (Exception(..)) import Control.Monad.Logger (NoLoggingT) import Control.Monad.Trans.Reader (ReaderT (..)) @@ -25,7 +27,14 @@ data Column = Column , cDefault :: !(Maybe Text) , cDefaultConstraintName :: !(Maybe DBName) , cMaxLen :: !(Maybe Integer) - , cReference :: !(Maybe (DBName, DBName)) -- table name, constraint name + , cReference :: !(Maybe ColumnReference) + } + deriving (Eq, Ord, Show) + +data ColumnReference = ColumnReference + { crTableName :: DBName + , crConstraintName :: DBName + , crFieldCascade :: FieldCascade } deriving (Eq, Ord, Show) diff --git a/persistent/Database/Persist/Types/Base.hs b/persistent/Database/Persist/Types/Base.hs index 4788f7b81..deaffe0e5 100644 --- a/persistent/Database/Persist/Types/Base.hs +++ b/persistent/Database/Persist/Types/Base.hs @@ -139,6 +139,14 @@ data EntityDef = EntityDef } deriving (Show, Eq, Read, Ord) +entitiesPrimary :: EntityDef -> Maybe [FieldDef] +entitiesPrimary t = case fieldReference primaryField of + CompositeRef c -> Just $ (compositeFields c) + ForeignRef _ _ -> Just [primaryField] + _ -> Nothing + where + primaryField = entityId t + entityPrimary :: EntityDef -> Maybe CompositeDef entityPrimary t = case fieldReference (entityId t) of CompositeRef c -> Just c @@ -201,6 +209,11 @@ data FieldDef = FieldDef -- attach comments to a field in the quasiquoter. -- -- @since 2.10.0 + , fieldCascadeOpts :: !FieldCascade + -- ^ The cascade options of this fields. Used when this field refers to + -- another field. + -- + -- @since 2.11.0 } deriving (Show, Eq, Read, Ord) @@ -300,6 +313,10 @@ data ForeignDef = ForeignDef , foreignFields :: ![(ForeignFieldDef, ForeignFieldDef)] -- this entity plus the primary entity , foreignAttrs :: ![Attr] , foreignNullable :: Bool + , foreignToPrimary :: Bool + -- ^ Determines if the reference is towards a Primary Key or not. + -- + -- @since 2.11.0 } deriving (Show, Eq, Read, Ord)