diff --git a/.github/workflows/haskell.yml b/.github/workflows/haskell.yml index df9e14926..af8b007f5 100644 --- a/.github/workflows/haskell.yml +++ b/.github/workflows/haskell.yml @@ -82,4 +82,5 @@ jobs: - run: cabal v2-build all --disable-optimization $CONFIG - run: cabal v2-test all --disable-optimization $CONFIG - run: cabal v2-haddock all $CONFIG + continue-on-error: true - run: cabal v2-sdist all diff --git a/persistent-mysql/Database/Persist/MySQL.hs b/persistent-mysql/Database/Persist/MySQL.hs index 4f3476abb..addaaf3bc 100644 --- a/persistent-mysql/Database/Persist/MySQL.hs +++ b/persistent-mysql/Database/Persist/MySQL.hs @@ -1428,7 +1428,7 @@ copyField = CopyField -- [] -- @ -- --- Once we run that code on the datahase, the new data set looks like this: +-- Once we run that code on the database, the new data set looks like this: -- -- > items: -- > +------+-------------+-------+----------+ diff --git a/persistent-mysql/test/InsertDuplicateUpdate.hs b/persistent-mysql/test/InsertDuplicateUpdate.hs index 595d13b60..437120792 100644 --- a/persistent-mysql/test/InsertDuplicateUpdate.hs +++ b/persistent-mysql/test/InsertDuplicateUpdate.hs @@ -61,12 +61,15 @@ specs = describe "DuplicateKeyUpdate" $ do dbItems <- map entityVal <$> selectList [] [] sort dbItems @== sort (newItem : items) it "updates existing records" $ db $ do + let postUpdate = map (\i -> i { itemQuantity = fmap (+1) (itemQuantity i) }) items deleteWhere ([] :: [Filter Item]) insertMany_ items insertManyOnDuplicateKeyUpdate items [] [ItemQuantity +=. Just 1] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort postUpdate it "only copies passing values" $ db $ do deleteWhere ([] :: [Filter Item]) insertMany_ items diff --git a/persistent-postgresql/ChangeLog.md b/persistent-postgresql/ChangeLog.md index 6350581ef..7fce26302 100644 --- a/persistent-postgresql/ChangeLog.md +++ b/persistent-postgresql/ChangeLog.md @@ -1,5 +1,9 @@ # Changelog for persistent-postgresql +## 2.12.1.0 + +* Added `upsertWhere` and `upsertManyWhere` to `persistent-postgresql`. [#1222](https://github.com/yesodweb/persistent/pull/1222). + ## 2.12.0.0 * Decomposed `HaskellName` into `ConstraintNameHS`, `EntityNameHS`, `FieldNameHS`. Decomposed `DBName` into `ConstraintNameDB`, `EntityNameDB`, `FieldNameDB` respectively. [#1174](https://github.com/yesodweb/persistent/pull/1174) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 459d69ffc..2fa1fed6b 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -1,5 +1,9 @@ {-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -21,8 +25,16 @@ module Database.Persist.Postgresql , createPostgresqlPoolWithConf , module Database.Persist.Sql , ConnectionString + , HandleUpdateCollision + , copyField + , copyUnlessNull + , copyUnlessEmpty + , copyUnlessEq + , excludeNotEqualToOriginal , PostgresConf (..) , PgInterval (..) + , upsertWhere + , upsertManyWhere , openSimpleConn , openSimpleConnWithVersion , tableName @@ -50,7 +62,7 @@ import Control.Monad import Control.Monad.Except import Control.Monad.IO.Unlift (MonadIO (..), MonadUnliftIO) import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT) -import Control.Monad.Trans.Reader (runReaderT) +import Control.Monad.Trans.Reader (ReaderT(..), runReaderT, asks) import Control.Monad.Trans.Writer (WriterT(..), runWriterT) import qualified Blaze.ByteString.Builder.Char8 as BBB @@ -66,7 +78,7 @@ import qualified Data.ByteString.Char8 as B8 import Data.Char (ord) import Data.Conduit import qualified Data.Conduit.List as CL -import Data.Data +import Data.Data ( Data, Typeable ) import Data.Either (partitionEithers) import Data.Fixed (Fixed(..), Pico) import Data.Function (on) @@ -80,6 +92,7 @@ import qualified Data.List.NonEmpty as NEL import qualified Data.Map as Map import Data.Maybe import Data.Monoid ((<>)) +import qualified Data.Monoid as Monoid import Data.Pool (Pool) import Data.String.Conversions.Monomorphic (toStrictByteString) import Data.Text (Text) @@ -397,7 +410,6 @@ insertSql' ent vals = ] ] - upsertSql' :: EntityDef -> NonEmpty (FieldNameHS, FieldNameDB) -> Text -> Text upsertSql' ent uniqs updateVal = T.concat @@ -1496,6 +1508,7 @@ escapeE = escapeWith escape escapeF :: FieldNameDB -> Text escapeF = escapeWith escape + escape :: Text -> Text escape s = T.pack $ '"' : go (T.unpack s) ++ "\"" @@ -1738,6 +1751,214 @@ repsertManySql ent n = putManySql' conflictColumns fields ent n fields = keyAndEntityFields ent conflictColumns = escapeF . fieldDB <$> entityKeyFields ent +-- | This type is used to determine how to update rows using Postgres' +-- @INSERT ... ON CONFLICT KEY UPDATE@ functionality, exposed via +-- 'upsertWhere' and 'upsertManyWhere' in this library. +-- +-- @since 2.12.1.0 +data HandleUpdateCollision record where + -- | Copy the field directly from the record. + CopyField :: EntityField record typ -> HandleUpdateCollision record + -- | Only copy the field if it is not equal to the provided value. + CopyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record + +-- | Copy the field into the database only if the value in the +-- corresponding record is non-@NULL@. +-- +-- @since 2.12.1.0 +copyUnlessNull :: PersistField typ => EntityField record (Maybe typ) -> HandleUpdateCollision record +copyUnlessNull field = CopyUnlessEq field Nothing + +-- | Copy the field into the database only if the value in the +-- corresponding record is non-empty, where "empty" means the Monoid +-- definition for 'mempty'. Useful for 'Text', 'String', 'ByteString', etc. +-- +-- The resulting 'HandleUpdateCollision' type is useful for the +-- 'upsertManyWhere' function. +-- +-- @since 2.12.1.0 +copyUnlessEmpty :: (Monoid.Monoid typ, PersistField typ) => EntityField record typ -> HandleUpdateCollision record +copyUnlessEmpty field = CopyUnlessEq field Monoid.mempty + +-- | Copy the field into the database only if the field is not equal to the +-- provided value. This is useful to avoid copying weird nullary data into +-- the database. +-- +-- The resulting 'HandleUpdateCollision' type is useful for the +-- 'upsertMany' function. +-- +-- @since 2.12.1.0 +copyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record +copyUnlessEq = CopyUnlessEq + +-- | Copy the field directly from the record. +-- +-- @since 2.12.1.0 +copyField :: PersistField typ => EntityField record typ -> HandleUpdateCollision record +copyField = CopyField + +-- | Postgres specific 'upsertWhere'. This method does the following: +-- It will insert a record if no matching unique key exists. +-- If a unique key exists, it will update the relevant field with a user-supplied value, however, +-- it will only do this update on a user-supplied condition. +-- For example, here's how this method could be called like such: +-- +-- @ +-- upsertWhere record [recordField =. newValue] [recordField /= newValue] +-- @ +-- +-- Called thusly, this method will insert a new record (if none exists) OR update a recordField with a new value +-- assuming the condition in the last block is met. +-- +-- @since 2.12.1.0 +upsertWhere + :: ( backend ~ PersistEntityBackend record + , PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + , PersistStore backend + , BackendCompatible SqlBackend backend + ) + => record + -> [Update record] + -> [Filter record] + -> ReaderT backend m () +upsertWhere record updates filts = + upsertManyWhere [record] [] updates filts + +-- | Exclude any record field if it doesn't match the filter record. Used only in `upsertWhere` and +-- `upsertManyWhere` +-- +-- @since 2.12.1.0 +-- TODO: we could probably make a sum type for the `Filter` record that's passed into the `upserWhere` and +-- `upsertManyWhere` methods that has similar behavior to the HandleCollisionUpdate type. +excludeNotEqualToOriginal :: + (PersistField typ + , PersistEntity rec) => + EntityField rec typ -> + Filter rec +excludeNotEqualToOriginal field = + Filter + { filterField = + field, + filterFilter = + Ne, + filterValue = + UnsafeValue $ + PersistLiteral_ + Unescaped + bsForExcludedField + } + where + bsForExcludedField = + T.encodeUtf8 $ + "EXCLUDED." + <> fieldName field + +-- | Postgres specific 'upsertManyWhere'. This method does the following: +-- It will insert a record if no matching unique key exists. +-- If a unique key exists, it will update the relevant field with a user-supplied value, however, +-- it will only do this update on a user-supplied condition. +-- For example, here's how this method could be called like such: +-- +-- upsertManyWhere [record] [recordField =. newValue] [recordField /= newValue] +-- +-- Called thusly, this method will insert a new record (if none exists) OR update a recordField with a new value +-- assuming the condition in the last block is met. +-- +-- -- @since 2.12.1.0 +upsertManyWhere :: + forall record backend m. + ( backend ~ PersistEntityBackend record, + BackendCompatible SqlBackend backend, + PersistEntityBackend record ~ SqlBackend, + PersistEntity record, + MonadIO m + ) => + -- | A list of the records you want to insert, or update + [record] -> + -- | A list of the fields you want to copy over. + [HandleUpdateCollision record] -> + -- | A list of the updates to apply that aren't dependent on the record being inserted. + [Update record] -> + -- | A filter condition that dictates the scope of the updates + [Filter record] -> + ReaderT backend m () +upsertManyWhere [] _ _ _ = return () +upsertManyWhere records fieldValues updates filters = do + conn <- asks projectBackend + uncurry rawExecute $ + mkBulkUpsertQuery records conn fieldValues updates filters + +-- | This creates the query for 'upsertManyWhere'. If you +-- provide an empty list of updates to perform, then it will generate +-- a dummy/no-op update using the first field of the record. This avoids +-- duplicate key exceptions. +mkBulkUpsertQuery + :: (PersistEntity record, PersistEntityBackend record ~ SqlBackend) + => [record] -- ^ A list of the records you want to insert, or update + -> SqlBackend + -> [HandleUpdateCollision record] -- ^ A list of the fields you want to copy over. + -> [Update record] -- ^ A list of the updates to apply that aren't dependent on the record being inserted. + -> [Filter record] -- ^ A filter condition that dictates the scope of the updates + -> (Text, [PersistValue]) +mkBulkUpsertQuery records conn fieldValues updates filters = + (q, recordValues <> updsValues <> copyUnlessValues <> whereVals) + where + mfieldDef x = case x of + CopyField rec -> Right (fieldDbToText (persistFieldDef rec)) + CopyUnlessEq rec val -> Left (fieldDbToText (persistFieldDef rec), toPersistValue val) + (fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues + fieldDbToText = escapeF . fieldDB + entityDef' = entityDef records + conflictColumns = escapeF . fieldDB <$> entityKeyFields entityDef' + firstField = case entityFieldNames of + [] -> error "The entity you're trying to insert does not have any fields." + (field:_) -> field + entityFieldNames = map fieldDbToText (entityFields entityDef') + nameOfTable = escapeE . entityDB $ entityDef' + copyUnlessValues = map snd fieldsToMaybeCopy + recordValues = concatMap (map toPersistValue . toPersistFields) records + recordPlaceholders = Util.commaSeparated $ map (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) records + mkCondFieldSet n _ = + T.concat + [ n + , "=COALESCE(" + , "NULLIF(" + , "EXCLUDED." + , n + , "," + , "?" + , ")" + , "," + , nameOfTable + , "." + , n + ,")" + ] + condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy + fieldSets = map (\n -> T.concat [n, "=EXCLUDED.", n, ""]) updateFieldNames + upds = map (Util.mkUpdateText' (escapeF) (\n -> T.concat [nameOfTable, ".", n])) updates + updsValues = map (\(Update _ val _) -> toPersistValue val) updates + (wher, whereVals) = if null filters + then ("", []) + else (filterClauseWithVals (Just PrefixTableName) conn filters) + updateText = case fieldSets <> upds <> condFieldSets of + [] -> T.concat [firstField, "=EXCLUDED.", firstField] + xs -> Util.commaSeparated xs + q = T.concat + [ "INSERT INTO " + , nameOfTable + , Util.parenWrapped . Util.commaSeparated $ entityFieldNames + , " VALUES " + , recordPlaceholders + , " ON CONFLICT " + , Util.parenWrapped $ Util.commaSeparated $ conflictColumns + , " DO UPDATE SET " + , updateText + , wher + ] + putManySql' :: [Text] -> [FieldDef] -> EntityDef -> Int -> Text putManySql' conflictColumns (filter isFieldNotGenerated -> fields) ent n = q where diff --git a/persistent-postgresql/README.md b/persistent-postgresql/README.md index 7318e21ea..219bb184a 100644 --- a/persistent-postgresql/README.md +++ b/persistent-postgresql/README.md @@ -21,5 +21,5 @@ $ createdb test The tests do not pass a test and expect to connect with the `postgres` user. Ensure that peer authentication is allowed for this. -An easy/insecure way to do this is to set the `METHOD` to `trust` for all the login methods in `/etc/postgresql/XX/main/pg_hba.coinf`. +An easy/insecure way to do this is to set the `METHOD` to `trust` for all the login methods in `/etc/postgresql/XX/main/pg_hba.conf`. (TODO: make this better?) diff --git a/persistent-postgresql/persistent-postgresql.cabal b/persistent-postgresql/persistent-postgresql.cabal index c7ea5e3f4..4f7a8eb69 100644 --- a/persistent-postgresql/persistent-postgresql.cabal +++ b/persistent-postgresql/persistent-postgresql.cabal @@ -53,6 +53,7 @@ test-suite test JSONTest CustomConstraintTest PgIntervalTest + UpsertWhere ghc-options: -Wall build-depends: base >= 4.9 && < 5 diff --git a/persistent-postgresql/test/PgInit.hs b/persistent-postgresql/test/PgInit.hs index d2fcb85dd..8c9906ce3 100644 --- a/persistent-postgresql/test/PgInit.hs +++ b/persistent-postgresql/test/PgInit.hs @@ -56,7 +56,7 @@ import Test.HUnit ((@?=),(@=?), Assertion, assertFailure, assertBool) import Test.QuickCheck import Control.Monad (unless, (>=>)) -import Control.Monad.IO.Class + import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Logger import Control.Monad.Trans.Resource (ResourceT, runResourceT) diff --git a/persistent-postgresql/test/UpsertWhere.hs b/persistent-postgresql/test/UpsertWhere.hs new file mode 100644 index 000000000..626f83443 --- /dev/null +++ b/persistent-postgresql/test/UpsertWhere.hs @@ -0,0 +1,178 @@ +{-# LANGUAGE DataKinds, FlexibleInstances, MultiParamTypeClasses, ExistentialQuantification #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE StandaloneDeriving #-} + +module UpsertWhere where + +import Data.List (sort) + +import Database.Persist.Postgresql +import PgInit + +share [mkPersist sqlSettings, mkMigrate "upsertWhereMigrate"] [persistLowerCase| + Item + name Text sqltype=varchar(80) + description Text + price Double Maybe + quantity Int Maybe + + Primary name + deriving Eq Show Ord + +|] + +specs :: Spec +specs = describe "UpsertWhere" $ do + let item1 = Item "item1" "" (Just 3) Nothing + item2 = Item "item2" "hello world" Nothing (Just 2) + items = [item1, item2] + + describe "upsertWhere" $ do + it "inserts appropriately" $ runConnAssert $ do + deleteWhere ([] :: [Filter Item]) + upsertWhere item1 [ItemDescription =. "i am item 1"] [] + Just item <- get (ItemKey "item1") + item @== item1 + it "performs only updates given if record already exists" $ runConnAssert $ do + deleteWhere ([] :: [Filter Item]) + let newDescription = "I am a new description" + insert_ item1 + upsertWhere + (Item "item1" "i am inserted description" (Just 1) (Just 2)) + [ItemDescription =. newDescription] + [] + Just item <- get (ItemKey "item1") + item @== item1 { itemDescription = newDescription } + + describe "upsertManyWhere" $ do + it "inserts fresh records" $ runConnAssert $ do + deleteWhere ([] :: [Filter Item]) + insertMany_ items + let newItem = Item "item3" "fresh" Nothing Nothing + upsertManyWhere + (newItem : items) + [copyField ItemDescription] + [] + [] + dbItems <- map entityVal <$> selectList [] [] + sort dbItems @== sort (newItem : items) + it "updates existing records" $ runConnAssert $ do + deleteWhere ([] :: [Filter Item]) + let postUpdate = map (\i -> i { itemQuantity = fmap (+1) (itemQuantity i) }) items + insertMany_ items + upsertManyWhere + items + [] + [ItemQuantity +=. Just 1] + [] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort postUpdate + it "only copies passing values" $ runConnAssert $ do + deleteWhere ([] :: [Filter Item]) + insertMany_ items + let newItems = map (\i -> i { itemQuantity = Just 0, itemPrice = fmap (*2) (itemPrice i) }) items + postUpdate = map (\i -> i { itemPrice = fmap (*2) (itemPrice i) }) items + upsertManyWhere + newItems + [ + copyUnlessEq ItemQuantity (Just 0) + , copyField ItemPrice + ] + [] + [] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort postUpdate + it "inserts without modifying existing records if no updates specified" $ runConnAssert $ do + let newItem = Item "item3" "hi friends!" Nothing Nothing + deleteWhere ([] :: [Filter Item]) + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [] + [] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort (newItem : items) + it "inserts without modifying existing records if no updates specified and there's a filter with True condition" $ + runConnAssert $ do + let newItem = Item "item3" "hi friends!" Nothing Nothing + deleteWhere ([] :: [Filter Item]) + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [] + [ItemDescription ==. "hi friends!"] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort (newItem : items) + it "inserts without updating existing records if there are updates specified but there's a filter with a False condition" $ + runConnAssert $ do + let newItem = Item "item3" "hi friends!" Nothing Nothing + deleteWhere ([] :: [Filter Item]) + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [ItemQuantity +=. Just 1] + [ItemDescription ==. "hi friends!"] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort (newItem : items) + it "inserts new records but does not update existing records if there are updates specified but the modification condition is False" $ + runConnAssert $ do + let newItem = Item "item3" "hi friends!" Nothing Nothing + deleteWhere ([] :: [Filter Item]) + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [ItemQuantity +=. Just 1] + [excludeNotEqualToOriginal ItemDescription] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort (newItem : items) + it "inserts new records and updates existing records if there are updates specified and the modification condition is True (because it's empty)" $ + runConnAssert $ do + let newItem = Item "item3" "hello world" Nothing Nothing + postUpdate = map (\i -> i {itemQuantity = fmap (+ 1) (itemQuantity i)}) items + deleteWhere ([] :: [Filter Item]) + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [ItemQuantity +=. Just 1] + [] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort (newItem : postUpdate) + it "inserts new records and updates existing records if there are updates specified and the modification filter condition is triggered" $ + runConnAssert $ do + let newItem = Item "item3" "hi friends!" Nothing Nothing + postUpdate = map (\i -> i {itemQuantity = fmap (+1) (itemQuantity i)}) items + deleteWhere ([] :: [Filter Item]) + insertMany_ items + upsertManyWhere + (newItem : items) + [ + copyUnlessEq ItemDescription "hi friends!" + , copyField ItemPrice + ] + [ItemQuantity +=. Just 1] + [ItemDescription !=. "bye friends!"] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort (newItem : postUpdate) + it "inserts an item and doesn't apply the update if the filter condition is triggered" $ + runConnAssert $ do + let newItem = Item "item3" "hello world" Nothing Nothing + deleteWhere ([] :: [Filter Item]) + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [ItemQuantity +=. Just 1] + [excludeNotEqualToOriginal ItemDescription] + dbItems <- sort . fmap entityVal <$> selectList [] [] + dbItems @== sort (newItem : items) diff --git a/persistent-postgresql/test/main.hs b/persistent-postgresql/test/main.hs index 6c0c47ee6..60543a349 100644 --- a/persistent-postgresql/test/main.hs +++ b/persistent-postgresql/test/main.hs @@ -51,6 +51,7 @@ import qualified TransactionLevelTest import qualified TreeTest import qualified UniqueTest import qualified UpsertTest +import qualified UpsertWhere import qualified CustomConstraintTest import qualified LongIdentifierTest import qualified PgIntervalTest @@ -128,6 +129,7 @@ main = do , ForeignKey.compositeMigrate , MigrationTest.migrationMigrate , PgIntervalTest.pgIntervalMigrate + , UpsertWhere.upsertWhereMigrate ] PersistentTest.cleanDB ForeignKey.cleanDB @@ -195,6 +197,7 @@ main = do LongIdentifierTest.specsWith runConnAssertUseConf -- Have at least one test use the conf variant of connecting to Postgres, to improve test coverage. JSONTest.specs CustomConstraintTest.specs + UpsertWhere.specs PgIntervalTest.specs ArrayAggTest.specs GeneratedColumnTestSQL.specsWith runConnAssert diff --git a/persistent/Database/Persist/Sql.hs b/persistent/Database/Persist/Sql.hs index a0e802507..33676da55 100644 --- a/persistent/Database/Persist/Sql.hs +++ b/persistent/Database/Persist/Sql.hs @@ -12,6 +12,10 @@ module Database.Persist.Sql , rawSql , deleteWhereCount , updateWhereCount + , filterClause + , filterClauseHelper + , filterClauseWithVals + , FilterTablePrefix (..) , transactionSave , transactionSaveWithIsolation , transactionUndo diff --git a/persistent/Database/Persist/Sql/Orphan/PersistQuery.hs b/persistent/Database/Persist/Sql/Orphan/PersistQuery.hs index 3dc784292..a593bf4e1 100644 --- a/persistent/Database/Persist/Sql/Orphan/PersistQuery.hs +++ b/persistent/Database/Persist/Sql/Orphan/PersistQuery.hs @@ -6,6 +6,10 @@ module Database.Persist.Sql.Orphan.PersistQuery ( deleteWhereCount , updateWhereCount + , filterClause + , filterClauseHelper + , filterClauseWithVals + , FilterTablePrefix (..) , decorateSQLWithLimitOffset ) where @@ -36,7 +40,7 @@ instance PersistQueryRead SqlBackend where conn <- ask let wher = if null filts then "" - else filterClause False conn filts + else filterClause Nothing conn filts let sql = mconcat [ "SELECT COUNT(*) FROM " , connEscapeTableName conn t @@ -59,7 +63,7 @@ instance PersistQueryRead SqlBackend where conn <- ask let wher = if null filts then "" - else filterClause False conn filts + else filterClause Nothing conn filts let sql = mconcat [ "SELECT EXISTS(SELECT 1 FROM " , connEscapeTableName conn t @@ -93,7 +97,7 @@ instance PersistQueryRead SqlBackend where t = entityDef $ dummyFromFilts filts wher conn = if null filts then "" - else filterClause False conn filts + else filterClause Nothing conn filts ord conn = case map (orderClause False conn) orders of [] -> "" @@ -119,7 +123,7 @@ instance PersistQueryRead SqlBackend where wher conn = if null filts then "" - else filterClause False conn filts + else filterClause Nothing conn filts sql conn = connLimitOffset conn (limit,offset) (not (null orders)) $ mconcat [ "SELECT " , cols conn @@ -183,7 +187,7 @@ deleteWhereCount filts = withCompatibleBackend $ do let t = entityDef $ dummyFromFilts filts let wher = if null filts then "" - else filterClause False conn filts + else filterClause Nothing conn filts sql = mconcat [ "DELETE FROM " , connEscapeTableName conn t @@ -203,7 +207,7 @@ updateWhereCount filts upds = withCompatibleBackend $ do conn <- ask let wher = if null filts then "" - else filterClause False conn filts + else filterClause Nothing conn filts let sql = mconcat [ "UPDATE " , connEscapeTableName conn t @@ -217,26 +221,30 @@ updateWhereCount filts upds = withCompatibleBackend $ do where t = entityDef $ dummyFromFilts filts -fieldName :: forall record typ. (PersistEntity record, PersistEntityBackend record ~ SqlBackend) => EntityField record typ -> FieldNameDB +fieldName :: forall record typ. (PersistEntity record) => EntityField record typ -> FieldNameDB fieldName f = fieldDB $ persistFieldDef f dummyFromFilts :: [Filter v] -> Maybe v dummyFromFilts _ = Nothing -getFiltsValues :: forall val. (PersistEntity val, PersistEntityBackend val ~ SqlBackend) +getFiltsValues :: forall val. (PersistEntity val) => SqlBackend -> [Filter val] -> [PersistValue] -getFiltsValues conn = snd . filterClauseHelper False False conn OrNullNo +getFiltsValues conn = snd . filterClauseHelper Nothing False conn OrNullNo data OrNull = OrNullYes | OrNullNo -filterClauseHelper :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend) - => Bool -- ^ include table name? - -> Bool -- ^ include WHERE? +data FilterTablePrefix + = PrefixTableName + | PrefixExcluded + +filterClauseHelper :: (PersistEntity val) + => Maybe FilterTablePrefix -- ^ include table name or PostgresSQL EXCLUDED + -> Bool -- ^ include WHERE -> SqlBackend -> OrNull -> [Filter val] -> (Text, [PersistValue]) -filterClauseHelper includeTable includeWhere conn orNull filters = +filterClauseHelper tablePrefix includeWhere conn orNull filters = (if not (T.null sql) && includeWhere then " WHERE " <> sql else sql, vals) @@ -356,7 +364,9 @@ filterClauseHelper includeTable includeWhere conn orNull filters = orNullSuffix = case orNull of - OrNullYes -> mconcat [" OR ", name, " IS NULL"] + OrNullYes -> mconcat [" OR " + , name + , " IS NULL"] OrNullNo -> "" isNull = PersistNull `elem` allVals @@ -364,10 +374,10 @@ filterClauseHelper includeTable includeWhere conn orNull filters = allVals = filterValueToPersistValues value tn = connEscapeTableName conn $ entityDef $ dummyFromFilts [Filter field value pfilter] name = - (if includeTable - then ((tn <> ".") <>) - else id) - $ connEscapeFieldName conn (fieldName field) + case tablePrefix of + Just PrefixTableName -> ((tn <> ".") <>) $ connEscapeFieldName conn (fieldName field) + Just PrefixExcluded -> (("EXCLUDED.") <>) $ connEscapeFieldName conn (fieldName field) + _ -> id $ connEscapeFieldName conn (fieldName field) qmarks = case value of FilterValue{} -> "(?)" UnsafeValue{} -> "(?)" @@ -387,14 +397,21 @@ filterClauseHelper includeTable includeWhere conn orNull filters = showSqlFilter NotIn = " NOT IN " showSqlFilter (BackendSpecificFilter s) = s -filterClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend) - => Bool -- ^ include table name? +filterClause :: (PersistEntity val) + => Maybe FilterTablePrefix -- ^ include table name or EXCLUDED -> SqlBackend -> [Filter val] -> Text filterClause b c = fst . filterClauseHelper b True c OrNullNo -orderClause :: (PersistEntity val, PersistEntityBackend val ~ SqlBackend) +filterClauseWithVals :: (PersistEntity val) + => Maybe FilterTablePrefix -- ^ include table name or EXCLUDED + -> SqlBackend + -> [Filter val] + -> (Text, [PersistValue]) +filterClauseWithVals b c = filterClauseHelper b True c OrNullNo + +orderClause :: (PersistEntity val) => Bool -- ^ include the table name -> SqlBackend -> SelectOpt val @@ -410,7 +427,7 @@ orderClause includeTable conn o = tn = connEscapeTableName conn (entityDef $ dummyFromOrder o) - name :: (PersistEntityBackend record ~ SqlBackend, PersistEntity record) + name :: (PersistEntity record) => EntityField record typ -> Text name x = (if includeTable diff --git a/persistent/Database/Persist/Sql/Util.hs b/persistent/Database/Persist/Sql/Util.hs index 980cc6e08..d68e55320 100644 --- a/persistent/Database/Persist/Sql/Util.hs +++ b/persistent/Database/Persist/Sql/Util.hs @@ -207,6 +207,7 @@ commaSeparated = T.intercalate ", " mkUpdateText :: PersistEntity record => SqlBackend -> Update record -> Text mkUpdateText conn = mkUpdateText' (connEscapeFieldName conn) id +-- TODO: incorporate the table names into a sum type mkUpdateText' :: PersistEntity record => (FieldNameDB -> Text) -> (Text -> Text) -> Update record -> Text mkUpdateText' escapeName refColumn x = case updateUpdate x of @@ -223,7 +224,7 @@ mkUpdateText' escapeName refColumn x = parenWrapped :: Text -> Text parenWrapped t = T.concat ["(", t, ")"] --- | Make a list 'PersistValue' suitable for detabase inserts. Pairs nicely +-- | Make a list 'PersistValue' suitable for database inserts. Pairs nicely -- with the function 'mkInsertPlaceholders'. -- -- Does not include generated columns.