Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions persistent-postgresql/ChangeLog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Changelog for persistent-postgresql

# 2.12.1.1

* [#1235](https://github.com/yesodweb/persistent/pull/1235)
* `upsertWhere` and `upsertManyWhere` only worked in cases where a `Primary`
key was defined on a record, and no other uniqueness constraints. They
have been fixed to only work with records that have a single Uniqueness
constraint defined.

## 2.12.1.0

* Added `upsertWhere` and `upsertManyWhere` to `persistent-postgresql`. [#1222](https://github.com/yesodweb/persistent/pull/1222).
Expand Down
184 changes: 109 additions & 75 deletions persistent-postgresql/Database/Persist/Postgresql.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,7 @@ upsertWhere
, MonadIO m
, PersistStore backend
, BackendCompatible SqlBackend backend
, OnlyOneUniqueKey record
)
=> record
-> [Update record]
Expand All @@ -1825,79 +1826,98 @@ upsertWhere
upsertWhere record updates filts =
upsertManyWhere [record] [] updates filts

-- | Exclude any record field if it doesn't match the filter record. Used only in `upsertWhere` and
-- `upsertManyWhere`
--
-- @since 2.12.1.0
-- TODO: we could probably make a sum type for the `Filter` record that's passed into the `upserWhere` and
-- `upsertManyWhere` methods that has similar behavior to the HandleCollisionUpdate type.
excludeNotEqualToOriginal ::
(PersistField typ
, PersistEntity rec) =>
EntityField rec typ ->
Filter rec
excludeNotEqualToOriginal field =
Filter
{ filterField =
field,
filterFilter =
Ne,
filterValue =
UnsafeValue $
PersistLiteral_
Unescaped
bsForExcludedField
}
where
bsForExcludedField =
T.encodeUtf8 $
"EXCLUDED."
<> fieldName field

-- | Postgres specific 'upsertManyWhere'. This method does the following:
-- It will insert a record if no matching unique key exists.
-- If a unique key exists, it will update the relevant field with a user-supplied value, however,
-- it will only do this update on a user-supplied condition.
-- For example, here's how this method could be called like such:
--
-- upsertManyWhere [record] [recordField =. newValue] [recordField /= newValue]
-- upsertManyWhere [record] [recordField =. newValue] [recordField !=. newValue]
--
-- Called thusly, this method will insert a new record (if none exists) OR update a recordField with a new value
-- assuming the condition in the last block is met.
--
-- @since 2.12.1.0
upsertManyWhere ::
forall record backend m.
( backend ~ PersistEntityBackend record,
BackendCompatible SqlBackend backend,
PersistEntityBackend record ~ SqlBackend,
PersistEntity record,
MonadIO m
) =>
[record] -> -- ^ A list of the records you want to insert, or update
[HandleUpdateCollision record] -> -- ^ A list of the fields you want to copy over.
[Update record] -> -- ^ A list of the updates to apply that aren't dependent on the record being inserted.
[Filter record] -> -- ^ A filter condition that dictates the scope of the updates
ReaderT backend m ()
upsertManyWhere
:: forall record backend m.
( backend ~ PersistEntityBackend record
, BackendCompatible SqlBackend backend
, PersistEntityBackend record ~ SqlBackend
, PersistEntity record
, OnlyOneUniqueKey record
, MonadIO m
)
=> [record]
-- ^ A list of the records you want to insert, or update
-> [HandleUpdateCollision record]
-- ^ A list of the fields you want to copy over.
-> [Update record]
-- ^ A list of the updates to apply that aren't dependent on the record
-- being inserted.
-> [Filter record]
-- ^ A filter condition that dictates the scope of the updates
-> ReaderT backend m ()
upsertManyWhere [] _ _ _ = return ()
upsertManyWhere records fieldValues updates filters = do
conn <- asks projectBackend
uncurry rawExecute $
mkBulkUpsertQuery records conn fieldValues updates filters
conn <- asks projectBackend
let uniqDef = -- onlyOneUniqueDef (Nothing :: Maybe record)
case entityUniques (entityDef (Nothing :: Maybe record)) of
[uniq] -> uniq
_ -> error "impossible due to OnlyOneUniqueKey constraint"
-- TODO: use onlyOneUniqueDef when it's exported
uncurry rawExecute $
mkBulkUpsertQuery records conn fieldValues updates filters uniqDef

-- | Exclude any record field if it doesn't match the filter record. Used only in `upsertWhere` and
-- `upsertManyWhere`
--
-- TODO: we could probably make a sum type for the `Filter` record that's passed into the `upsertWhere` and
-- `upsertManyWhere` methods that has similar behavior to the HandleCollisionUpdate type.
--
-- @since 2.12.1.0
excludeNotEqualToOriginal
:: (PersistField typ, PersistEntity rec)
=> EntityField rec typ
-> Filter rec
excludeNotEqualToOriginal field =
Filter
{ filterField =
field
, filterFilter =
Ne
, filterValue =
UnsafeValue $
PersistLiteral_
Unescaped
bsForExcludedField
}
where
bsForExcludedField =
T.encodeUtf8
$ "EXCLUDED."
<> fieldName field

-- | This creates the query for 'upsertManyWhere'. If you
-- provide an empty list of updates to perform, then it will generate
-- a dummy/no-op update using the first field of the record. This avoids
-- duplicate key exceptions.
mkBulkUpsertQuery
:: (PersistEntity record, PersistEntityBackend record ~ SqlBackend)
=> [record] -- ^ A list of the records you want to insert, or update
:: (PersistEntity record, PersistEntityBackend record ~ SqlBackend, OnlyOneUniqueKey record)
=> [record]
-- ^ A list of the records you want to insert, or update
-> SqlBackend
-> [HandleUpdateCollision record] -- ^ A list of the fields you want to copy over.
-> [Update record] -- ^ A list of the updates to apply that aren't dependent on the record being inserted.
-> [Filter record] -- ^ A filter condition that dictates the scope of the updates
-> [HandleUpdateCollision record]
-- ^ A list of the fields you want to copy over.
-> [Update record]
-- ^ A list of the updates to apply that aren't dependent on the record being inserted.
-> [Filter record]
-- ^ A filter condition that dictates the scope of the updates
-> UniqueDef
-- ^ The specific uniqueness constraint to use on the record. Postgres
-- rquires that we use exactly one relevant constraint, and it can't do
-- a catch-all. How frustrating!
-> (Text, [PersistValue])
mkBulkUpsertQuery records conn fieldValues updates filters =
mkBulkUpsertQuery records conn fieldValues updates filters uniqDef =
(q, recordValues <> updsValues <> copyUnlessValues <> whereVals)
where
mfieldDef x = case x of
Expand All @@ -1906,41 +1926,55 @@ mkBulkUpsertQuery records conn fieldValues updates filters =
(fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues
fieldDbToText = escapeF . fieldDB
entityDef' = entityDef records
conflictColumns = escapeF . fieldDB <$> entityKeyFields entityDef'
conflictColumns =
map (escapeF . snd) $ uniqueFields uniqDef
firstField = case entityFieldNames of
[] -> error "The entity you're trying to insert does not have any fields."
(field:_) -> field
entityFieldNames = map fieldDbToText (entityFields entityDef')
nameOfTable = escapeE . entityDB $ entityDef'
copyUnlessValues = map snd fieldsToMaybeCopy
recordValues = concatMap (map toPersistValue . toPersistFields) records
recordPlaceholders = Util.commaSeparated $ map (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) records
recordPlaceholders =
Util.commaSeparated
$ map (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields)
$ records
mkCondFieldSet n _ =
T.concat
[ n
, "=COALESCE("
, "NULLIF("
, "EXCLUDED."
, n
, ","
, "?"
, ")"
, ","
, nameOfTable
, "."
, n
,")"
]
T.concat
[ n
, "=COALESCE("
, "NULLIF("
, "EXCLUDED."
, n
, ","
, "?"
, ")"
, ","
, nameOfTable
, "."
, n
,")"
]
condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy
fieldSets = map (\n -> T.concat [n, "=EXCLUDED.", n, ""]) updateFieldNames
upds = map (Util.mkUpdateText' (escapeF) (\n -> T.concat [nameOfTable, ".", n])) updates
updsValues = map (\(Update _ val _) -> toPersistValue val) updates
(wher, whereVals) = if null filters
then ("", [])
else (filterClauseWithVals (Just PrefixTableName) conn filters)
updateText = case fieldSets <> upds <> condFieldSets of
[] -> T.concat [firstField, "=EXCLUDED.", firstField]
xs -> Util.commaSeparated xs
(wher, whereVals) =
if null filters
then ("", [])
else (filterClauseWithVals (Just PrefixTableName) conn filters)
updateText =
case fieldSets <> upds <> condFieldSets of
[] ->
-- This case is really annoying, and probably unlikely to be
-- actually hit - someone would have had to call something like
-- `upsertManyWhere [] [] []`, but that would have been caught
-- by the prior case.
-- Would be nice to have something like a `NonEmpty (These ...)`
-- instead of multiple lists...
T.concat [firstField, "=", nameOfTable, ".", firstField]
xs ->
Util.commaSeparated xs
q = T.concat
[ "INSERT INTO "
, nameOfTable
Expand Down
3 changes: 2 additions & 1 deletion persistent-postgresql/persistent-postgresql.cabal
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: persistent-postgresql
version: 2.12.1.0
version: 2.12.1.1
license: MIT
license-file: LICENSE
author: Felipe Lessa, Michael Snoyman <michael@snoyman.com>
Expand Down Expand Up @@ -68,6 +68,7 @@ test-suite test
, HUnit
, hspec >= 2.4
, hspec-expectations
, hspec-expectations-lifted
, monad-logger
, QuickCheck
, quickcheck-instances
Expand Down
109 changes: 68 additions & 41 deletions persistent-postgresql/test/PgInit.hs
Original file line number Diff line number Diff line change
@@ -1,58 +1,85 @@
{-# LANGUAGE ScopedTypeVariables, OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module PgInit (
runConn
, runConn_
, runConnAssert
, runConnAssertUseConf

, MonadIO
, persistSettings
, MkPersistSettings (..)
, BackendKey(..)
, GenerateKey(..)

-- re-exports
, module Control.Monad.Trans.Reader
, module Control.Monad
, module Database.Persist.Sql
, module Database.Persist
, module Database.Persist.Sql.Raw.QQ
, module Init
, module Test.Hspec
, module Test.HUnit
, BS.ByteString
, Int32, Int64
, liftIO
, mkPersist, mkMigrate, share, sqlSettings, persistLowerCase, persistUpperCase
, SomeException
, Text
, TestFn(..)
) where
module PgInit
( runConn
, runConn_
, runConnAssert
, runConnAssertUseConf

, MonadIO
, persistSettings
, MkPersistSettings (..)
, BackendKey(..)
, GenerateKey(..)

-- re-exports
, module Control.Monad.Trans.Reader
, module Control.Monad
, module Database.Persist.Sql
, module Database.Persist
, module Database.Persist.Sql.Raw.QQ
, module Init
, module Test.Hspec
, module Test.Hspec.Expectations.Lifted
, module Test.HUnit
, BS.ByteString
, Int32, Int64
, liftIO
, mkPersist, mkMigrate, share, sqlSettings, persistLowerCase, persistUpperCase
, SomeException
, Text
, TestFn(..)
, LoggingT
, ResourceT
) where

import Init
( TestFn(..), truncateTimeOfDay, truncateUTCTime
, truncateToMicro, arbText, liftA2, GenerateKey(..)
, (@/=), (@==), (==@), MonadFail
, assertNotEqual, assertNotEmpty, assertEmpty, asIO
, isTravis, RunDb
)
( GenerateKey(..)
, MonadFail
, RunDb
, TestFn(..)
, arbText
, asIO
, assertEmpty
, assertNotEmpty
, assertNotEqual
, isTravis
, liftA2
, truncateTimeOfDay
, truncateToMicro
, truncateUTCTime
, (==@)
, (@/=)
, (@==)
)

-- re-exports
import Control.Exception (SomeException)
import UnliftIO
import Control.Monad (void, replicateM, liftM, when, forM_)
import Control.Monad (forM_, liftM, replicateM, void, when)
import Control.Monad.Trans.Reader
import Data.Aeson (Value(..))
import Database.Persist.TH (mkPersist, mkMigrate, share, sqlSettings, persistLowerCase, persistUpperCase, MkPersistSettings(..))
import Database.Persist.Postgresql.JSON ()
import Database.Persist.Sql.Raw.QQ
import Database.Persist.Postgresql.JSON()
import Database.Persist.TH
( MkPersistSettings(..)
, mkMigrate
, mkPersist
, persistLowerCase
, persistUpperCase
, share
, sqlSettings
)
import Test.Hspec
(Spec, afterAll_, before, beforeAll, describe, fdescribe, fit, it,
before_, SpecWith, Arg, hspec)
import Test.Hspec.Expectations.Lifted
import Test.QuickCheck.Instances ()
import UnliftIO

-- testing
import Test.HUnit ((@?=),(@=?), Assertion, assertFailure, assertBool)
import Test.HUnit (Assertion, assertBool, assertFailure, (@=?), (@?=))
import Test.QuickCheck

import Control.Monad (unless, (>=>))
Expand Down
Loading