diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/CustomType.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/CustomType.hs
new file mode 100644
index 000000000..b71f082af
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/CustomType.hs
@@ -0,0 +1,275 @@
+{-# LANGUAGE OverloadedStrings #-}
+
+-- |
+-- Module : Database.Persist.Postgresql.CustomType
+-- Description : Guide for writing custom PersistField instances with the binary protocol backend
+--
+-- = Overview
+--
+-- The pipeline backend uses PostgreSQL's binary wire protocol instead of the
+-- text protocol used by @persistent-postgresql@ (via @postgresql-simple@). This
+-- is faster — no parsing/rendering of text representations — but it changes
+-- how some custom types need to be encoded.
+--
+-- Most custom @PersistField@ instances work without modification. This module
+-- documents the cases where you may need to adjust your approach.
+--
+-- = How PersistValue maps to PostgreSQL
+--
+-- The binary backend encodes each @PersistValue@ with a specific PostgreSQL
+-- OID (type identifier) and binary representation:
+--
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | PersistValue | PG Type | OID | Notes |
+-- +========================+============+=============+===================================+
+-- | @PersistText@ | text | 25 | Binary UTF-8 |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistByteString@ | bytea | 17 | Raw binary |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistInt64@ | int8 | 20 | 8-byte big-endian |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistDouble@ | float8 | 701 | IEEE 754 double |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistRational@ | numeric | 1700 | Arbitrary precision |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistBool@ | bool | 16 | Single byte |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistDay@ | date | 1082 | Days since 2000-01-01 |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistTimeOfDay@ | time | 1083 | Microseconds since midnight |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistUTCTime@ | timestamptz| 1184 | Microseconds since 2000-01-01 UTC |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistList@ | (unknown) | 0 | JSON text, PG infers type |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistArray@ | \[] | varies | Native PostgreSQL array |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistMap@ | (unknown) | 0 | JSON text, PG infers type |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistLiteral_ Escaped@ | (unknown) | 0 | Text format, PG infers type |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistLiteral_ DbSpecific@| (unknown) | 0 | Text format, PG infers type |
+-- +------------------------+------------+-------------+-----------------------------------+
+-- | @PersistLiteral_ Unescaped@ | — | — | Inlined into SQL text |
+-- +------------------------+------------+-------------+-----------------------------------+
+--
+-- = Custom Types That Just Work
+--
+-- If your custom type stores as one of the standard @PersistValue@ constructors,
+-- it will work without any changes:
+--
+-- @
+-- newtype Email = Email Text
+--
+-- instance PersistField Email where
+-- toPersistValue (Email t) = PersistText t
+-- fromPersistValue (PersistText t) = Right (Email t)
+-- fromPersistValue _ = Left "Expected PersistText for Email"
+--
+-- instance PersistFieldSql Email where
+-- sqlType _ = SqlString
+-- @
+--
+-- This covers the vast majority of custom types: newtypes over @Text@, @Int@,
+-- @Double@, @Bool@, @Day@, @UTCTime@, @ByteString@, etc.
+--
+-- = Types Requiring OID 0 (Type Inference)
+--
+-- When PostgreSQL's column type doesn't match any built-in @PersistValue@ OID,
+-- you need the value to be sent with OID 0 so PostgreSQL infers the type from
+-- the column. Use @PersistLiteral_ Escaped@ or @PersistLiteral_ DbSpecific@:
+--
+-- @
+-- newtype UUID = UUID Text
+--
+-- instance PersistField UUID where
+-- toPersistValue (UUID t) = PersistLiteral_ Escaped (encodeUtf8 t)
+-- fromPersistValue (PersistLiteral_ Escaped bs) = Right (UUID (decodeUtf8 bs))
+-- fromPersistValue _ = Left "Expected PersistLiteral_ Escaped for UUID"
+--
+-- instance PersistFieldSql UUID where
+-- sqlType _ = SqlOther "UUID"
+-- @
+--
+-- The bytes in @PersistLiteral_ Escaped@ are sent as-is with OID 0 in text
+-- format. PostgreSQL sees the text representation and casts it to the column
+-- type (UUID in this case). This is the same behavior as @postgresql-simple@'s
+-- @Unknown@ type.
+--
+-- __Other types that use this pattern:__ @INET@, @CIDR@, @MACADDR@, @LTREE@,
+-- @HSTORE@, custom enum types, PostGIS geometry types.
+--
+-- = Types Requiring SQL Inlining
+--
+-- Some PostgreSQL types cannot accept /any/ parameterized input — even with
+-- OID 0, PostgreSQL won't cast. For these, use @PersistLiteral_ Unescaped@
+-- to inline the value directly into the SQL text:
+--
+-- @
+-- newtype PgInterval = PgInterval NominalDiffTime
+--
+-- instance PersistField PgInterval where
+-- toPersistValue (PgInterval ndt) =
+-- let (sci, _) = fromRationalRepetendUnlimited (toRational ndt)
+-- s = formatScientific Fixed Nothing sci
+-- in PersistLiteral_ Unescaped (encodeUtf8 ("'" <> pack s <> " seconds'::interval"))
+-- fromPersistValue (PersistRational r) = Right (PgInterval (fromRational r))
+-- fromPersistValue _ = Left "Expected PersistRational for PgInterval"
+-- @
+--
+-- __Important:__ @PersistLiteral_ Unescaped@ is inlined into the SQL string
+-- /before/ parameter encoding. The value must be a valid SQL expression. Always
+-- use explicit casts (e.g. @'...'::interval@) and ensure the content is safe
+-- (no user-controlled input without validation).
+--
+-- = JSON and JSONB Columns
+--
+-- Aeson's @Value@ type is stored as @PersistLiteral_ Escaped@ with the JSON
+-- bytes. The binary backend sends this with OID 0, so PostgreSQL infers the
+-- @jsonb@ type from the column.
+--
+-- If you have a custom type that serializes to JSON for a @jsonb@ column:
+--
+-- @
+-- import Data.Aeson (ToJSON, FromJSON, encode, eitherDecodeStrict)
+-- import qualified Data.ByteString.Lazy as BSL
+--
+-- instance PersistField MyJsonType where
+-- toPersistValue = PersistLiteralEscaped . BSL.toStrict . encode
+-- fromPersistValue (PersistByteString bs) =
+-- case eitherDecodeStrict bs of
+-- Right v -> Right v
+-- Left err -> Left (pack err)
+-- fromPersistValue _ = Left "Expected PersistByteString for MyJsonType"
+--
+-- instance PersistFieldSql MyJsonType where
+-- sqlType _ = SqlOther "jsonb"
+-- @
+--
+-- The @PersistLiteralEscaped@ (= @PersistLiteral_ Escaped@) encoding sends
+-- with OID 0, so PostgreSQL accepts the JSON text for a @jsonb@ column.
+--
+-- = PersistList vs PersistArray
+--
+-- These two constructors serve different purposes in this backend:
+--
+-- [@PersistList@] Encodes as __JSON text__ with OID 0 (unknown). PostgreSQL
+-- infers the column type, so this works for both @VARCHAR@ and @jsonb@
+-- columns. Used by persistent internally for embedded entities. If your
+-- custom type uses @PersistList@, it will be stored as a JSON text array
+-- like @[1,2,3]@.
+--
+-- [@PersistArray@] Encodes as a __native PostgreSQL array__ (e.g. @int8[]@,
+-- @text[]@). The element type is inferred from the first non-null element.
+-- This is used by the @IN \u2192 ANY@ rewriting: @WHERE id IN (?,?,?)@
+-- becomes @WHERE id = ANY($1)@ with a single @PersistArray@ parameter.
+--
+-- If you want your custom list type to use native arrays:
+--
+-- @
+-- newtype IntList = IntList [Int64]
+--
+-- instance PersistField IntList where
+-- toPersistValue (IntList xs) = PersistArray (map PersistInt64 xs)
+-- fromPersistValue (PersistList xs) = IntList \<$\> mapM fromPersistValue xs
+-- fromPersistValue _ = Left "Expected PersistList for IntList"
+--
+-- instance PersistFieldSql IntList where
+-- sqlType _ = SqlOther "int8[]"
+-- @
+--
+-- = PostgreSQL Enum Types
+--
+-- Custom PostgreSQL enums work with @PersistLiteral_ Escaped@ or @PersistText@:
+--
+-- @
+-- data Color = Red | Green | Blue
+--
+-- instance PersistField Color where
+-- toPersistValue Red = PersistLiteral_ Escaped "red"
+-- toPersistValue Green = PersistLiteral_ Escaped "green"
+-- toPersistValue Blue = PersistLiteral_ Escaped "blue"
+-- fromPersistValue (PersistLiteral_ Escaped "red") = Right Red
+-- fromPersistValue (PersistLiteral_ Escaped "green") = Right Green
+-- fromPersistValue (PersistLiteral_ Escaped "blue") = Right Blue
+-- fromPersistValue _ = Left "Invalid Color"
+--
+-- instance PersistFieldSql Color where
+-- sqlType _ = SqlOther "color" -- must match CREATE TYPE color AS ENUM (...)
+-- @
+--
+-- Using @PersistLiteral_ Escaped@ sends with OID 0, letting PostgreSQL match
+-- the text @\"red\"@ against the enum type.
+--
+-- = When Do You Need Explicit SQL Casts?
+--
+-- The short answer: __almost never__ with standard persistent operations.
+--
+-- Standard persistent operations (@insert@, @update@, @selectList@,
+-- @deleteWhere@, etc.) generate SQL like @INSERT INTO "table" ("col") VALUES (?)@
+-- or @UPDATE "table" SET "col" = ? WHERE ...@. In these contexts, PostgreSQL
+-- knows the column type from the table definition and will infer the parameter
+-- type from OID 0. This covers:
+--
+-- * Custom enums via @PersistLiteral_ Escaped@
+-- * UUIDs via @PersistLiteral_ Escaped@
+-- * JSONB via @PersistLiteral_ Escaped@ or @PersistList@/@PersistMap@
+-- * All standard @PersistValue@ types
+--
+-- You __do__ need explicit casts (@::type@) in these cases:
+--
+-- * __@rawSql@ / @rawExecute@__ where PostgreSQL can't infer the type from
+-- context. For example, @SELECT ? + 1@ is ambiguous — PostgreSQL doesn't
+-- know if @?@ is @int4@, @int8@, @numeric@, etc. Write
+-- @SELECT ?::int8 + 1@ instead.
+--
+-- * __Function arguments__ where PostgreSQL has multiple overloads. For example,
+-- @jsonb_build_object('key', ?)@ may need @?::text@ if PostgreSQL can't
+-- resolve the overload.
+--
+-- * __@PersistLiteral_ Unescaped@__ values are always inlined into SQL and
+-- should include their own cast (e.g. @'5 seconds'::interval@).
+--
+-- = Unsupported Operations
+--
+-- The following PostgreSQL features are __incompatible with pipeline mode__ and
+-- will not work with this backend:
+--
+-- * __COPY__ (@COPY FROM@ / @COPY TO@) — uses a separate sub-protocol that
+-- conflicts with pipeline mode.
+--
+-- * __LISTEN/NOTIFY__ — asynchronous notifications require consuming results
+-- outside the pipeline flow, which conflicts with the pending result counter.
+--
+-- * __Large Objects__ — large object operations use a separate API that is not
+-- supported in pipeline mode.
+--
+-- These operations are also uncommon in persistent-based applications. If you
+-- need them, use the raw @LibPQ.Connection@ via @getPipelineConn@ and manage
+-- the pipeline state yourself.
+--
+-- = Migration from persistent-postgresql
+--
+-- When migrating from @persistent-postgresql@ to @persistent-postgresql-ng@:
+--
+-- 1. __Most types work unchanged.__ Standard @PersistField@ instances using
+-- @PersistText@, @PersistInt64@, @PersistBool@, etc. need no changes.
+--
+-- 2. __Types using @PersistRational@ for non-numeric columns__ (like
+-- @interval@) need to switch to @PersistLiteral_ Unescaped@ with an
+-- explicit cast. In the text protocol, PostgreSQL would auto-cast a text
+-- number to @interval@; the binary protocol sends with OID 1700 (numeric)
+-- which PostgreSQL won't cast.
+--
+-- 3. __UUID types__ work if they use @PersistLiteral_ Escaped@ with the
+-- hex text representation. The backend sends with OID 0 and decodes
+-- binary UUIDs to hex text on the way back.
+--
+-- 4. __Array columns__ can now use @PersistArray@ for native PostgreSQL
+-- arrays instead of JSON-encoded text. The @IN@ operator automatically
+-- uses native arrays via the @= ANY(...)@ rewriting.
+--
+-- 5. __JSONB__ values stored via @PersistLiteral_ Escaped@ (the standard
+-- pattern from the JSON module) work correctly — OID 0 lets PostgreSQL
+-- accept the JSON text for @jsonb@ columns.
+module Database.Persist.Postgresql.CustomType () where
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal.hs
new file mode 100644
index 000000000..c0404ee6a
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal.hs
@@ -0,0 +1,70 @@
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE OverloadedStrings #-}
+
+-- | Shared internal utilities for the pipeline backend.
+--
+-- Contains escape functions, the 'PgInterval' type, and re-exports from the
+-- migration module.
+module Database.Persist.Postgresql.Internal
+ ( PgInterval (..)
+ , AlterDB (..)
+ , AlterTable (..)
+ , AlterColumn (..)
+ , SafeToRemove
+ , migrateStructured
+ , migrateEntitiesStructured
+ , mockMigrateStructured
+ , addTable
+ , findAlters
+ , maySerial
+ , mayDefault
+ , showSqlType
+ , showColumn
+ , showAlter
+ , showAlterDb
+ , showAlterTable
+ , getAddReference
+ , udToPair
+ , safeToRemove
+ , postgresMkColumns
+ , getAlters
+ , escapeE
+ , escapeF
+ , escape
+ ) where
+
+import Data.Scientific (FPFormat (Fixed), formatScientific, fromRationalRepetendUnlimited)
+import qualified Data.Text
+import qualified Data.Text.Encoding
+import Data.Time (NominalDiffTime)
+import Database.Persist.Sql
+import Database.Persist.Postgresql.Internal.Migration
+
+-- | Represent Postgres interval using NominalDiffTime.
+--
+-- Note that this type cannot be losslessly round tripped through PostgreSQL.
+-- For example the value @'PgInterval' 0.0000009@ will truncate extra
+-- precision. And the value @'PgInterval' 9223372036854.775808@ will overflow.
+newtype PgInterval = PgInterval {getPgInterval :: NominalDiffTime}
+ deriving (Eq, Show)
+
+instance PersistField PgInterval where
+ toPersistValue (PgInterval ndt) =
+ -- Inline the interval literal into the SQL text because binary-protocol
+ -- numeric (PersistRational) can't be implicitly cast to interval, and
+ -- text with OID 25 also can't. Unescaped literals get inlined before
+ -- parameter encoding.
+ let r = toRational ndt
+ -- Format as a quoted interval literal: '123.456000 seconds'::interval
+ -- Use Scientific for exact decimal representation (no Double precision loss)
+ (sci, _) = fromRationalRepetendUnlimited r
+ s = formatScientific Fixed Nothing sci
+ in PersistLiteral_ Unescaped (Data.Text.Encoding.encodeUtf8
+ (Data.Text.pack ("'" <> s <> " seconds'::interval")))
+ fromPersistValue (PersistRational r) =
+ Right $ PgInterval (fromRational r)
+ fromPersistValue x =
+ Left $ "PgInterval: expected PersistRational, got: " <> Data.Text.pack (show x)
+
+instance PersistFieldSql PgInterval where
+ sqlType _ = SqlOther "interval"
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Decoding.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Decoding.hs
new file mode 100644
index 000000000..88d4655d0
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Decoding.hs
@@ -0,0 +1,162 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE ViewPatterns #-}
+
+-- | Decode binary result columns from PostgreSQL into 'PersistValue'.
+--
+-- Replaces the @getGetter@/@builtinGetters@ from the existing
+-- @persistent-postgresql@ package, using @postgresql-binary@ decoders
+-- directly instead of going through @postgresql-simple@.
+module Database.Persist.Postgresql.Internal.Decoding
+ ( decodePersistValue
+ ) where
+
+import Data.ByteString (ByteString)
+import qualified Data.ByteString as BS
+import Data.Int (Int16, Int32, Int64)
+import Data.Maybe (fromMaybe)
+import Control.Monad (replicateM)
+import Data.Word (Word8)
+import Data.Scientific (Scientific)
+import Data.Text (Text)
+import Data.Time (localTimeToUTC, utc)
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+import Database.Persist.Sql (PersistValue (..))
+import qualified PostgreSQL.Binary.Decoding as PD
+
+import Database.Persist.Postgresql.Internal.PgType
+
+-- | Decode a binary column value from PostgreSQL into a 'PersistValue'.
+--
+-- If the value is @Nothing@, returns @PersistNull@.
+-- Otherwise, dispatches on the column OID to the appropriate decoder.
+-- Falls back to @PersistLiteralEscaped@ for unknown OIDs.
+decodePersistValue :: LibPQ.Oid -> Maybe ByteString -> Either Text PersistValue
+decodePersistValue _ Nothing = Right PersistNull
+decodePersistValue (fromOid -> pt) (Just bs) = decodeByType pt bs
+
+-- | Decode based on classified 'PgType'.
+decodeByType :: PgType -> ByteString -> Either Text PersistValue
+
+-- Scalar types
+decodeByType (Scalar PgBool) bs = PersistBool <$> run PD.bool bs
+decodeByType (Scalar PgBytea) bs = PersistByteString <$> run PD.bytea_strict bs
+decodeByType (Scalar PgChar) bs = PersistText <$> run PD.text_strict bs
+decodeByType (Scalar PgName) bs = PersistText <$> run PD.text_strict bs
+decodeByType (Scalar PgInt8) bs = PersistInt64 <$> run (PD.int :: PD.Value Int64) bs
+decodeByType (Scalar PgInt2) bs = PersistInt64 . fromIntegral <$> run (PD.int :: PD.Value Int16) bs
+decodeByType (Scalar PgInt4) bs = PersistInt64 . fromIntegral <$> run (PD.int :: PD.Value Int32) bs
+decodeByType (Scalar PgText) bs = PersistText <$> run PD.text_strict bs
+decodeByType (Scalar PgXml) bs = PersistByteString <$> run PD.bytea_strict bs
+decodeByType (Scalar PgFloat4) bs = PersistDouble . realToFrac <$> run PD.float4 bs
+decodeByType (Scalar PgFloat8) bs = PersistDouble <$> run PD.float8 bs
+decodeByType (Scalar PgMoney) bs = PersistRational . fromIntegral <$> run (PD.int :: PD.Value Int64) bs
+decodeByType (Scalar PgBpchar) bs = PersistText <$> run PD.text_strict bs
+decodeByType (Scalar PgVarchar) bs = PersistText <$> run PD.text_strict bs
+decodeByType (Scalar PgDate) bs = PersistDay <$> run PD.date bs
+decodeByType (Scalar PgTime) bs = PersistTimeOfDay <$> run PD.time_int bs
+decodeByType (Scalar PgTimestamp) bs = PersistUTCTime . localTimeToUTC utc <$> run PD.timestamp_int bs
+decodeByType (Scalar PgTimestamptz) bs = PersistUTCTime <$> run PD.timestamptz_int bs
+decodeByType (Scalar PgInterval) bs = PersistRational . toRational <$> run PD.interval_int bs
+decodeByType (Scalar PgBit) bs = PersistByteString <$> run PD.bytea_strict bs
+decodeByType (Scalar PgVarbit) bs = PersistByteString <$> run PD.bytea_strict bs
+decodeByType (Scalar PgNumeric) bs = decodeNumeric bs
+decodeByType (Scalar PgVoid) _ = Right PersistNull
+decodeByType (Scalar PgJson) bs = PersistByteString <$> run PD.bytea_strict bs
+decodeByType (Scalar PgJsonb) bs = PersistByteString <$> decodeJsonb bs
+decodeByType (Scalar PgUnknown) bs = PersistByteString <$> run PD.bytea_strict bs
+decodeByType (Scalar PgUuid) bs = do
+ raw <- run PD.bytea_strict bs
+ Right $ PersistLiteralEscaped (uuidBytesToText raw)
+
+-- Array types
+decodeByType (Array PgBool) bs = decodeArray (PersistBool <$> PD.bool) bs
+decodeByType (Array PgBytea) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs
+decodeByType (Array PgChar) bs = decodeArray (PersistText <$> PD.text_strict) bs
+decodeByType (Array PgName) bs = decodeArray (PersistText <$> PD.text_strict) bs
+decodeByType (Array PgInt8) bs = decodeArray (PersistInt64 <$> (PD.int :: PD.Value Int64)) bs
+decodeByType (Array PgInt2) bs = decodeArray (PersistInt64 . fromIntegral <$> (PD.int :: PD.Value Int16)) bs
+decodeByType (Array PgInt4) bs = decodeArray (PersistInt64 . fromIntegral <$> (PD.int :: PD.Value Int32)) bs
+decodeByType (Array PgText) bs = decodeArray (PersistText <$> PD.text_strict) bs
+decodeByType (Array PgXml) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs
+decodeByType (Array PgFloat4) bs = decodeArray (PersistDouble . realToFrac <$> PD.float4) bs
+decodeByType (Array PgFloat8) bs = decodeArray (PersistDouble <$> PD.float8) bs
+decodeByType (Array PgTimestamp) bs = decodeArray (PersistUTCTime . localTimeToUTC utc <$> PD.timestamp_int) bs
+decodeByType (Array PgTimestamptz) bs = decodeArray (PersistUTCTime <$> PD.timestamptz_int) bs
+decodeByType (Array PgDate) bs = decodeArray (PersistDay <$> PD.date) bs
+decodeByType (Array PgTime) bs = decodeArray (PersistTimeOfDay <$> PD.time_int) bs
+decodeByType (Array PgMoney) bs = decodeArray (PersistRational . fromIntegral <$> (PD.int :: PD.Value Int64)) bs
+decodeByType (Array PgBpchar) bs = decodeArray (PersistText <$> PD.text_strict) bs
+decodeByType (Array PgVarchar) bs = decodeArray (PersistText <$> PD.text_strict) bs
+decodeByType (Array PgInterval) bs = decodeArray (PersistRational . toRational <$> PD.interval_int) bs
+decodeByType (Array PgBit) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs
+decodeByType (Array PgVarbit) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs
+decodeByType (Array PgNumeric) bs = decodeArrayNumeric bs
+decodeByType (Array PgJson) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs
+decodeByType (Array PgJsonb) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs
+decodeByType (Array PgUuid) bs = decodeArray (PersistLiteralEscaped . uuidBytesToText <$> PD.bytea_strict) bs
+
+-- Unhandled array element types fall through to raw bytes
+decodeByType (Array PgVoid) bs = Right $ PersistLiteralEscaped bs
+decodeByType (Array PgUnknown) bs = Right $ PersistLiteralEscaped bs
+
+-- User-defined types: return raw bytes (callers should use the direct path)
+decodeByType (Composite _ _) bs = Right $ PersistLiteralEscaped bs
+decodeByType (CompositeArray _) bs = Right $ PersistLiteralEscaped bs
+decodeByType (Enum _ _) bs = PersistText <$> run PD.text_strict bs
+decodeByType (EnumArray _) bs = decodeArray (PersistText <$> PD.text_strict) bs
+
+-- Unrecognized OID: return raw bytes
+decodeByType (Unrecognized _) bs = Right $ PersistLiteralEscaped bs
+
+run :: PD.Value a -> ByteString -> Either Text a
+run = PD.valueParser
+
+decodeNumeric :: ByteString -> Either Text PersistValue
+decodeNumeric bs = do
+ sci <- run PD.numeric bs
+ Right $ PersistRational (toRational (sci :: Scientific))
+
+-- | Decode JSONB binary format. The binary format has a 1-byte version prefix
+-- that @jsonb_bytes@ strips for us.
+decodeJsonb :: ByteString -> Either Text ByteString
+decodeJsonb = PD.valueParser (PD.jsonb_bytes Right)
+
+decodeArray :: PD.Value PersistValue -> ByteString -> Either Text PersistValue
+decodeArray elemDecoder bs =
+ PersistList . map nullable <$>
+ run (PD.array (PD.dimensionArray replicateM (PD.nullableValueArray elemDecoder))) bs
+ where
+ nullable = fromMaybe PersistNull
+
+decodeArrayNumeric :: ByteString -> Either Text PersistValue
+decodeArrayNumeric bs =
+ PersistList . map nullable <$>
+ run (PD.array (PD.dimensionArray replicateM (PD.nullableValueArray numDecoder))) bs
+ where
+ numDecoder = do
+ sci <- PD.numeric
+ pure $ PersistRational (toRational (sci :: Scientific))
+ nullable = fromMaybe PersistNull
+
+-- | Convert 16 raw UUID bytes to the text representation
+-- "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" as a ByteString.
+uuidBytesToText :: ByteString -> ByteString
+uuidBytesToText raw
+ | BS.length raw /= 16 = raw
+ | otherwise = BS.pack $ concat
+ [ hexBytes 0 4, [0x2D]
+ , hexBytes 4 2, [0x2D]
+ , hexBytes 6 2, [0x2D]
+ , hexBytes 8 2, [0x2D]
+ , hexBytes 10 6
+ ]
+ where
+ hexBytes offset count =
+ concatMap byteToHex [BS.index raw (offset + i) | i <- [0 .. count - 1]]
+ byteToHex :: Word8 -> [Word8]
+ byteToHex b = [hexNibble (b `div` 16), hexNibble (b `mod` 16)]
+ hexNibble :: Word8 -> Word8
+ hexNibble n
+ | n < 10 = 0x30 + n
+ | otherwise = 0x61 + n - 10
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectDecode.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectDecode.hs
new file mode 100644
index 000000000..997d01013
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectDecode.hs
@@ -0,0 +1,255 @@
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+-- | Direct field decoding for PostgreSQL.
+--
+-- 'prepareField' inspects the column OID once (classifying it into a
+-- 'PgType'), then returns a 'FieldRunner' closure that calls the
+-- appropriate @postgresql-binary@ decoder on each row without
+-- re-dispatching on the OID.
+module Database.Persist.Postgresql.Internal.DirectDecode
+ ( PgRowEnv (..)
+ , compositeFieldDecode
+ ) where
+
+import Data.ByteString (ByteString)
+import Data.Int (Int16, Int32, Int64)
+import Data.Scientific (Scientific)
+import Data.Text (Text)
+import qualified Data.Text as T
+import Data.Time (Day, TimeOfDay, UTCTime, localTimeToUTC, utc)
+import qualified Data.Vector as V
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+import qualified PostgreSQL.Binary.Decoding as PD
+
+import Database.Persist.DirectDecode (FieldDecode (..), FieldRunner (..))
+import Database.Persist.Names (FieldNameDB (..))
+import Database.Persist.Postgresql.Internal.PgCodec (PgDecode (..), runPgDecoder)
+import Database.Persist.Postgresql.Internal.PgType
+
+-- | Row environment for PostgreSQL direct decoding.
+data PgRowEnv = PgRowEnv
+ { pgResult :: !LibPQ.Result
+ , pgRow :: !LibPQ.Row
+ , pgCols :: !(V.Vector (LibPQ.Column, PgType))
+ , pgRowCache :: !OidCache
+ }
+
+---------------------------------------------------------------------------
+-- Helpers
+---------------------------------------------------------------------------
+
+{-# INLINE decodeWith #-}
+decodeWith :: PD.Value a -> ByteString -> (Text -> IO r) -> (a -> IO r) -> IO r
+decodeWith decoder bs onErr onOk =
+ case PD.valueParser decoder bs of
+ Left err -> onErr err
+ Right v -> onOk v
+
+{-# INLINE readBytes #-}
+readBytes :: PgRowEnv -> Int -> (Text -> IO r) -> (ByteString -> IO r) -> IO r
+readBytes env col onErr onOk = do
+ let (c, _) = pgCols env V.! col
+ mbs <- LibPQ.getvalue' (pgResult env) (pgRow env) c
+ case mbs of
+ Nothing -> onErr "unexpected NULL"
+ Just bs -> onOk bs
+
+{-# INLINE mismatch #-}
+mismatch :: PgType -> Text -> (Text -> IO r) -> IO r
+mismatch pt expected onErr =
+ onErr ("cannot decode " <> T.pack (show pt) <> " as " <> expected)
+
+-- | Build a FieldRunner that reads bytes from the row and decodes them.
+{-# INLINE mkRunner #-}
+mkRunner :: Int -> PD.Value a -> FieldRunner PgRowEnv a
+mkRunner col decoder = FieldRunner $ \env onErr onOk ->
+ readBytes env col onErr $ \bs -> decodeWith decoder bs onErr onOk
+
+-- | Build a FieldRunner that reads bytes and applies a post-decode conversion.
+{-# INLINE mkRunnerWith #-}
+mkRunnerWith :: Int -> PD.Value a -> (a -> b) -> FieldRunner PgRowEnv b
+mkRunnerWith col decoder f = FieldRunner $ \env onErr onOk ->
+ readBytes env col onErr $ \bs -> decodeWith decoder bs onErr (onOk . f)
+
+---------------------------------------------------------------------------
+-- FieldDecode instances: prepareField inspects OID once, returns FieldRunner
+---------------------------------------------------------------------------
+
+instance FieldDecode PgRowEnv Bool where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgBool -> onOk (mkRunner col PD.bool)
+ _ -> mismatch pt "Bool" onErr
+
+instance FieldDecode PgRowEnv Int16 where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgInt2 -> onOk (mkRunner col (PD.int :: PD.Value Int16))
+ _ -> mismatch pt "Int16" onErr
+
+instance FieldDecode PgRowEnv Int32 where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgInt4 -> onOk (mkRunner col (PD.int :: PD.Value Int32))
+ Scalar PgInt2 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int16) fromIntegral)
+ _ -> mismatch pt "Int32" onErr
+
+instance FieldDecode PgRowEnv Int64 where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgInt8 -> onOk (mkRunner col (PD.int :: PD.Value Int64))
+ Scalar PgInt4 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int32) fromIntegral)
+ Scalar PgInt2 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int16) fromIntegral)
+ _ -> mismatch pt "Int64" onErr
+
+instance FieldDecode PgRowEnv Int where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgInt8 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int64) fromIntegral)
+ Scalar PgInt4 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int32) fromIntegral)
+ Scalar PgInt2 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int16) fromIntegral)
+ _ -> mismatch pt "Int" onErr
+
+instance FieldDecode PgRowEnv Double where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgFloat8 -> onOk (mkRunner col PD.float8)
+ Scalar PgFloat4 -> onOk (mkRunnerWith col PD.float4 realToFrac)
+ _ -> mismatch pt "Double" onErr
+
+instance FieldDecode PgRowEnv Scientific where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgNumeric -> onOk (mkRunner col PD.numeric)
+ _ -> mismatch pt "Scientific" onErr
+
+instance FieldDecode PgRowEnv Rational where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgNumeric -> onOk (mkRunnerWith col PD.numeric (toRational :: Scientific -> Rational))
+ Scalar PgMoney -> onOk (mkRunnerWith col (PD.int :: PD.Value Int64) fromIntegral)
+ _ -> mismatch pt "Rational" onErr
+
+instance FieldDecode PgRowEnv Text where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgText -> onOk (mkRunner col PD.text_strict)
+ Scalar PgVarchar -> onOk (mkRunner col PD.text_strict)
+ Scalar PgBpchar -> onOk (mkRunner col PD.text_strict)
+ Scalar PgChar -> onOk (mkRunner col PD.text_strict)
+ Scalar PgName -> onOk (mkRunner col PD.text_strict)
+ _ -> mismatch pt "Text" onErr
+
+instance FieldDecode PgRowEnv ByteString where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgBytea -> onOk (mkRunner col PD.bytea_strict)
+ Scalar PgXml -> onOk (mkRunner col PD.bytea_strict)
+ Scalar PgJson -> onOk (mkRunner col PD.bytea_strict)
+ Scalar PgJsonb -> onOk (mkRunner col (PD.jsonb_bytes Right))
+ Scalar PgBit -> onOk (mkRunner col PD.bytea_strict)
+ Scalar PgVarbit -> onOk (mkRunner col PD.bytea_strict)
+ Scalar PgUnknown -> onOk (mkRunner col PD.bytea_strict)
+ _ -> mismatch pt "ByteString" onErr
+
+instance FieldDecode PgRowEnv Day where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgDate -> onOk (mkRunner col PD.date)
+ _ -> mismatch pt "Day" onErr
+
+instance FieldDecode PgRowEnv TimeOfDay where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgTime -> onOk (mkRunner col PD.time_int)
+ _ -> mismatch pt "TimeOfDay" onErr
+
+instance FieldDecode PgRowEnv UTCTime where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ in case pt of
+ Scalar PgTimestamptz -> onOk (mkRunner col PD.timestamptz_int)
+ Scalar PgTimestamp -> onOk (mkRunnerWith col PD.timestamp_int (localTimeToUTC utc))
+ _ -> mismatch pt "UTCTime" onErr
+
+---------------------------------------------------------------------------
+-- Maybe wrapper: NULL -> Nothing, otherwise delegate
+---------------------------------------------------------------------------
+
+instance FieldDecode PgRowEnv a => FieldDecode PgRowEnv (Maybe a) where
+ prepareField env name col onErr onOk =
+ prepareField @PgRowEnv @a env name col onErr $ \inner ->
+ onOk $ FieldRunner $ \env' onErr' onOk' -> do
+ let (c, _) = pgCols env' V.! col
+ mbs <- LibPQ.getvalue' (pgResult env') (pgRow env') c
+ case mbs of
+ Nothing -> onOk' Nothing
+ Just _ -> runField inner env' onErr' (onOk' . Just)
+
+---------------------------------------------------------------------------
+-- Generic array instance via PgDecode
+---------------------------------------------------------------------------
+
+instance {-# OVERLAPPABLE #-} PgDecode a => FieldDecode PgRowEnv [a] where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ cache = pgRowCache env
+ decoder = runPgDecoder (pgDecoder @[a]) cache
+ in case pt of
+ Array _ -> onOk (mkRunner col decoder)
+ Unrecognized _ -> onOk (mkRunner col decoder)
+ _ -> mismatch pt "array" onErr
+
+instance {-# OVERLAPPING #-} PgDecode a => FieldDecode PgRowEnv [Maybe a] where
+ prepareField env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ cache = pgRowCache env
+ decoder = runPgDecoder (pgDecoder @[Maybe a]) cache
+ in case pt of
+ Array _ -> onOk (mkRunner col decoder)
+ Unrecognized _ -> onOk (mkRunner col decoder)
+ _ -> mismatch pt "array" onErr
+
+---------------------------------------------------------------------------
+-- Composite helper
+---------------------------------------------------------------------------
+
+-- | Build a 'FieldDecode' for a PostgreSQL composite type from its
+-- 'PgDecode' instance. Accepts any 'Unrecognized' OID (composites
+-- have dynamically-assigned OIDs) or a resolved 'Composite' name.
+--
+-- @
+-- instance FieldDecode PgRowEnv Address where
+-- prepareField = compositeFieldDecode
+-- @
+compositeFieldDecode
+ :: PgDecode a
+ => PgRowEnv -> FieldNameDB -> Int
+ -> (Text -> IO r) -> (FieldRunner PgRowEnv a -> IO r) -> IO r
+compositeFieldDecode env _ col onErr onOk =
+ let (_, pt) = pgCols env V.! col
+ cache = pgRowCache env
+ decoder = runPgDecoder pgDecoder cache
+ in case pt of
+ Composite _ _ -> onOk (mkRunner col decoder)
+ Unrecognized _ -> onOk (mkRunner col decoder)
+ _ -> mismatch pt "composite" onErr
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectEncode.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectEncode.hs
new file mode 100644
index 000000000..14fb7dafc
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectEncode.hs
@@ -0,0 +1,214 @@
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+
+-- | Direct parameter encoding for PostgreSQL, bypassing 'PersistValue'.
+--
+-- Each Haskell type is encoded directly to the binary wire format that
+-- @libpq@ expects, using @postgresql-binary@ encoders. No intermediate
+-- 'PersistValue' wrapper is allocated.
+--
+-- 'PgParam' is an unpacked ADT: 'PgNull' for SQL NULL,
+-- 'PgValue' for a typed value with OID, payload, and format.
+module Database.Persist.Postgresql.Internal.DirectEncode
+ ( -- * Parameter type
+ PgParam (..)
+ , pgParamToLibPQ
+ -- * Re-export for convenience
+ , module Database.Persist.DirectEncode
+ ) where
+
+import Data.ByteString (ByteString)
+import Data.Int (Int16, Int32, Int64)
+import Data.Scientific (Scientific, fromRationalRepetendUnlimited)
+import Data.Text (Text)
+import qualified Data.Text.Encoding as T
+import Data.Time (Day, TimeOfDay, UTCTime)
+import Data.Word (Word32)
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+import qualified PostgreSQL.Binary.Encoding as PE
+
+import Database.Persist (PersistValue (..), listToJSON, mapToJSON)
+import Database.Persist.DirectEncode
+import Database.Persist.Postgresql.Internal.Encoding (encodePersistValue)
+import Database.Persist.Postgresql.Internal.PgType
+import Database.Persist.Types (LiteralType (..))
+
+-- | A PostgreSQL query parameter in binary wire format.
+--
+-- Unpacked ADT to avoid the boxing overhead of
+-- @Maybe (Oid, ByteString, Format)@.
+data PgParam
+ = PgNull
+ | PgValue
+ {-# UNPACK #-} !LibPQ.Oid
+ !ByteString
+ !LibPQ.Format
+
+-- | Convert to the representation that @libpq@ expects.
+pgParamToLibPQ :: PgParam -> Maybe (LibPQ.Oid, ByteString, LibPQ.Format)
+pgParamToLibPQ PgNull = Nothing
+pgParamToLibPQ (PgValue oid bs fmt) = Just (oid, bs, fmt)
+{-# INLINE pgParamToLibPQ #-}
+
+bin :: PgScalar -> PE.Encoding -> PgParam
+bin s enc = PgValue (scalarOid s) (PE.encodingBytes enc) LibPQ.Binary
+{-# INLINE bin #-}
+
+-- Scalar types
+
+instance FieldEncode PgParam Bool where
+ encodeField = bin PgBool . PE.bool
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Int16 where
+ encodeField = bin PgInt2 . PE.int2_int16
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Int32 where
+ encodeField = bin PgInt4 . PE.int4_int32
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Int64 where
+ encodeField = bin PgInt8 . PE.int8_int64
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Int where
+ encodeField i
+ | i >= fromIntegral (minBound :: Int32)
+ && i <= fromIntegral (maxBound :: Int32)
+ = bin PgInt4 (PE.int4_int32 (fromIntegral i))
+ | otherwise
+ = bin PgInt8 (PE.int8_int64 (fromIntegral i))
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Double where
+ encodeField = bin PgFloat8 . PE.float8
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Scientific where
+ encodeField = bin PgNumeric . PE.numeric
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Rational where
+ encodeField r =
+ let (sci, _) = fromRationalRepetendUnlimited r
+ in bin PgNumeric (PE.numeric sci)
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Text where
+ encodeField = bin PgText . PE.text_strict
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam ByteString where
+ encodeField = bin PgBytea . PE.bytea_strict
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam Day where
+ encodeField = bin PgDate . PE.date
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam TimeOfDay where
+ encodeField = bin PgTime . PE.time_int
+ {-# INLINE encodeField #-}
+
+instance FieldEncode PgParam UTCTime where
+ encodeField = bin PgTimestamptz . PE.timestamptz_int
+ {-# INLINE encodeField #-}
+
+-- Nullable wrapper: Nothing becomes SQL NULL
+
+instance FieldEncode PgParam a => FieldEncode PgParam (Maybe a) where
+ encodeField Nothing = PgNull
+ encodeField (Just a) = encodeField a
+ {-# INLINE encodeField #-}
+
+-- | Backward-compatibility instance: encode a 'PersistValue' directly.
+-- Converts through the existing 'encodePersistValue' function.
+instance FieldEncode PgParam PersistValue where
+ encodeField pv = case encodePersistValue pv of
+ Nothing -> PgNull
+ Just (oid, bs, fmt) -> PgValue oid bs fmt
+ {-# INLINE encodeField #-}
+
+-- Array encoding for = ANY($1) patterns
+
+instance FieldEncode PgParam [PersistValue] where
+ encodeField xs = encodeAsArray xs
+
+-- | Encode a list of 'PersistValue' as a native PostgreSQL array.
+-- Reuses the logic from "Database.Persist.Postgresql.Internal.Encoding".
+encodeAsArray :: [PersistValue] -> PgParam
+encodeAsArray xs =
+ case inferElementType xs of
+ Just (elemOidW, arrOid, encElem) ->
+ let encoding = PE.array_foldable elemOidW encElem xs
+ in PgValue arrOid (PE.encodingBytes encoding) LibPQ.Binary
+ Nothing ->
+ PgValue (LibPQ.Oid 0) "{}" LibPQ.Text
+
+inferElementType
+ :: [PersistValue]
+ -> Maybe (Word32, LibPQ.Oid, PersistValue -> Maybe PE.Encoding)
+inferElementType = go
+ where
+ go [] = Nothing
+ go (PersistNull : rest) = go rest
+ go (PersistInt64 _ : _) = arrayEnc PgInt8 encInt
+ go (PersistText _ : _) = arrayEnc PgText encText
+ go (PersistBool _ : _) = arrayEnc PgBool encBool
+ go (PersistDouble _ : _) = arrayEnc PgFloat8 encDouble
+ go (PersistByteString _ : _) = arrayEnc PgBytea encBytea
+ go (PersistDay _ : _) = arrayEnc PgDate encDay
+ go (PersistTimeOfDay _ : _) = arrayEnc PgTime encTime
+ go (PersistUTCTime _ : _) = arrayEnc PgTimestamptz encTimestamptz
+ go (PersistRational _ : _) = arrayEnc PgNumeric encNumeric
+ go _ = Nothing
+
+ arrayEnc
+ :: PgScalar
+ -> (PersistValue -> Maybe PE.Encoding)
+ -> Maybe (Word32, LibPQ.Oid, PersistValue -> Maybe PE.Encoding)
+ arrayEnc s enc = do
+ arrOid <- arrayOid s
+ pure (scalarOidWord32 s, arrOid, enc)
+
+ encInt (PersistInt64 i) = Just (PE.int8_int64 i)
+ encInt PersistNull = Nothing
+ encInt _ = Nothing
+
+ encText (PersistText t) = Just (PE.text_strict t)
+ encText PersistNull = Nothing
+ encText _ = Nothing
+
+ encBool (PersistBool b) = Just (PE.bool b)
+ encBool PersistNull = Nothing
+ encBool _ = Nothing
+
+ encDouble (PersistDouble d) = Just (PE.float8 d)
+ encDouble PersistNull = Nothing
+ encDouble _ = Nothing
+
+ encBytea (PersistByteString bs) = Just (PE.bytea_strict bs)
+ encBytea PersistNull = Nothing
+ encBytea _ = Nothing
+
+ encDay (PersistDay d) = Just (PE.date d)
+ encDay PersistNull = Nothing
+ encDay _ = Nothing
+
+ encTime (PersistTimeOfDay t) = Just (PE.time_int t)
+ encTime PersistNull = Nothing
+ encTime _ = Nothing
+
+ encTimestamptz (PersistUTCTime t) = Just (PE.timestamptz_int t)
+ encTimestamptz PersistNull = Nothing
+ encTimestamptz _ = Nothing
+
+ encNumeric (PersistRational r) =
+ let (sci :: Scientific, _) = fromRationalRepetendUnlimited r
+ in Just (PE.numeric sci)
+ encNumeric PersistNull = Nothing
+ encNumeric _ = Nothing
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Encoding.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Encoding.hs
new file mode 100644
index 000000000..19c549c7a
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Encoding.hs
@@ -0,0 +1,179 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+
+-- | Encode 'PersistValue' to binary parameters for @LibPQ.execParams@.
+--
+-- Replaces the @P@ newtype and @ToField@ instance from the existing
+-- @persistent-postgresql@ package, using @postgresql-binary@ encoders
+-- directly instead of going through @postgresql-simple@.
+--
+-- 'PersistList' and 'PersistArray' are encoded as native PostgreSQL arrays
+-- (not JSONB), with element type inferred from the first non-null element.
+module Database.Persist.Postgresql.Internal.Encoding
+ ( encodePersistValue
+ ) where
+
+import Data.ByteString (ByteString)
+import Data.Int (Int32)
+import Data.Scientific (Scientific, fromRationalRepetendUnlimited)
+import qualified Data.Text.Encoding as T
+import Data.Word (Word32)
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+import Database.Persist (PersistValue (..), listToJSON, mapToJSON)
+import Database.Persist.Types (LiteralType (..))
+import qualified PostgreSQL.Binary.Encoding as PE
+
+import Database.Persist.Postgresql.Internal.PgType
+
+-- | Encode a 'PersistValue' to a libpq parameter triple.
+--
+-- Returns 'Nothing' for 'PersistNull'.
+-- Returns 'Just (oid, bytes, format)' for non-null values.
+--
+-- 'PersistLiteral_ Unescaped' values should have been inlined into the SQL
+-- text before calling this function. If encountered here, they are sent
+-- as text-format parameters as a fallback.
+encodePersistValue :: PersistValue -> Maybe (LibPQ.Oid, ByteString, LibPQ.Format)
+encodePersistValue PersistNull = Nothing
+encodePersistValue (PersistText t) =
+ Just $ bin PgText (PE.text_strict t)
+encodePersistValue (PersistByteString bs) =
+ Just $ bin PgBytea (PE.bytea_strict bs)
+encodePersistValue (PersistInt64 i)
+ | i >= fromIntegral (minBound :: Int32) && i <= fromIntegral (maxBound :: Int32)
+ = Just $ bin PgInt4 (PE.int4_int32 (fromIntegral i))
+ | otherwise
+ = Just $ bin PgInt8 (PE.int8_int64 i)
+encodePersistValue (PersistDouble d) =
+ Just $ bin PgFloat8 (PE.float8 d)
+encodePersistValue (PersistRational r) =
+ let (sci, _) = fromRationalRepetendUnlimited r
+ in Just $ bin PgNumeric (PE.numeric sci)
+encodePersistValue (PersistBool b) =
+ Just $ bin PgBool (PE.bool b)
+encodePersistValue (PersistDay d) =
+ Just $ bin PgDate (PE.date d)
+encodePersistValue (PersistTimeOfDay t) =
+ Just $ bin PgTime (PE.time_int t)
+encodePersistValue (PersistUTCTime t) =
+ Just $ bin PgTimestamptz (PE.timestamptz_int t)
+encodePersistValue (PersistList xs) =
+ -- PersistList is used by persistent for embedded entities stored in varchar
+ -- columns as JSON text. Send with OID 0 so PostgreSQL infers the type from
+ -- column context, handling both text/varchar and jsonb columns correctly.
+ -- PersistArray is used for native PostgreSQL arrays (e.g. from IN→ANY).
+ let jsonText = listToJSON xs
+ in Just (LibPQ.Oid 0, T.encodeUtf8 jsonText, LibPQ.Text)
+encodePersistValue (PersistArray xs) = encodeAsArray xs
+encodePersistValue (PersistMap m) =
+ -- Send with OID 0 (unknown) in text format so PostgreSQL infers the type
+ -- from column context. This handles both text/varchar columns (embedded
+ -- entities as JSON text) and jsonb columns. Using oidText (25) would fail
+ -- for jsonb columns since there is no implicit cast from text to jsonb.
+ let jsonText = mapToJSON m
+ in Just (LibPQ.Oid 0, T.encodeUtf8 jsonText, LibPQ.Text)
+encodePersistValue (PersistLiteral_ DbSpecific s) =
+ -- Send with OID 0 (unknown) in text format, letting PostgreSQL infer the
+ -- type from context. This matches postgresql-simple's Unknown behavior and
+ -- handles UUID, bytea, and other backend-specific types correctly.
+ Just (LibPQ.Oid 0, s, LibPQ.Text)
+encodePersistValue (PersistLiteral_ Escaped e) =
+ -- Same as DbSpecific: send with OID 0 in text format for PostgreSQL to
+ -- infer the type. Used for UUIDs and other opaque byte values.
+ Just (LibPQ.Oid 0, e, LibPQ.Text)
+encodePersistValue (PersistLiteral_ Unescaped l) =
+ -- Unescaped literals are raw SQL fragments. They should normally be
+ -- inlined into the SQL text by inlineUnescaped before reaching here.
+ -- As a fallback, send as text format.
+ Just (scalarOid PgText, l, LibPQ.Text)
+encodePersistValue (PersistObjectId _) =
+ error "Refusing to serialize a PersistObjectId to a PostgreSQL value"
+
+-- | Construct a binary-format parameter triple from a scalar type and encoding.
+bin :: PgScalar -> PE.Encoding -> (LibPQ.Oid, ByteString, LibPQ.Format)
+bin s enc = (scalarOid s, PE.encodingBytes enc, LibPQ.Binary)
+
+-- | Encode a list of 'PersistValue' as a native PostgreSQL array.
+--
+-- Infers the element type from the first non-null element. Falls back
+-- to a text[] with text-encoded values for empty or heterogeneous lists.
+encodeAsArray :: [PersistValue] -> Maybe (LibPQ.Oid, ByteString, LibPQ.Format)
+encodeAsArray xs =
+ case inferElementType xs of
+ Just (elemOidW, arrOid, encElem) ->
+ let encoding = PE.array_foldable elemOidW encElem xs
+ in Just (arrOid, PE.encodingBytes encoding, LibPQ.Binary)
+ Nothing ->
+ -- Empty list or all nulls: send as empty text representation with
+ -- OID 0 so PostgreSQL infers the array type from column context.
+ -- Using a specific array OID (e.g. text[] 1009) would fail for
+ -- non-text array columns like int8[], bool[], etc.
+ Just (LibPQ.Oid 0, "{}", LibPQ.Text)
+
+-- | Infer the PostgreSQL element type from the first non-null element
+-- of a list. Returns (element OID as Word32, array OID, element encoder).
+--
+-- Returns 'Nothing' if the list is empty or contains only nulls.
+inferElementType
+ :: [PersistValue]
+ -> Maybe (Word32, LibPQ.Oid, PersistValue -> Maybe PE.Encoding)
+inferElementType = go
+ where
+ go [] = Nothing
+ go (PersistNull : rest) = go rest
+ go (PersistInt64 _ : _) = arrayEnc PgInt8 encInt
+ go (PersistText _ : _) = arrayEnc PgText encText
+ go (PersistBool _ : _) = arrayEnc PgBool encBool
+ go (PersistDouble _ : _) = arrayEnc PgFloat8 encDouble
+ go (PersistByteString _ : _) = arrayEnc PgBytea encBytea
+ go (PersistDay _ : _) = arrayEnc PgDate encDay
+ go (PersistTimeOfDay _ : _) = arrayEnc PgTime encTime
+ go (PersistUTCTime _ : _) = arrayEnc PgTimestamptz encTimestamptz
+ go (PersistRational _ : _) = arrayEnc PgNumeric encNumeric
+ go _ = Nothing
+
+ arrayEnc
+ :: PgScalar
+ -> (PersistValue -> Maybe PE.Encoding)
+ -> Maybe (Word32, LibPQ.Oid, PersistValue -> Maybe PE.Encoding)
+ arrayEnc s enc = do
+ arrOid <- arrayOid s
+ pure (scalarOidWord32 s, arrOid, enc)
+
+ encInt (PersistInt64 i) = Just (PE.int8_int64 i)
+ encInt PersistNull = Nothing
+ encInt _ = Nothing
+
+ encText (PersistText t) = Just (PE.text_strict t)
+ encText PersistNull = Nothing
+ encText _ = Nothing
+
+ encBool (PersistBool b) = Just (PE.bool b)
+ encBool PersistNull = Nothing
+ encBool _ = Nothing
+
+ encDouble (PersistDouble d) = Just (PE.float8 d)
+ encDouble PersistNull = Nothing
+ encDouble _ = Nothing
+
+ encBytea (PersistByteString bs) = Just (PE.bytea_strict bs)
+ encBytea PersistNull = Nothing
+ encBytea _ = Nothing
+
+ encDay (PersistDay d) = Just (PE.date d)
+ encDay PersistNull = Nothing
+ encDay _ = Nothing
+
+ encTime (PersistTimeOfDay t) = Just (PE.time_int t)
+ encTime PersistNull = Nothing
+ encTime _ = Nothing
+
+ encTimestamptz (PersistUTCTime t) = Just (PE.timestamptz_int t)
+ encTimestamptz PersistNull = Nothing
+ encTimestamptz _ = Nothing
+
+ encNumeric (PersistRational r) =
+ let (sci :: Scientific, _) = fromRationalRepetendUnlimited r
+ in Just (PE.numeric sci)
+ encNumeric PersistNull = Nothing
+ encNumeric _ = Nothing
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Migration.hs
new file mode 100644
index 000000000..e78c44ad3
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Migration.hs
@@ -0,0 +1,1186 @@
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE NamedFieldPuns #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE ViewPatterns #-}
+
+-- | Generate postgresql migrations for a set of EntityDefs, either from scratch
+-- or based on the current state of a database.
+module Database.Persist.Postgresql.Internal.Migration where
+
+import Control.Arrow
+import Control.Monad
+import Control.Monad.Except
+import Control.Monad.IO.Class
+import Data.Acquire (with)
+import Data.Conduit
+import qualified Data.Conduit.List as CL
+import Data.Either (partitionEithers)
+import Data.FileEmbed (embedFileRelative)
+import Data.List as List
+import qualified Data.List.NonEmpty as NEL
+import Data.Map (Map)
+import qualified Data.Map as Map
+import Data.Maybe
+import Data.Set (Set)
+import qualified Data.Set as Set
+import Data.Text (Text)
+import qualified Data.Text as T
+import qualified Data.Text.Encoding as T
+import Data.Traversable
+import Database.Persist.Sql
+import qualified Database.Persist.Sql.Util as Util
+
+-- | Returns a structured representation of all of the
+-- DB changes required to migrate the Entity from its
+-- current state in the database to the state described in
+-- Haskell.
+--
+-- @since 2.17.1.0
+migrateStructured
+ :: BackendSpecificOverrides
+ -> [EntityDef]
+ -> (Text -> IO Statement)
+ -> EntityDef
+ -> IO (Either [Text] [AlterDB])
+migrateStructured overrides allDefs getter entity =
+ migrateEntitiesStructured overrides getter allDefs [entity]
+
+-- | Returns a structured representation of all of the DB changes required to
+-- migrate the listed entities from their current state in the database to the
+-- state described in Haskell. This function avoids N+1 queries, so if you
+-- have a lot of entities to migrate, it's much faster to use this rather than
+-- using 'migrateStructured' in a loop.
+--
+-- @since 2.14.1.0
+migrateEntitiesStructured
+ :: BackendSpecificOverrides
+ -> (Text -> IO Statement)
+ -> [EntityDef]
+ -> [EntityDef]
+ -> IO (Either [Text] [AlterDB])
+migrateEntitiesStructured overrides getStmt allDefs defsToMigrate = do
+ r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate)
+ pure $ case r of
+ Right schemaState ->
+ migrateEntitiesFromSchemaState overrides schemaState allDefs defsToMigrate
+ Left err ->
+ Left [err]
+
+-- | Returns a structured representation of all of the
+-- DB changes required to migrate the Entity to the state
+-- described in Haskell, assuming it currently does not
+-- exist in the database.
+--
+-- @since 2.17.1.0
+mockMigrateStructured
+ :: BackendSpecificOverrides
+ -> [EntityDef]
+ -> EntityDef
+ -> [AlterDB]
+mockMigrateStructured overrides allDefs entity =
+ migrateEntityFromSchemaState overrides EntityDoesNotExist allDefs entity
+
+-- | In order to ensure that generating migrations is fast and avoids N+1
+-- queries, we split it into two phases. The first phase involves querying the
+-- database to gather all of the information we need about the existing schema.
+-- The second phase then generates migrations based on the information from the
+-- first phase. This data type represents all of the data that's gathered during
+-- the first phase: information about the current state of the entities we're
+-- migrating in the database.
+newtype SchemaState = SchemaState (Map EntityNameDB EntitySchemaState)
+ deriving (Eq, Show)
+
+-- | The state of a particular entity (i.e. table) in the database; we generate
+-- migrations based on the diff of this versus an EntityDef.
+data EntitySchemaState
+ = -- | The table does not exist in the database
+ EntityDoesNotExist
+ | -- | The table does exist in the database
+ EntityExists ExistingEntitySchemaState
+ deriving (Eq, Show)
+
+-- | Information about an existing table in the database
+data ExistingEntitySchemaState = ExistingEntitySchemaState
+ { essColumns :: Map FieldNameDB (Column, (Set ColumnReference))
+ -- ^ The columns in this entity, together with the set of foreign key
+ -- constraints that they are subject to. Usually the ColumnReference list
+ -- will contain 0-1 elements, but in the event that there are multiple FK
+ -- constraints applying to a given column in the database we need to keep
+ -- track of them all because we don't yet know which one has the right name
+ -- (based on what is in the corresponding model's EntityDef).
+ --
+ -- Note that cReference will be unset for these columns, for the same reason:
+ -- there may be multiple FK constraints and we don't yet know which one to
+ -- use.
+ , essUniqueConstraints :: Map ConstraintNameDB [FieldNameDB]
+ -- ^ A map of unique constraint names to the columns that are affected by
+ -- those constraints.
+ }
+ deriving (Eq, Show)
+
+-- | Query a database in order to assemble a SchemaState containing information
+-- about each of the entities in the given list. Every entity name in the input
+-- should be present in the returned Map.
+collectSchemaState
+ :: (Text -> IO Statement) -> [EntityNameDB] -> IO (Either Text SchemaState)
+collectSchemaState getStmt entityNames = runExceptT $ do
+ existence <- getTableExistence getStmt entityNames
+ columns <- getColumnsWithoutReferences getStmt entityNames
+ constraints <- getConstraints getStmt entityNames
+ foreignKeyReferences <- getForeignKeyReferences getStmt entityNames
+
+ fmap (SchemaState . Map.fromList) $
+ for entityNames $ \entityNameDB -> do
+ tableExists <- case Map.lookup entityNameDB existence of
+ Just e -> pure e
+ Nothing ->
+ throwError
+ ("Missing entity name from existence map: " <> unEntityNameDB entityNameDB)
+
+ if tableExists
+ then do
+ essColumns <- case Map.lookup entityNameDB columns of
+ Just cols ->
+ pure $ Map.fromList $ flip map cols $ \c ->
+ ( cName c
+ ,
+ ( c
+ , fromMaybe Set.empty $
+ Map.lookup (cName c) =<< Map.lookup entityNameDB foreignKeyReferences
+ )
+ )
+ Nothing ->
+ throwError
+ ("Missing entity name from columns map: " <> unEntityNameDB entityNameDB)
+
+ let
+ essUniqueConstraints = fromMaybe Map.empty (Map.lookup entityNameDB constraints)
+ pure
+ ( entityNameDB
+ , EntityExists $ ExistingEntitySchemaState{essColumns, essUniqueConstraints}
+ )
+ else
+ pure
+ ( entityNameDB
+ , EntityDoesNotExist
+ )
+
+runStmt
+ :: (Show a)
+ => (Text -> IO Statement)
+ -> Text
+ -> [PersistValue]
+ -> ([PersistValue] -> a)
+ -> IO [a]
+runStmt getStmt sql values process = do
+ stmt <- getStmt sql
+ results <-
+ with
+ (stmtQuery stmt values)
+ (\src -> runConduit $ src .| CL.map process .| CL.consume)
+ pure results
+
+-- | Check for the existence of each of the input tables. The keys in the
+-- returned Map are exactly the entity names in the argument; True means the
+-- table exists.
+getTableExistence
+ :: (Text -> IO Statement)
+ -> [EntityNameDB]
+ -> ExceptT Text IO (Map EntityNameDB Bool)
+getTableExistence getStmt entityNames = do
+ results <-
+ liftIO $
+ runStmt
+ getStmt
+ getTableExistenceSql
+ [PersistArray (map (PersistText . unEntityNameDB) entityNames)]
+ processTable
+ case partitionEithers results of
+ ([], xs) ->
+ let
+ existing = Set.fromList xs
+ in
+ pure $ Map.fromList $ map (\n -> (n, Set.member n existing)) entityNames
+ (errs, _) -> throwError (T.intercalate "\n" errs)
+ where
+ getTableExistenceSql =
+ "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'"
+ <> " AND schemaname != 'information_schema' AND tablename=ANY (?)"
+
+ processTable :: [PersistValue] -> Either Text EntityNameDB
+ processTable resultRow = do
+ fmap EntityNameDB $
+ case resultRow of
+ [PersistText tableName] ->
+ pure tableName
+ [PersistByteString tableName] ->
+ pure (T.decodeUtf8 tableName)
+ other ->
+ throwError $ T.pack $ "Invalid result from information_schema: " ++ show other
+
+-- | Get all columns for the listed tables from the database, ignoring foreign
+-- key references (those are filled in later).
+getColumnsWithoutReferences
+ :: (Text -> IO Statement)
+ -> [EntityNameDB]
+ -> ExceptT Text IO (Map EntityNameDB [Column])
+getColumnsWithoutReferences getStmt entityNames = do
+ results <-
+ liftIO $
+ runStmt
+ getStmt
+ getColumnsSql
+ [PersistArray (map (PersistText . unEntityNameDB) entityNames)]
+ processColumn
+ case partitionEithers results of
+ ([], xs) -> pure $ Map.fromListWith (++) $ map (second (: [])) xs
+ (errs, _) -> throwError (T.intercalate "\n" errs)
+ where
+ getColumnsSql =
+ T.concat
+ [ "SELECT "
+ , "table_name "
+ , ",column_name "
+ , ",is_nullable "
+ , ",COALESCE(domain_name, udt_name)" -- See DOMAINS below
+ , ",column_default "
+ , ",generation_expression "
+ , ",numeric_precision "
+ , ",numeric_scale "
+ , ",character_maximum_length "
+ , "FROM information_schema.columns "
+ , "WHERE table_catalog=current_database() "
+ , "AND table_schema=current_schema() "
+ , "AND table_name=ANY (?) "
+ ]
+
+ -- DOMAINS Postgres supports the concept of domains, which are data types
+ -- with optional constraints. An app might make an "email" domain over the
+ -- varchar type, with a CHECK that the emails are valid In this case the
+ -- generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN
+ -- foo TYPE email This code exists to use the domain name (email), instead
+ -- of the underlying type (varchar). This is tested in
+ -- EquivalentTypeTest.hs
+ processColumn :: [PersistValue] -> Either Text (EntityNameDB, Column)
+ processColumn resultRow = do
+ case resultRow of
+ [ PersistText tableName
+ , PersistText columnName
+ , PersistText isNullable
+ , PersistText typeName
+ , defaultValue
+ , generationExpression
+ , numericPrecision
+ , numericScale
+ , maxlen
+ ] -> mapLeft (addErrorContext tableName columnName) $ do
+ defaultValue' <-
+ case defaultValue of
+ PersistNull ->
+ pure Nothing
+ PersistText t ->
+ pure $ Just t
+ _ ->
+ throwError $ T.pack $ "Invalid default column: " ++ show defaultValue
+ generationExpression' <-
+ case generationExpression of
+ PersistNull ->
+ pure Nothing
+ PersistText t ->
+ pure $ Just t
+ _ ->
+ throwError $ T.pack $ "Invalid generated column: " ++ show generationExpression
+ let
+ typeStr =
+ case maxlen of
+ PersistInt64 n ->
+ T.concat [typeName, "(", T.pack (show n), ")"]
+ _ ->
+ typeName
+
+ t <- getType numericPrecision numericScale typeStr
+
+ pure
+ ( EntityNameDB tableName
+ , Column
+ { cName = FieldNameDB columnName
+ , cNull = isNullable == "YES"
+ , cSqlType = t
+ , cDefault = fmap stripSuffixes defaultValue'
+ , cGenerated = fmap stripSuffixes generationExpression'
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ )
+ other ->
+ Left $
+ T.pack $
+ "Invalid result from information_schema: " ++ show other
+
+ stripSuffixes t =
+ loop'
+ [ "::character varying"
+ , "::text"
+ ]
+ where
+ loop' [] = t
+ loop' (p : ps) =
+ case T.stripSuffix p t of
+ Nothing -> loop' ps
+ Just t' -> t'
+
+ getType _ _ "int4" = pure SqlInt32
+ getType _ _ "int8" = pure SqlInt64
+ getType _ _ "varchar" = pure SqlString
+ getType _ _ "text" = pure SqlString
+ getType _ _ "date" = pure SqlDay
+ getType _ _ "bool" = pure SqlBool
+ getType _ _ "timestamptz" = pure SqlDayTime
+ getType _ _ "float4" = pure SqlReal
+ getType _ _ "float8" = pure SqlReal
+ getType _ _ "bytea" = pure SqlBlob
+ getType _ _ "time" = pure SqlTime
+ getType precision scale "numeric" = getNumeric precision scale
+ getType _ _ a = pure $ SqlOther a
+
+ getNumeric (PersistInt64 a) (PersistInt64 b) =
+ pure $ SqlNumeric (fromIntegral a) (fromIntegral b)
+ getNumeric PersistNull PersistNull =
+ throwError $
+ T.concat
+ [ "No precision and scale were specified. "
+ , "Postgres defaults to a maximum scale of 147,455 and precision of 16383,"
+ , " which is probably not what you intended."
+ , " Specify the values as numeric(total_digits, digits_after_decimal_place)."
+ ]
+ getNumeric a b =
+ throwError $
+ T.concat
+ [ "Can not get numeric field precision. "
+ , "Expected an integer for both precision and scale, "
+ , "got: "
+ , T.pack $ show a
+ , " and "
+ , T.pack $ show b
+ , ", respectively."
+ , " Specify the values as numeric(total_digits, digits_after_decimal_place)."
+ ]
+
+-- cyclist putting a stick into his own wheel meme
+addErrorContext :: Text -> Text -> Text -> Text
+addErrorContext tableName columnName originalMsg =
+ T.concat
+ [ "Error in column "
+ , tableName
+ , "."
+ , columnName
+ , ": "
+ , originalMsg
+ ]
+
+-- | Get all constraints for the listed tables from the database, except for foreign
+-- keys and primary keys (those go in the Column data type)
+getConstraints
+ :: (Text -> IO Statement)
+ -> [EntityNameDB]
+ -> ExceptT Text IO (Map EntityNameDB (Map ConstraintNameDB [FieldNameDB]))
+getConstraints getStmt entityNames = do
+ results <-
+ liftIO $
+ runStmt
+ getStmt
+ getConstraintsSql
+ [PersistArray (map (PersistText . unEntityNameDB) entityNames)]
+ processConstraint
+ case partitionEithers results of
+ ([], xs) -> pure $ Map.unionsWith (Map.unionWith (<>)) xs
+ (errs, _) -> throwError (T.intercalate "\n" errs)
+ where
+ getConstraintsSql =
+ T.concat
+ [ "SELECT "
+ , "c.table_name, "
+ , "c.constraint_name, "
+ , "c.column_name "
+ , "FROM information_schema.key_column_usage AS c, "
+ , "information_schema.table_constraints AS k "
+ , "WHERE c.table_catalog=current_database() "
+ , "AND c.table_catalog=k.table_catalog "
+ , "AND c.table_schema=current_schema() "
+ , "AND c.table_schema=k.table_schema "
+ , "AND c.table_name=ANY (?) "
+ , "AND c.table_name=k.table_name "
+ , "AND c.constraint_name=k.constraint_name "
+ , "AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') "
+ , "ORDER BY c.constraint_name, c.column_name"
+ ]
+
+ processConstraint
+ :: [PersistValue]
+ -> Either Text (Map EntityNameDB (Map ConstraintNameDB [FieldNameDB]))
+ processConstraint resultRow = do
+ (tableName, constraintName, columnName) <- case resultRow of
+ [PersistText tab, PersistText con, PersistText col] ->
+ pure (tab, con, col)
+ [PersistByteString tab, PersistByteString con, PersistByteString col] ->
+ pure (T.decodeUtf8 tab, T.decodeUtf8 con, T.decodeUtf8 col)
+ o ->
+ throwError $ T.pack $ "unexpected datatype returned for postgres o=" ++ show o
+
+ pure $
+ Map.singleton
+ (EntityNameDB tableName)
+ (Map.singleton (ConstraintNameDB constraintName) [FieldNameDB columnName])
+
+-- | Get foreign key constraint information for all columns in the supplied
+-- tables from the database. We return a list of references per column because
+-- there may be duplicate FK constraints in the database.
+--
+-- Note that we only care about FKs where the column in question has ordinal
+-- position 1 i.e. is the first column appearing in the FK constraint.
+-- Eventually we may want to fill this gap so that multi-column FK constraints
+-- can be dealt with by this migrator, but for now that is not something that
+-- persistent-postgresql handles.
+getForeignKeyReferences
+ :: (Text -> IO Statement)
+ -> [EntityNameDB]
+ -> ExceptT Text IO (Map EntityNameDB (Map FieldNameDB (Set ColumnReference)))
+getForeignKeyReferences getStmt entityNames = do
+ results <-
+ liftIO $
+ runStmt
+ getStmt
+ getForeignKeyReferencesSql
+ [PersistArray (map (PersistText . unEntityNameDB) entityNames)]
+ processForeignKeyReference
+ case partitionEithers results of
+ ([], xs) -> pure $ Map.unionsWith (Map.unionWith Set.union) xs
+ (errs, _) -> throwError (T.intercalate "\n" errs)
+ where
+ getForeignKeyReferencesSql = T.decodeUtf8 $(embedFileRelative "sql/getForeignKeyReferences.sql")
+
+ processForeignKeyReference
+ :: [PersistValue]
+ -> Either Text (Map EntityNameDB (Map FieldNameDB (Set ColumnReference)))
+ processForeignKeyReference resultRow = do
+ ( sourceTableName
+ , sourceColumnName
+ , refTableName
+ , constraintName
+ , updRule
+ , delRule
+ ) <-
+ case resultRow of
+ [ PersistText constrName
+ , PersistText srcTable
+ , PersistText refTable
+ , PersistText srcColumn
+ , PersistText _refColumn
+ , PersistText updRule
+ , PersistText delRule
+ ] ->
+ pure
+ ( EntityNameDB srcTable
+ , FieldNameDB srcColumn
+ , EntityNameDB refTable
+ , ConstraintNameDB constrName
+ , updRule
+ , delRule
+ )
+ other ->
+ throwError $ T.pack $ "unexpected row returned for postgres: " ++ show other
+
+ fcOnUpdate <- parseCascade updRule
+ fcOnDelete <- parseCascade delRule
+
+ let
+ columnRef =
+ ColumnReference
+ { crTableName = refTableName
+ , crConstraintName = constraintName
+ , crFieldCascade =
+ FieldCascade
+ { fcOnUpdate = Just fcOnUpdate
+ , fcOnDelete = Just fcOnDelete
+ }
+ }
+
+ pure $
+ Map.singleton
+ sourceTableName
+ (Map.singleton sourceColumnName (Set.singleton columnRef))
+
+-- Parse a cascade action as represented in pg_constraint
+parseCascade :: Text -> Either Text CascadeAction
+parseCascade txt =
+ case txt of
+ "a" ->
+ Right NoAction
+ "c" ->
+ Right Cascade
+ "n" ->
+ Right SetNull
+ "d" ->
+ Right SetDefault
+ "r" ->
+ Right Restrict
+ _ ->
+ Left $ "Unexpected value in parseCascade: " <> txt
+
+mapLeft :: (a1 -> a2) -> Either a1 b -> Either a2 b
+mapLeft _ (Right x) = Right x
+mapLeft f (Left x) = Left (f x)
+
+migrateEntitiesFromSchemaState
+ :: BackendSpecificOverrides
+ -> SchemaState
+ -> [EntityDef]
+ -> [EntityDef]
+ -> Either [Text] [AlterDB]
+migrateEntitiesFromSchemaState overrides (SchemaState schemaStateMap) allDefs defsToMigrate =
+ let
+ go :: EntityDef -> Either Text [AlterDB]
+ go entity = do
+ let
+ name = getEntityDBName entity
+ case Map.lookup name schemaStateMap of
+ Just entityState ->
+ Right $ migrateEntityFromSchemaState overrides entityState allDefs entity
+ Nothing ->
+ Left $ T.pack $ "No entry for entity in schemaState: " <> show name
+ in
+ case partitionEithers (map go defsToMigrate) of
+ ([], xs) -> Right (concat xs)
+ (errs, _) -> Left errs
+
+migrateEntityFromSchemaState
+ :: BackendSpecificOverrides
+ -> EntitySchemaState
+ -> [EntityDef]
+ -> EntityDef
+ -> [AlterDB]
+migrateEntityFromSchemaState overrides schemaState allDefs entity =
+ case schemaState of
+ EntityDoesNotExist ->
+ (addTable newcols entity) : uniques ++ references ++ foreignsAlt
+ EntityExists ExistingEntitySchemaState{essColumns, essUniqueConstraints} ->
+ let
+ (acs, ats) =
+ getAlters
+ allDefs
+ entity
+ (newcols, udspair)
+ ( map pickColumnReference (Map.elems essColumns)
+ , Map.toList essUniqueConstraints
+ )
+ acs' = map (AlterColumn name) acs
+ ats' = map (AlterTable name) ats
+ in
+ acs' ++ ats'
+ where
+ name = getEntityDBName entity
+ (newcols', udefs, fdefs) = postgresMkColumns overrides allDefs entity
+ newcols = filter (not . safeToRemove entity . cName) newcols'
+ udspair = map udToPair udefs
+
+ uniques = flip concatMap udspair $ \(uname, ucols) ->
+ [AlterTable name $ AddUniqueConstraint uname ucols]
+ references =
+ mapMaybe
+ ( \Column{cName, cReference} ->
+ getAddReference allDefs entity cName =<< cReference
+ )
+ newcols
+ foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs
+
+ -- HACK! This was added to preserve existing behaviour during a refactor.
+ -- The migrator currently expects to only see cReference set in the old
+ -- columns if it is also set in the new ones. It also ignores any existing
+ -- FK constraints in the database that don't match the expected FK
+ -- constraint name as defined by the Persistent EntityDef.
+ --
+ -- This means that the migrator sometimes behaves incorrectly for standalone
+ -- Foreign declarations, like Child in the ForeignKey test in
+ -- persistent-test, as well as in situations where there are duplicate FK
+ -- constraints for a given column.
+ --
+ -- See https://github.com/yesodweb/persistent/issues/1611#issuecomment-3613251095 for
+ -- more info
+ pickColumnReference (oldCol, oldReferences) =
+ case List.find (\c -> cName c == cName oldCol) newcols of
+ Just new -> fromMaybe oldCol $ do
+ -- Note that if this do block evaluates to Nothing, it means
+ -- we'll return a Column that has cReference = Nothing -
+ -- effectively, we are telling the migrator that this particular
+ -- column has no FK constraints in the DB.
+
+ -- If the persistent models don't define a FK constraint, ignore
+ -- any FK constraints that might exist in the DB (this is
+ -- arguably a bug, but it's a pre-existing one)
+ newRef <- cReference new
+
+ -- If the persistent models _do_ define an FK constraint but
+ -- there's no matching FK constraint in the DB, we don't have
+ -- to do anything else here: `getAlters` should handle adding
+ -- the FK constraint for us
+ oldRef <-
+ List.find
+ (\oldRef -> crConstraintName oldRef == crConstraintName newRef)
+ oldReferences
+
+ -- Finally, if the persistent models define an FK constraint and
+ -- an FK constraint of that name exists in the DB, return it, so
+ -- that `getAlters` can check that the constraint is set up
+ -- correctly
+ pure $ oldCol{cReference = Just oldRef}
+ Nothing ->
+ -- We have a column that exists in the DB but not in the
+ -- EntityDef. We can no-op here, since `getAlters` will handle
+ -- dropping this for us.
+ oldCol
+
+-- | Indicates whether a Postgres Column is safe to drop.
+--
+-- @since 2.17.1.0
+newtype SafeToRemove = SafeToRemove Bool
+ deriving (Show, Eq)
+
+-- | Represents a change to a Postgres column in a DB statement.
+--
+-- @since 2.17.1.0
+data AlterColumn
+ = ChangeType Column SqlType Text
+ | IsNull Column
+ | NotNull Column
+ | AddColumn Column
+ | Drop Column SafeToRemove
+ | Default Column Text
+ | NoDefault Column
+ | UpdateNullToValue Column Text
+ | AddReference
+ EntityNameDB
+ ConstraintNameDB
+ (NEL.NonEmpty FieldNameDB)
+ [Text]
+ FieldCascade
+ | DropReference ConstraintNameDB
+ deriving (Show, Eq)
+
+-- | Represents a change to a Postgres table in a DB statement.
+--
+-- @since 2.17.1.0
+data AlterTable
+ = AddUniqueConstraint ConstraintNameDB [FieldNameDB]
+ | DropConstraint ConstraintNameDB
+ deriving (Show, Eq)
+
+-- | Represents a change to a Postgres DB in a statement.
+--
+-- @since 2.17.1.0
+data AlterDB
+ = AddTable EntityNameDB EntityIdDef [Column]
+ | AlterColumn EntityNameDB AlterColumn
+ | AlterTable EntityNameDB AlterTable
+ deriving (Show, Eq)
+
+-- | Create a table if it doesn't exist.
+--
+-- @since 2.17.1.0
+addTable :: [Column] -> EntityDef -> AlterDB
+addTable cols entity =
+ AddTable name entityId nonIdCols
+ where
+ nonIdCols =
+ case entityPrimary entity of
+ Just _ ->
+ cols
+ _ ->
+ filter keepField cols
+ where
+ keepField c =
+ Just (cName c) /= fmap fieldDB (getEntityIdField entity)
+ && not (safeToRemove entity (cName c))
+ entityId = getEntityId entity
+ name = getEntityDBName entity
+
+maySerial :: SqlType -> Maybe Text -> Text
+maySerial SqlInt64 Nothing = " SERIAL8 "
+maySerial sType _ = " " <> showSqlType sType
+
+mayDefault :: Maybe Text -> Text
+mayDefault def = case def of
+ Nothing -> ""
+ Just d -> " DEFAULT " <> d
+
+getAlters
+ :: [EntityDef]
+ -> EntityDef
+ -> ([Column], [(ConstraintNameDB, [FieldNameDB])])
+ -> ([Column], [(ConstraintNameDB, [FieldNameDB])])
+ -> ([AlterColumn], [AlterTable])
+getAlters defs def (c1, u1) (c2, u2) =
+ (getAltersC c1 c2, getAltersU u1 u2)
+ where
+ getAltersC [] old =
+ map (\x -> Drop x $ SafeToRemove $ safeToRemove def $ cName x) old
+ getAltersC (new : news) old =
+ let
+ (alters, old') = findAlters defs def new old
+ in
+ alters ++ getAltersC news old'
+
+ getAltersU
+ :: [(ConstraintNameDB, [FieldNameDB])]
+ -> [(ConstraintNameDB, [FieldNameDB])]
+ -> [AlterTable]
+ getAltersU [] old =
+ map DropConstraint $ filter (not . isManual) $ map fst old
+ getAltersU ((name, cols) : news) old =
+ case lookup name old of
+ Nothing ->
+ AddUniqueConstraint name cols : getAltersU news old
+ Just ocols ->
+ let
+ old' = filter (\(x, _) -> x /= name) old
+ in
+ if sort cols == sort ocols
+ then getAltersU news old'
+ else
+ DropConstraint name
+ : AddUniqueConstraint name cols
+ : getAltersU news old'
+
+ -- Don't drop constraints which were manually added.
+ isManual (ConstraintNameDB x) = "__manual_" `T.isPrefixOf` x
+
+-- | Postgres' default maximum identifier length in bytes
+-- (You can re-compile Postgres with a new limit, but I'm assuming that virtually noone does this).
+-- See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
+maximumIdentifierLength :: Int
+maximumIdentifierLength = 63
+
+-- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer
+sqlTypeEq :: SqlType -> SqlType -> Bool
+sqlTypeEq x y =
+ let
+ -- Non exhaustive helper to map postgres aliases to the same name. Based on
+ -- https://www.postgresql.org/docs/9.5/datatype.html.
+ -- This prevents needless `ALTER TYPE`s when the type is the same.
+ normalize "int8" = "bigint"
+ normalize "serial8" = "bigserial"
+ normalize v = v
+ in
+ normalize (T.toCaseFold (showSqlType x))
+ == normalize (T.toCaseFold (showSqlType y))
+
+-- We check if we should alter a foreign key. This is almost an equality check,
+-- except we consider 'Nothing' and 'Just Restrict' equivalent.
+equivalentRef :: Maybe ColumnReference -> Maybe ColumnReference -> Bool
+equivalentRef Nothing Nothing = True
+equivalentRef (Just cr1) (Just cr2) =
+ crTableName cr1 == crTableName cr2
+ && crConstraintName cr1 == crConstraintName cr2
+ && eqCascade (fcOnUpdate $ crFieldCascade cr1) (fcOnUpdate $ crFieldCascade cr2)
+ && eqCascade (fcOnDelete $ crFieldCascade cr1) (fcOnDelete $ crFieldCascade cr2)
+ where
+ eqCascade :: Maybe CascadeAction -> Maybe CascadeAction -> Bool
+ eqCascade Nothing Nothing = True
+ eqCascade Nothing (Just Restrict) = True
+ eqCascade (Just Restrict) Nothing = True
+ eqCascade (Just cs1) (Just cs2) = cs1 == cs2
+ eqCascade _ _ = False
+equivalentRef _ _ = False
+
+-- | Generate the default foreign key constraint name for a given source table and
+-- source column name. Note that this function should generally not be used
+-- except as an argument to postgresMkColumns, because if you use it in other contexts,
+-- you're likely to miss nonstandard constraint names declared in the persistent
+-- models files via `constraint=`
+refName :: EntityNameDB -> FieldNameDB -> ConstraintNameDB
+refName (EntityNameDB table) (FieldNameDB column) =
+ let
+ overhead = T.length $ T.concat ["_", "_fkey"]
+ (fromTable, fromColumn) = shortenNames overhead (T.length table, T.length column)
+ in
+ ConstraintNameDB $
+ T.concat [T.take fromTable table, "_", T.take fromColumn column, "_fkey"]
+ where
+ -- Postgres automatically truncates too long foreign keys to a combination of
+ -- truncatedTableName + "_" + truncatedColumnName + "_fkey"
+ -- This works fine for normal use cases, but it creates an issue for Persistent
+ -- Because after running the migrations, Persistent sees the truncated foreign key constraint
+ -- doesn't have the expected name, and suggests that you migrate again
+ -- To workaround this, we copy the Postgres truncation approach before sending foreign key constraints to it.
+ --
+ -- I believe this will also be an issue for extremely long table names,
+ -- but it's just much more likely to exist with foreign key constraints because they're usually tablename * 2 in length
+
+ -- Approximation of the algorithm Postgres uses to truncate identifiers
+ -- See makeObjectName https://github.com/postgres/postgres/blob/5406513e997f5ee9de79d4076ae91c04af0c52f6/src/backend/commands/indexcmds.c#L2074-L2080
+ shortenNames :: Int -> (Int, Int) -> (Int, Int)
+ shortenNames overhead (x, y)
+ | x + y + overhead <= maximumIdentifierLength = (x, y)
+ | x > y = shortenNames overhead (x - 1, y)
+ | otherwise = shortenNames overhead (x, y - 1)
+
+postgresMkColumns
+ :: BackendSpecificOverrides
+ -> [EntityDef]
+ -> EntityDef
+ -> ([Column], [UniqueDef], [ForeignDef])
+postgresMkColumns overrides allDefs t =
+ mkColumns allDefs t $
+ setBackendSpecificForeignKeyName refName overrides
+
+-- | Check if a column name is listed as the "safe to remove" in the entity
+-- list.
+safeToRemove :: EntityDef -> FieldNameDB -> Bool
+safeToRemove def (FieldNameDB colName) =
+ any (elem FieldAttrSafeToRemove . fieldAttrs) $
+ filter ((== FieldNameDB colName) . fieldDB) $
+ allEntityFields
+ where
+ allEntityFields =
+ getEntityFieldsDatabase def <> case getEntityId def of
+ EntityIdField fdef ->
+ [fdef]
+ _ ->
+ []
+
+udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB])
+udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud)
+
+-- | Get the references to be added to a table for the given column.
+getAddReference
+ :: [EntityDef]
+ -> EntityDef
+ -> FieldNameDB
+ -> ColumnReference
+ -> Maybe AlterDB
+getAddReference allDefs entity cname cr@ColumnReference{crTableName = s, crConstraintName = constraintName} = do
+ guard $ Just cname /= fmap fieldDB (getEntityIdField entity)
+ pure $
+ AlterColumn
+ table
+ (AddReference s constraintName (cname NEL.:| []) id_ (crFieldCascade cr))
+ where
+ table = getEntityDBName entity
+ id_ =
+ fromMaybe
+ (error $ "Could not find ID of entity " ++ show s)
+ $ do
+ entDef <- find ((== s) . getEntityDBName) allDefs
+ return $ NEL.toList $ Util.dbIdColumnsEsc escapeF entDef
+
+mkForeignAlt
+ :: EntityDef
+ -> ForeignDef
+ -> Maybe AlterDB
+mkForeignAlt entity fdef = case NEL.nonEmpty childfields of
+ Nothing -> Nothing
+ Just childfields' -> Just $ AlterColumn tableName_ addReference
+ where
+ addReference =
+ AddReference
+ (foreignRefTableDBName fdef)
+ constraintName
+ childfields'
+ escapedParentFields
+ (foreignFieldCascade fdef)
+ where
+ tableName_ = getEntityDBName entity
+ constraintName =
+ foreignConstraintNameDBName fdef
+ (childfields, parentfields) =
+ unzip (map (\((_, b), (_, d)) -> (b, d)) (foreignFields fdef))
+ escapedParentFields =
+ map escapeF parentfields
+
+escapeC :: ConstraintNameDB -> Text
+escapeC = escapeWith escape
+
+escapeE :: EntityNameDB -> Text
+escapeE = escapeWith escape
+
+escapeF :: FieldNameDB -> Text
+escapeF = escapeWith escape
+
+escape :: Text -> Text
+escape s = T.concat ["\"", T.replace "\"" "\"\"" s, "\""]
+
+showAlterDb :: AlterDB -> (Bool, Text)
+showAlterDb (AddTable name entityId nonIdCols) = (False, rawText)
+ where
+ idtxt =
+ case entityId of
+ EntityIdNaturalKey pdef ->
+ T.concat
+ [ " PRIMARY KEY ("
+ , T.intercalate "," $ map (escapeF . fieldDB) $ NEL.toList $ compositeFields pdef
+ , ")"
+ ]
+ EntityIdField field ->
+ let
+ defText = defaultAttribute $ fieldAttrs field
+ sType = fieldSqlType field
+ in
+ T.concat
+ [ escapeF $ fieldDB field
+ , maySerial sType defText
+ , " PRIMARY KEY UNIQUE"
+ , mayDefault defText
+ ]
+ rawText =
+ T.concat
+ -- Lower case e: see Database.Persist.Sql.Migration
+ [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION!
+ , escapeE name
+ , "("
+ , idtxt
+ , if null nonIdCols then "" else ","
+ , T.intercalate "," $ map showColumn nonIdCols
+ , ")"
+ ]
+showAlterDb (AlterColumn t ac) =
+ (isUnsafe ac, showAlter t ac)
+ where
+ isUnsafe (Drop _ (SafeToRemove safeRemove)) = not safeRemove
+ isUnsafe _ = False
+showAlterDb (AlterTable t at) = (False, showAlterTable t at)
+
+showAlterTable :: EntityNameDB -> AlterTable -> Text
+showAlterTable table (AddUniqueConstraint cname cols) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " ADD CONSTRAINT "
+ , escapeC cname
+ , " UNIQUE("
+ , T.intercalate "," $ map escapeF cols
+ , ")"
+ ]
+showAlterTable table (DropConstraint cname) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " DROP CONSTRAINT "
+ , escapeC cname
+ ]
+
+showAlter :: EntityNameDB -> AlterColumn -> Text
+showAlter table (ChangeType c t extra) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " ALTER COLUMN "
+ , escapeF (cName c)
+ , " TYPE "
+ , showSqlType t
+ , extra
+ ]
+showAlter table (IsNull c) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " ALTER COLUMN "
+ , escapeF (cName c)
+ , " DROP NOT NULL"
+ ]
+showAlter table (NotNull c) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " ALTER COLUMN "
+ , escapeF (cName c)
+ , " SET NOT NULL"
+ ]
+showAlter table (AddColumn col) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " ADD COLUMN "
+ , showColumn col
+ ]
+showAlter table (Drop c _) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " DROP COLUMN "
+ , escapeF (cName c)
+ ]
+showAlter table (Default c s) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " ALTER COLUMN "
+ , escapeF (cName c)
+ , " SET DEFAULT "
+ , s
+ ]
+showAlter table (NoDefault c) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " ALTER COLUMN "
+ , escapeF (cName c)
+ , " DROP DEFAULT"
+ ]
+showAlter table (UpdateNullToValue c s) =
+ T.concat
+ [ "UPDATE "
+ , escapeE table
+ , " SET "
+ , escapeF (cName c)
+ , "="
+ , s
+ , " WHERE "
+ , escapeF (cName c)
+ , " IS NULL"
+ ]
+showAlter table (AddReference reftable fkeyname t2 id2 cascade) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " ADD CONSTRAINT "
+ , escapeC fkeyname
+ , " FOREIGN KEY("
+ , T.intercalate "," $ map escapeF $ NEL.toList t2
+ , ") REFERENCES "
+ , escapeE reftable
+ , "("
+ , T.intercalate "," id2
+ , ")"
+ ]
+ <> renderFieldCascade cascade
+showAlter table (DropReference cname) =
+ T.concat
+ [ "ALTER TABLE "
+ , escapeE table
+ , " DROP CONSTRAINT "
+ , escapeC cname
+ ]
+
+showColumn :: Column -> Text
+showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) =
+ T.concat
+ [ escapeF n
+ , " "
+ , showSqlType sqlType'
+ , " "
+ , if nu then "NULL" else "NOT NULL"
+ , case def of
+ Nothing -> ""
+ Just s -> " DEFAULT " <> s
+ , case gen of
+ Nothing -> ""
+ Just s -> " GENERATED ALWAYS AS (" <> s <> ") STORED"
+ ]
+
+showSqlType :: SqlType -> Text
+showSqlType SqlString = "VARCHAR"
+showSqlType SqlInt32 = "INT4"
+showSqlType SqlInt64 = "INT8"
+showSqlType SqlReal = "DOUBLE PRECISION"
+showSqlType (SqlNumeric s prec) = T.concat ["NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")"]
+showSqlType SqlDay = "DATE"
+showSqlType SqlTime = "TIME"
+showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE"
+showSqlType SqlBlob = "BYTEA"
+showSqlType SqlBool = "BOOLEAN"
+-- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682
+showSqlType (SqlOther (T.toLower -> "integer")) = "INT4"
+showSqlType (SqlOther t) = t
+
+findAlters
+ :: [EntityDef]
+ -- ^ The list of all entity definitions that persistent is aware of.
+ -> EntityDef
+ -- ^ The entity definition for the entity that we're working on.
+ -> Column
+ -- ^ The column that we're searching for potential alterations for, derived
+ -- from the Persistent EntityDef. That is: this is how we _want_ the column
+ -- to look, and not necessarily how it actually looks in the database right
+ -- now.
+ -> [Column]
+ -- ^ The columns for this table, as they currently exist in the database.
+ -> ([AlterColumn], [Column])
+findAlters defs edef newCol oldCols =
+ case List.find (\c -> cName c == cName newCol) oldCols of
+ Nothing ->
+ ([AddColumn newCol] ++ refAdd (cReference newCol), oldCols)
+ Just
+ oldCol ->
+ let
+ refDrop Nothing = []
+ refDrop (Just ColumnReference{crConstraintName = cname}) =
+ [DropReference cname]
+
+ modRef =
+ if equivalentRef (cReference oldCol) (cReference newCol)
+ then []
+ else refDrop (cReference oldCol) ++ refAdd (cReference newCol)
+ modNull = case (cNull newCol, cNull oldCol) of
+ (True, False) -> do
+ guard $ Just (cName newCol) /= fmap fieldDB (getEntityIdField edef)
+ pure (IsNull newCol)
+ (False, True) ->
+ let
+ up = case cDefault newCol of
+ Nothing -> id
+ Just s -> (:) (UpdateNullToValue newCol s)
+ in
+ up [NotNull newCol]
+ _ -> []
+ modType
+ | sqlTypeEq (cSqlType newCol) (cSqlType oldCol) = []
+ -- When converting from Persistent pre-2.0 databases, we
+ -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is
+ -- treated as UTC.
+ | cSqlType newCol == SqlDayTime && cSqlType oldCol == SqlOther "timestamp" =
+ [ ChangeType newCol (cSqlType newCol) $
+ T.concat
+ [ " USING "
+ , escapeF (cName newCol)
+ , " AT TIME ZONE 'UTC'"
+ ]
+ ]
+ | otherwise = [ChangeType newCol (cSqlType newCol) ""]
+ modDef =
+ if cDefault newCol == cDefault oldCol
+ || isJust (T.stripPrefix "nextval" =<< cDefault oldCol)
+ then []
+ else case cDefault newCol of
+ Nothing -> [NoDefault newCol]
+ Just s -> [Default newCol s]
+ dropSafe =
+ if safeToRemove edef (cName newCol)
+ then error "wtf" [Drop newCol (SafeToRemove True)]
+ else []
+ in
+ ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe
+ , filter (\c -> cName c /= cName newCol) oldCols
+ )
+ where
+ refAdd Nothing = []
+ -- This check works around a bug where persistent will sometimes
+ -- generate an erroneous ForeignRef for ID fields.
+ -- See: https://github.com/yesodweb/persistent/issues/1615
+ refAdd _ | fmap fieldDB (getEntityIdField edef) == Just (cName newCol) = []
+ refAdd (Just colRef) =
+ case find ((== crTableName colRef) . getEntityDBName) defs of
+ Just refdef ->
+ [ AddReference
+ (crTableName colRef)
+ (crConstraintName colRef)
+ (cName newCol NEL.:| [])
+ (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef)
+ (crFieldCascade colRef)
+ ]
+ Nothing ->
+ error $
+ "could not find the entityDef for reftable["
+ ++ show (crTableName colRef)
+ ++ "]"
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgCodec.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgCodec.hs
new file mode 100644
index 000000000..46b3016ee
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgCodec.hs
@@ -0,0 +1,367 @@
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+-- | Composable binary codecs for PostgreSQL types.
+--
+-- 'PgDecode' and 'PgEncode' form the /value-level/ codec layer: they
+-- operate on raw binary payloads ('PD.Value' \/ 'PE.Encoding') and
+-- compose through PostgreSQL's compound-type wire formats (composites,
+-- arrays, ranges).
+--
+-- The column-level layer ('FieldDecode' \/ 'FieldEncode') delegates
+-- to these for compound types while handling OID dispatch itself.
+--
+-- Both layers receive an 'OidCache' so they can resolve
+-- dynamically-assigned OIDs (composites, enums, domains).
+module Database.Persist.Postgresql.Internal.PgCodec
+ ( -- * Decode
+ -- ** Reader types
+ PgDecoder (..)
+ , PgComposite (..)
+ -- ** Class
+ , PgDecode (..)
+ -- ** Combinators
+ , pgValue
+ , pgComposite
+ , pgField
+ , pgFieldNullable
+ , pgArray
+ , pgArrayNullable
+
+ -- * Encode
+ -- ** Reader types
+ , PgEncoder (..)
+ -- ** Class
+ , PgEncode (..)
+ -- ** Combinators
+ , pgConst
+ , pgEncodeField
+ , pgArrayEncoder
+ ) where
+
+import Control.Monad (replicateM)
+import Data.ByteString (ByteString)
+import Data.Int (Int16, Int32, Int64)
+import Data.Proxy (Proxy (..))
+import Data.Scientific (Scientific, fromRationalRepetendUnlimited)
+import Data.Text (Text)
+import Data.Time (Day, TimeOfDay, UTCTime)
+import Data.Word (Word32)
+import qualified PostgreSQL.Binary.Decoding as PD
+import qualified PostgreSQL.Binary.Encoding as PE
+
+import Database.Persist.Postgresql.Internal.PgType
+
+---------------------------------------------------------------------------
+-- Decode: reader types
+---------------------------------------------------------------------------
+
+-- | A value-level binary decoder parameterised by 'OidCache'.
+--
+-- Wraps @postgresql-binary@'s 'PD.Value' in a reader so the cache is
+-- available to compound decoders (composites that want to validate
+-- sub-field OIDs, etc.) without threading it manually.
+newtype PgDecoder a = PgDecoder { runPgDecoder :: OidCache -> PD.Value a }
+
+instance Functor PgDecoder where
+ fmap f (PgDecoder g) = PgDecoder $ \c -> fmap f (g c)
+ {-# INLINE fmap #-}
+
+instance Applicative PgDecoder where
+ pure a = PgDecoder $ \_ -> pure a
+ {-# INLINE pure #-}
+ PgDecoder f <*> PgDecoder a = PgDecoder $ \c -> f c <*> a c
+ {-# INLINE (<*>) #-}
+
+instance Monad PgDecoder where
+ PgDecoder ma >>= k = PgDecoder $ \c -> do
+ a <- ma c
+ runPgDecoder (k a) c
+ {-# INLINE (>>=) #-}
+
+-- | A composite-field decoder parameterised by 'OidCache'.
+--
+-- Wraps @postgresql-binary@'s 'PD.Composite' applicative. Build one
+-- with 'pgField' \/ 'pgFieldNullable', then convert to a 'PgDecoder'
+-- with 'pgComposite'.
+newtype PgComposite a = PgComposite (OidCache -> PD.Composite a)
+
+instance Functor PgComposite where
+ fmap f (PgComposite g) = PgComposite $ \c -> fmap f (g c)
+ {-# INLINE fmap #-}
+
+instance Applicative PgComposite where
+ pure a = PgComposite $ \_ -> pure a
+ {-# INLINE pure #-}
+ PgComposite f <*> PgComposite a = PgComposite $ \c -> f c <*> a c
+ {-# INLINE (<*>) #-}
+
+---------------------------------------------------------------------------
+-- Decode: combinators
+---------------------------------------------------------------------------
+
+-- | Lift a raw @postgresql-binary@ decoder that doesn't need the cache.
+pgValue :: PD.Value a -> PgDecoder a
+pgValue v = PgDecoder $ \_ -> v
+{-# INLINE pgValue #-}
+
+-- | Build a 'PgDecoder' for a PostgreSQL composite type from its fields.
+--
+-- @
+-- instance PgDecode Address where
+-- pgDecoder = pgComposite $ Address
+-- \<$\> pgField pgDecoder
+-- \<*\> pgField pgDecoder
+-- \<*\> pgField pgDecoder
+-- @
+pgComposite :: PgComposite a -> PgDecoder a
+pgComposite (PgComposite f) = PgDecoder $ \c -> PD.composite (f c)
+{-# INLINE pgComposite #-}
+
+-- | Decode a non-nullable field inside a composite.
+pgField :: PgDecoder a -> PgComposite a
+pgField (PgDecoder f) = PgComposite $ \c ->
+ PD.valueComposite (f c)
+{-# INLINE pgField #-}
+
+-- | Decode a nullable field inside a composite.
+pgFieldNullable :: PgDecoder a -> PgComposite (Maybe a)
+pgFieldNullable (PgDecoder f) = PgComposite $ \c ->
+ PD.nullableValueComposite (f c)
+{-# INLINE pgFieldNullable #-}
+
+-- | Decode a one-dimensional PostgreSQL array of non-nullable elements.
+pgArray :: PgDecoder a -> PgDecoder [a]
+pgArray (PgDecoder f) = PgDecoder $ \c ->
+ PD.array $ PD.dimensionArray replicateM (PD.valueArray (f c))
+{-# INLINE pgArray #-}
+
+-- | Decode a one-dimensional PostgreSQL array of nullable elements.
+pgArrayNullable :: PgDecoder a -> PgDecoder [Maybe a]
+pgArrayNullable (PgDecoder f) = PgDecoder $ \c ->
+ PD.array $ PD.dimensionArray replicateM (PD.nullableValueArray (f c))
+{-# INLINE pgArrayNullable #-}
+
+---------------------------------------------------------------------------
+-- Decode: class
+---------------------------------------------------------------------------
+
+-- | Value-level binary decoder for a Haskell type.
+--
+-- Scalar instances wrap the corresponding @postgresql-binary@ decoder.
+-- Compound instances compose via 'pgComposite', 'pgField', 'pgArray'.
+class PgDecode a where
+ pgDecoder :: PgDecoder a
+
+-- Scalars
+
+instance PgDecode Bool where
+ pgDecoder = pgValue PD.bool
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Int16 where
+ pgDecoder = pgValue PD.int
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Int32 where
+ pgDecoder = pgValue PD.int
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Int64 where
+ pgDecoder = pgValue PD.int
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Int where
+ pgDecoder = pgValue (fromIntegral <$> (PD.int :: PD.Value Int64))
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Double where
+ pgDecoder = pgValue PD.float8
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Float where
+ pgDecoder = pgValue PD.float4
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Scientific where
+ pgDecoder = pgValue PD.numeric
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Rational where
+ pgDecoder = pgValue (toRational <$> PD.numeric)
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Text where
+ pgDecoder = pgValue PD.text_strict
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode ByteString where
+ pgDecoder = pgValue PD.bytea_strict
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode Day where
+ pgDecoder = pgValue PD.date
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode TimeOfDay where
+ pgDecoder = pgValue PD.time_int
+ {-# INLINE pgDecoder #-}
+
+instance PgDecode UTCTime where
+ pgDecoder = pgValue PD.timestamptz_int
+ {-# INLINE pgDecoder #-}
+
+-- Compound
+
+instance PgDecode a => PgDecode [a] where
+ pgDecoder = pgArray pgDecoder
+ {-# INLINE pgDecoder #-}
+
+instance {-# OVERLAPPING #-} PgDecode a => PgDecode [Maybe a] where
+ pgDecoder = pgArrayNullable pgDecoder
+ {-# INLINE pgDecoder #-}
+
+---------------------------------------------------------------------------
+-- Encode: reader types
+---------------------------------------------------------------------------
+
+-- | A value-level binary encoder parameterised by 'OidCache'.
+--
+-- The cache is needed for composite and enum types whose OIDs are
+-- assigned dynamically by PostgreSQL.
+newtype PgEncoder a = PgEncoder { runPgEncoder :: OidCache -> a -> PE.Encoding }
+
+instance Semigroup (PgEncoder a) where
+ PgEncoder f <> PgEncoder g = PgEncoder $ \c a ->
+ f c a <> g c a
+ {-# INLINE (<>) #-}
+
+-- | Lift a pure encoder that doesn't need the cache.
+pgConst :: (a -> PE.Encoding) -> PgEncoder a
+pgConst f = PgEncoder $ \_ -> f
+{-# INLINE pgConst #-}
+
+---------------------------------------------------------------------------
+-- Encode: combinators
+---------------------------------------------------------------------------
+
+-- | Encode a value as a composite field (OID + payload).
+pgEncodeField :: forall a. PgEncode a => OidCache -> a -> PE.Composite
+pgEncodeField cache a =
+ PE.field (pgTypeOid' cache (Proxy @a)) (runPgEncoder pgEncoder cache a)
+{-# INLINE pgEncodeField #-}
+
+-- | Build a 'PgEncoder' for a one-dimensional array of non-nullable elements.
+pgArrayEncoder :: forall a. PgEncode a => PgEncoder [a]
+pgArrayEncoder = PgEncoder $ \cache xs ->
+ PE.array_foldable
+ (pgTypeOid' cache (Proxy @a))
+ (\x -> Just (runPgEncoder pgEncoder cache x))
+ xs
+{-# INLINE pgArrayEncoder #-}
+
+-- | Build a 'PgEncoder' for a composite type from a function that
+-- produces a 'PE.Composite' builder.
+pgCompositeEncoder :: (OidCache -> a -> PE.Composite) -> PgEncoder a
+pgCompositeEncoder f = PgEncoder $ \cache a -> PE.composite (f cache a)
+{-# INLINE pgCompositeEncoder #-}
+
+---------------------------------------------------------------------------
+-- Encode: class
+---------------------------------------------------------------------------
+
+-- | Value-level binary encoder for a Haskell type.
+--
+-- 'pgEncoder' produces the binary payload. 'pgTypeOid'' returns the
+-- PostgreSQL OID for the type, needed when this value appears as an
+-- element inside a composite or array.
+class PgEncode a where
+ pgEncoder :: PgEncoder a
+ pgTypeOid' :: OidCache -> Proxy a -> Word32
+
+-- Scalars
+
+instance PgEncode Bool where
+ pgEncoder = pgConst PE.bool
+ pgTypeOid' _ _ = scalarOidWord32 PgBool
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Int16 where
+ pgEncoder = pgConst PE.int2_int16
+ pgTypeOid' _ _ = scalarOidWord32 PgInt2
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Int32 where
+ pgEncoder = pgConst PE.int4_int32
+ pgTypeOid' _ _ = scalarOidWord32 PgInt4
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Int64 where
+ pgEncoder = pgConst PE.int8_int64
+ pgTypeOid' _ _ = scalarOidWord32 PgInt8
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Int where
+ pgEncoder = pgConst $ \i -> PE.int8_int64 (fromIntegral i)
+ pgTypeOid' _ _ = scalarOidWord32 PgInt8
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Double where
+ pgEncoder = pgConst PE.float8
+ pgTypeOid' _ _ = scalarOidWord32 PgFloat8
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Float where
+ pgEncoder = pgConst PE.float4
+ pgTypeOid' _ _ = scalarOidWord32 PgFloat4
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Scientific where
+ pgEncoder = pgConst PE.numeric
+ pgTypeOid' _ _ = scalarOidWord32 PgNumeric
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Rational where
+ pgEncoder = pgConst $ \r ->
+ let (sci, _) = fromRationalRepetendUnlimited r
+ in PE.numeric sci
+ pgTypeOid' _ _ = scalarOidWord32 PgNumeric
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Text where
+ pgEncoder = pgConst PE.text_strict
+ pgTypeOid' _ _ = scalarOidWord32 PgText
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode ByteString where
+ pgEncoder = pgConst PE.bytea_strict
+ pgTypeOid' _ _ = scalarOidWord32 PgBytea
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode Day where
+ pgEncoder = pgConst PE.date
+ pgTypeOid' _ _ = scalarOidWord32 PgDate
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode TimeOfDay where
+ pgEncoder = pgConst PE.time_int
+ pgTypeOid' _ _ = scalarOidWord32 PgTime
+ {-# INLINE pgEncoder #-}
+
+instance PgEncode UTCTime where
+ pgEncoder = pgConst PE.timestamptz_int
+ pgTypeOid' _ _ = scalarOidWord32 PgTimestamptz
+ {-# INLINE pgEncoder #-}
+
+-- Compound
+
+instance PgEncode a => PgEncode [a] where
+ pgEncoder = pgArrayEncoder
+ pgTypeOid' cache _ = case arrayOidWord32 (pgTypeOid' cache (Proxy @a)) of
+ Just oid -> oid
+ Nothing -> 0
+ {-# INLINE pgEncoder #-}
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgType.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgType.hs
new file mode 100644
index 000000000..c792b691d
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgType.hs
@@ -0,0 +1,274 @@
+{-# LANGUAGE ViewPatterns #-}
+
+-- | Unified PostgreSQL type\/OID classification.
+--
+-- Single source of truth for the mapping between PostgreSQL type OIDs and
+-- the Haskell-side 'PgType' representation. Both "Encoding" and "Decoding"
+-- modules reference this module instead of maintaining their own OID tables.
+--
+-- 'fromOid' is designed for use as a view pattern:
+--
+-- @
+-- f (fromOid -> Scalar PgBool) = …
+-- f (fromOid -> Array PgInt8) = …
+-- f (fromOid -> Unrecognized _) = …
+-- @
+module Database.Persist.Postgresql.Internal.PgType
+ ( -- * Type classification
+ PgScalar (..)
+ , PgType (..)
+ -- * OID → PgType (view pattern)
+ , fromOid
+ -- * PgType → OID
+ , pgTypeOid
+ , scalarOid
+ , scalarOidWord32
+ , arrayOid
+ , arrayOidWord32
+ -- * OID cache for custom types
+ , OidCache
+ , emptyOidCache
+ , resolveOid
+ ) where
+
+import Data.IntMap.Strict (IntMap)
+import qualified Data.IntMap.Strict as IntMap
+import Data.Text (Text)
+import Data.Word (Word32)
+import Foreign.C.Types (CUInt)
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+
+-- | Known PostgreSQL scalar types.
+data PgScalar
+ = PgBool
+ | PgBytea
+ | PgChar
+ | PgName
+ | PgInt2
+ | PgInt4
+ | PgInt8
+ | PgFloat4
+ | PgFloat8
+ | PgText
+ | PgXml
+ | PgMoney
+ | PgBpchar
+ | PgVarchar
+ | PgDate
+ | PgTime
+ | PgTimestamp
+ | PgTimestamptz
+ | PgInterval
+ | PgBit
+ | PgVarbit
+ | PgNumeric
+ | PgVoid
+ | PgJson
+ | PgJsonb
+ | PgUnknown
+ -- ^ The PostgreSQL @unknown@ pseudotype (OID 705), used for untyped
+ -- string literals. Not to be confused with 'Unrecognized'.
+ | PgUuid
+ deriving (Eq, Ord, Show)
+
+-- | Classified PostgreSQL type: a known scalar, a known array of a scalar,
+-- a user-defined composite\/enum, or an OID we don't have a built-in
+-- mapping for.
+data PgType
+ = Scalar !PgScalar
+ | Array !PgScalar
+ | Composite !Text !Int
+ -- ^ User-defined composite type. Carries the type name and the
+ -- array-type OID (0 if unknown), looked up from @pg_type@ at
+ -- connection time.
+ | CompositeArray !Text
+ -- ^ Array of a user-defined composite. Carries the element type name.
+ | Enum !Text !Int
+ -- ^ User-defined enum. Carries the type name and the array-type OID.
+ | EnumArray !Text
+ -- ^ Array of a user-defined enum. Carries the element type name.
+ | Unrecognized !Int
+ deriving (Eq, Show)
+
+---------------------------------------------------------------------------
+-- OID → PgType
+---------------------------------------------------------------------------
+
+-- | Classify a @libpq@ 'LibPQ.Oid' into a 'PgType'.
+--
+-- Intended for use as a view pattern so that callers can pattern-match on
+-- structured types rather than raw integers.
+fromOid :: LibPQ.Oid -> PgType
+fromOid (LibPQ.Oid n) = fromCUInt n
+
+fromCUInt :: CUInt -> PgType
+fromCUInt 16 = Scalar PgBool
+fromCUInt 17 = Scalar PgBytea
+fromCUInt 18 = Scalar PgChar
+fromCUInt 19 = Scalar PgName
+fromCUInt 20 = Scalar PgInt8
+fromCUInt 21 = Scalar PgInt2
+fromCUInt 23 = Scalar PgInt4
+fromCUInt 25 = Scalar PgText
+fromCUInt 114 = Scalar PgJson
+fromCUInt 142 = Scalar PgXml
+fromCUInt 700 = Scalar PgFloat4
+fromCUInt 701 = Scalar PgFloat8
+fromCUInt 705 = Scalar PgUnknown
+fromCUInt 790 = Scalar PgMoney
+fromCUInt 1042 = Scalar PgBpchar
+fromCUInt 1043 = Scalar PgVarchar
+fromCUInt 1082 = Scalar PgDate
+fromCUInt 1083 = Scalar PgTime
+fromCUInt 1114 = Scalar PgTimestamp
+fromCUInt 1184 = Scalar PgTimestamptz
+fromCUInt 1186 = Scalar PgInterval
+fromCUInt 1560 = Scalar PgBit
+fromCUInt 1562 = Scalar PgVarbit
+fromCUInt 1700 = Scalar PgNumeric
+fromCUInt 2278 = Scalar PgVoid
+fromCUInt 2950 = Scalar PgUuid
+fromCUInt 3802 = Scalar PgJsonb
+-- Arrays
+fromCUInt 143 = Array PgXml
+fromCUInt 199 = Array PgJson
+fromCUInt 791 = Array PgMoney
+fromCUInt 1000 = Array PgBool
+fromCUInt 1001 = Array PgBytea
+fromCUInt 1002 = Array PgChar
+fromCUInt 1003 = Array PgName
+fromCUInt 1005 = Array PgInt2
+fromCUInt 1007 = Array PgInt4
+fromCUInt 1009 = Array PgText
+fromCUInt 1014 = Array PgBpchar
+fromCUInt 1015 = Array PgVarchar
+fromCUInt 1016 = Array PgInt8
+fromCUInt 1021 = Array PgFloat4
+fromCUInt 1022 = Array PgFloat8
+fromCUInt 1115 = Array PgTimestamp
+fromCUInt 1182 = Array PgDate
+fromCUInt 1183 = Array PgTime
+fromCUInt 1185 = Array PgTimestamptz
+fromCUInt 1187 = Array PgInterval
+fromCUInt 1231 = Array PgNumeric
+fromCUInt 1561 = Array PgBit
+fromCUInt 1563 = Array PgVarbit
+fromCUInt 2951 = Array PgUuid
+fromCUInt 3807 = Array PgJsonb
+fromCUInt n = Unrecognized (fromIntegral n)
+
+---------------------------------------------------------------------------
+-- PgType → OID
+---------------------------------------------------------------------------
+
+-- | The OID for a 'PgType'.
+pgTypeOid :: PgType -> LibPQ.Oid
+pgTypeOid (Scalar s) = scalarOid s
+pgTypeOid (Array s) = maybe (LibPQ.Oid 0) id (arrayOid s)
+pgTypeOid (Composite _ _) = LibPQ.Oid 0
+pgTypeOid (CompositeArray _) = LibPQ.Oid 0
+pgTypeOid (Enum _ _) = LibPQ.Oid 0
+pgTypeOid (EnumArray _) = LibPQ.Oid 0
+pgTypeOid (Unrecognized n) = LibPQ.Oid (fromIntegral n)
+
+-- | The scalar OID.
+scalarOid :: PgScalar -> LibPQ.Oid
+scalarOid = LibPQ.Oid . scalarCUInt
+
+-- | The scalar OID as 'Word32', for @postgresql-binary@ array element
+-- encoding which takes the element OID as a 'Word32'.
+scalarOidWord32 :: PgScalar -> Word32
+scalarOidWord32 = fromIntegral . scalarCUInt
+
+scalarCUInt :: PgScalar -> CUInt
+scalarCUInt PgBool = 16
+scalarCUInt PgBytea = 17
+scalarCUInt PgChar = 18
+scalarCUInt PgName = 19
+scalarCUInt PgInt2 = 21
+scalarCUInt PgInt4 = 23
+scalarCUInt PgInt8 = 20
+scalarCUInt PgFloat4 = 700
+scalarCUInt PgFloat8 = 701
+scalarCUInt PgText = 25
+scalarCUInt PgXml = 142
+scalarCUInt PgMoney = 790
+scalarCUInt PgBpchar = 1042
+scalarCUInt PgVarchar = 1043
+scalarCUInt PgDate = 1082
+scalarCUInt PgTime = 1083
+scalarCUInt PgTimestamp = 1114
+scalarCUInt PgTimestamptz = 1184
+scalarCUInt PgInterval = 1186
+scalarCUInt PgBit = 1560
+scalarCUInt PgVarbit = 1562
+scalarCUInt PgNumeric = 1700
+scalarCUInt PgVoid = 2278
+scalarCUInt PgJson = 114
+scalarCUInt PgJsonb = 3802
+scalarCUInt PgUnknown = 705
+scalarCUInt PgUuid = 2950
+
+-- | The array OID corresponding to a scalar type, if one exists.
+arrayOid :: PgScalar -> Maybe LibPQ.Oid
+arrayOid = fmap LibPQ.Oid . arrayCUInt
+
+arrayCUInt :: PgScalar -> Maybe CUInt
+arrayCUInt PgBool = Just 1000
+arrayCUInt PgBytea = Just 1001
+arrayCUInt PgChar = Just 1002
+arrayCUInt PgName = Just 1003
+arrayCUInt PgInt2 = Just 1005
+arrayCUInt PgInt4 = Just 1007
+arrayCUInt PgInt8 = Just 1016
+arrayCUInt PgFloat4 = Just 1021
+arrayCUInt PgFloat8 = Just 1022
+arrayCUInt PgText = Just 1009
+arrayCUInt PgXml = Just 143
+arrayCUInt PgMoney = Just 791
+arrayCUInt PgBpchar = Just 1014
+arrayCUInt PgVarchar = Just 1015
+arrayCUInt PgDate = Just 1182
+arrayCUInt PgTime = Just 1183
+arrayCUInt PgTimestamp = Just 1115
+arrayCUInt PgTimestamptz = Just 1185
+arrayCUInt PgInterval = Just 1187
+arrayCUInt PgBit = Just 1561
+arrayCUInt PgVarbit = Just 1563
+arrayCUInt PgNumeric = Just 1231
+arrayCUInt PgJson = Just 199
+arrayCUInt PgJsonb = Just 3807
+arrayCUInt PgUuid = Just 2951
+arrayCUInt PgVoid = Nothing
+arrayCUInt PgUnknown = Nothing
+
+-- | Reverse lookup: given a scalar element OID (as 'Word32'), return
+-- the corresponding array OID. Used by 'PgEncode' instances for
+-- @[a]@ to tag the encoded array parameter.
+arrayOidWord32 :: Word32 -> Maybe Word32
+arrayOidWord32 w = case fromCUInt (fromIntegral w) of
+ Scalar s -> fromIntegral . (\(LibPQ.Oid o) -> o) <$> arrayOid s
+ _ -> Nothing
+
+---------------------------------------------------------------------------
+-- OID cache
+---------------------------------------------------------------------------
+
+-- | Cache for dynamically-discovered type OIDs (custom enums, composites,
+-- domains, etc.) that aren't in the built-in table. Keyed by raw OID
+-- integer.
+--
+-- Currently unused by the decoding path, but wired into 'PgConn' so that
+-- future custom-type support can populate and consult it without changing
+-- any signatures.
+type OidCache = IntMap PgType
+
+emptyOidCache :: OidCache
+emptyOidCache = IntMap.empty
+
+-- | Classify an OID, falling back to the cache for OIDs not in the
+-- built-in table.
+resolveOid :: OidCache -> LibPQ.Oid -> PgType
+resolveOid cache oid = case fromOid oid of
+ Unrecognized n -> maybe (Unrecognized n) id (IntMap.lookup n cache)
+ known -> known
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Placeholders.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Placeholders.hs
new file mode 100644
index 000000000..6f45e71fb
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Placeholders.hs
@@ -0,0 +1,186 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE OverloadedStrings #-}
+
+-- | Rewrite @?@ placeholders to PostgreSQL @$1, $2, ...@ numbered parameters.
+--
+-- This replaces the functionality that @postgresql-simple@ provides internally.
+-- We must be careful to:
+--
+-- * Replace @?@ with @$N@
+-- * Replace @??@ with a literal @?@ (used by persistent for column expansion, e.g. @RETURNING ??@)
+-- * Skip contents of string literals (@\'...\'@), identifiers (@\"...\"@),
+-- and comments (@--@ line comments, @/* ... */@ block comments)
+module Database.Persist.Postgresql.Internal.Placeholders
+ ( rewritePlaceholders
+ -- * Numbered parameter detection
+ , ParamStyle (..)
+ , detectParamStyle
+ -- * SQL lexer helpers (shared with Pipeline.hs)
+ , skipStringLiteral
+ , skipQuotedIdent
+ , skipBlockComment
+ ) where
+
+import Data.ByteString (ByteString)
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Builder as BB
+import qualified Data.ByteString.Lazy as LBS
+import Data.Char (isDigit, ord)
+import Data.Word (Word8)
+
+
+-- | Rewrite @?@-style placeholders to @$1, $2, ...@ numbered parameters.
+--
+-- Returns the rewritten SQL and the total number of parameters found.
+--
+-- @??@ is replaced with a literal @?@ (not a parameter).
+rewritePlaceholders :: ByteString -> (ByteString, Int)
+rewritePlaceholders input =
+ let (builder, count) = go input 0 mempty
+ in (LBS.toStrict (BB.toLazyByteString builder), count)
+ where
+ go :: ByteString -> Int -> BB.Builder -> (BB.Builder, Int)
+ go bs !n !acc
+ | BS.null bs = (acc, n)
+ | otherwise =
+ let c = BS.head bs
+ rest = BS.tail bs
+ in case c of
+ -- String literal: copy through including contents
+ 39 {- '\'' -} ->
+ let (literal, after) = skipStringLiteral rest
+ in go after n (acc <> BB.word8 39 <> BB.byteString literal)
+
+ -- Quoted identifier: copy through including contents
+ 34 {- '"' -} ->
+ let (ident, after) = skipQuotedIdent rest
+ in go after n (acc <> BB.word8 34 <> BB.byteString ident)
+
+ -- Line comment: copy through to end of line
+ 45 {- '-' -}
+ | not (BS.null rest) && BS.head rest == 45 ->
+ let (comment, after) = BS.break (== 10) rest
+ in go after n (acc <> BB.word8 45 <> BB.byteString comment)
+
+ -- Block comment: copy through to closing */
+ 47 {- '/' -}
+ | not (BS.null rest) && BS.head rest == 42 {- '*' -} ->
+ let (comment, after) = skipBlockComment (BS.tail rest) 1
+ in go after n (acc <> BB.byteString "/*" <> BB.byteString comment)
+
+ -- Placeholder
+ 63 {- '?' -}
+ -- ?? -> literal ?
+ | not (BS.null rest) && BS.head rest == 63 ->
+ go (BS.tail rest) n (acc <> BB.word8 63)
+ -- ? -> $N
+ | otherwise ->
+ let n' = n + 1
+ in go rest n' (acc <> BB.word8 36 <> BB.intDec n')
+
+ -- Anything else: copy through
+ _ -> go rest n (acc <> BB.word8 c)
+
+-- SQL lexer helpers. These use index-based scanning for efficiency
+-- (no intermediate ByteString allocations per character), only doing
+-- a single splitAt at the boundary.
+
+-- | Skip a string literal (everything after the opening @\'@).
+-- Returns (consumed bytes including closing @\'@, remaining input).
+-- Handles @\'\'@ escape sequences.
+skipStringLiteral :: ByteString -> (ByteString, ByteString)
+skipStringLiteral = goStr 0
+ where
+ goStr !i bs
+ | i >= BS.length bs = (bs, BS.empty)
+ | BS.index bs i == 39 {- '\'' -} =
+ if i + 1 < BS.length bs && BS.index bs (i + 1) == 39
+ then goStr (i + 2) bs
+ else BS.splitAt (i + 1) bs
+ | otherwise = goStr (i + 1) bs
+
+-- | Skip a quoted identifier (everything after the opening @\"@).
+-- Returns (consumed bytes including closing @\"@, remaining input).
+-- Handles @\"\"@ escape sequences.
+skipQuotedIdent :: ByteString -> (ByteString, ByteString)
+skipQuotedIdent = goId 0
+ where
+ goId !i bs
+ | i >= BS.length bs = (bs, BS.empty)
+ | BS.index bs i == 34 {- '"' -} =
+ if i + 1 < BS.length bs && BS.index bs (i + 1) == 34
+ then goId (i + 2) bs
+ else BS.splitAt (i + 1) bs
+ | otherwise = goId (i + 1) bs
+
+-- | Skip a block comment (after the opening delimiter).
+-- Handles nested block comments. Returns (consumed bytes including
+-- the closing delimiter, remaining input).
+skipBlockComment :: ByteString -> Int -> (ByteString, ByteString)
+skipBlockComment = goBlk 0
+ where
+ goBlk !i bs !depth
+ | i >= BS.length bs = (bs, BS.empty)
+ | i + 1 < BS.length bs
+ , BS.index bs i == 42 {- '*' -}, BS.index bs (i + 1) == 47 {- '/' -} =
+ if depth <= 1
+ then BS.splitAt (i + 2) bs
+ else goBlk (i + 2) bs (depth - 1)
+ | i + 1 < BS.length bs
+ , BS.index bs i == 47 {- '/' -}, BS.index bs (i + 1) == 42 {- '*' -} =
+ goBlk (i + 2) bs (depth + 1)
+ | otherwise = goBlk (i + 1) bs depth
+
+---------------------------------------------------------------------------
+-- Numbered parameter detection
+---------------------------------------------------------------------------
+
+-- | Parameter placeholder style detected in a SQL statement.
+data ParamStyle
+ = QuestionMarkParams
+ -- ^ Traditional @?@ placeholders (persistent style). Each @?@ is a
+ -- distinct positional parameter.
+ | NumberedParams !Int
+ -- ^ PostgreSQL-native @$N@ placeholders. The 'Int' is the highest
+ -- @$N@ found. A single @$1@ can appear multiple times in the query,
+ -- referencing the same parameter value.
+ deriving (Eq, Show)
+
+-- | Scan a SQL statement (skipping string literals, quoted identifiers,
+-- and comments) for @$N@ patterns. If any are found, return
+-- 'NumberedParams' with the max N; otherwise 'QuestionMarkParams'.
+detectParamStyle :: ByteString -> ParamStyle
+detectParamStyle = go 0
+ where
+ go :: Int -> ByteString -> ParamStyle
+ go !maxN bs
+ | BS.null bs = if maxN > 0 then NumberedParams maxN else QuestionMarkParams
+ | otherwise =
+ let c = BS.head bs
+ rest = BS.tail bs
+ in case c of
+ 39 {- '\'' -} ->
+ let (_lit, after) = skipStringLiteral rest
+ in go maxN after
+ 34 {- '"' -} ->
+ let (_ident, after) = skipQuotedIdent rest
+ in go maxN after
+ 45 {- '-' -}
+ | not (BS.null rest) && BS.head rest == 45 ->
+ let (_comment, after) = BS.break (== 10) rest
+ in go maxN after
+ 47 {- '/' -}
+ | not (BS.null rest) && BS.head rest == 42 ->
+ let (_comment, after) = skipBlockComment (BS.tail rest) 1
+ in go maxN after
+ 36 {- '$' -} ->
+ let (digits, after) = BS.span (\w -> isDigit (chr w)) rest
+ in if BS.null digits
+ then go maxN after
+ else
+ let n = BS.foldl' (\acc w -> acc * 10 + (ord (chr w) - 48)) 0 digits
+ in go (max maxN n) after
+ _ -> go maxN rest
+
+ chr :: Word8 -> Char
+ chr = toEnum . fromIntegral
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/JSON.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/JSON.hs
new file mode 100644
index 000000000..57828f371
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/JSON.hs
@@ -0,0 +1,404 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# OPTIONS_GHC -fno-warn-orphans #-}
+
+-- | Filter operators for JSON values added to PostgreSQL 9.4
+module Database.Persist.Postgresql.JSON
+ ( (@>.)
+ , (<@.)
+ , (?.)
+ , (?|.)
+ , (?&.)
+ , Value ()
+ ) where
+
+import Data.Aeson (FromJSON, ToJSON, Value, eitherDecodeStrict, encode)
+import qualified Data.ByteString.Lazy as BSL
+import Data.Proxy (Proxy)
+import Data.Text (Text)
+import qualified Data.Text as T
+import Data.Text.Encoding as TE (encodeUtf8)
+
+import Database.Persist
+ ( EntityField
+ , Filter (..)
+ , PersistField (..)
+ , PersistFilter (..)
+ , PersistValue (..)
+ )
+import Database.Persist.Sql (PersistFieldSql (..), SqlType (..))
+import Database.Persist.Types (FilterValue (..))
+
+infix 4 @>., <@., ?., ?|., ?&.
+
+-- | This operator checks inclusion of the JSON value
+-- on the right hand side in the JSON value on the left
+-- hand side.
+--
+-- === __Objects__
+--
+-- An empty Object matches any object
+--
+-- @
+-- {} \@> {} == True
+-- {"a":1,"b":false} \@> {} == True
+-- @
+--
+-- Any key-value will be matched top-level
+--
+-- @
+-- {"a":1,"b":{"c":true"}} \@> {"a":1} == True
+-- {"a":1,"b":{"c":true"}} \@> {"b":1} == False
+-- {"a":1,"b":{"c":true"}} \@> {"b":{}} == True
+-- {"a":1,"b":{"c":true"}} \@> {"c":true} == False
+-- {"a":1,"b":{"c":true"}} \@> {"b":{c":true}} == True
+-- @
+--
+-- === __Arrays__
+--
+-- An empty Array matches any array
+--
+-- @
+-- [] \@> [] == True
+-- [1,2,"hi",false,null] \@> [] == True
+-- @
+--
+-- Any array has to be a sub-set.
+-- Any object or array will also be compared as being a subset of.
+--
+-- @
+-- [1,2,"hi",false,null] \@> [1] == True
+-- [1,2,"hi",false,null] \@> [null,"hi"] == True
+-- [1,2,"hi",false,null] \@> ["hi",true] == False
+-- [1,2,"hi",false,null] \@> ["hi",2,null,false,1] == True
+-- [1,2,"hi",false,null] \@> [1,2,"hi",false,null,{}] == False
+-- @
+--
+-- Arrays and objects inside arrays match the same way they'd
+-- be matched as being on their own.
+--
+-- @
+-- [1,"hi",[false,3],{"a":[null]}] \@> [{}] == True
+-- [1,"hi",[false,3],{"a":[null]}] \@> [{"a":[]}] == True
+-- [1,"hi",[false,3],{"a":[null]}] \@> [{"b":[null]}] == False
+-- [1,"hi",[false,3],{"a":[null]}] \@> [[]] == True
+-- [1,"hi",[false,3],{"a":[null]}] \@> [[3]] == True
+-- [1,"hi",[false,3],{"a":[null]}] \@> [[true,3]] == False
+-- @
+--
+-- A regular value has to be a member
+--
+-- @
+-- [1,2,"hi",false,null] \@> 1 == True
+-- [1,2,"hi",false,null] \@> 5 == False
+-- [1,2,"hi",false,null] \@> "hi" == True
+-- [1,2,"hi",false,null] \@> false == True
+-- [1,2,"hi",false,null] \@> "2" == False
+-- @
+--
+-- An object will never match with an array
+--
+-- @
+-- [1,2,"hi",[false,3],{"a":null}] \@> {} == False
+-- [1,2,"hi",[false,3],{"a":null}] \@> {"a":null} == False
+-- @
+--
+-- === __Other values__
+--
+-- For any other JSON values the `(\@>.)` operator
+-- functions like an equivalence operator.
+--
+-- @
+-- "hello" \@> "hello" == True
+-- "hello" \@> \"Hello" == False
+-- "hello" \@> "h" == False
+-- "hello" \@> {"hello":1} == False
+-- "hello" \@> ["hello"] == False
+--
+-- 5 \@> 5 == True
+-- 5 \@> 5.00 == True
+-- 5 \@> 1 == False
+-- 5 \@> 7 == False
+-- 12345 \@> 1234 == False
+-- 12345 \@> 2345 == False
+-- 12345 \@> "12345" == False
+-- 12345 \@> [1,2,3,4,5] == False
+--
+-- true \@> true == True
+-- true \@> false == False
+-- false \@> true == False
+-- true \@> "true" == False
+--
+-- null \@> null == True
+-- null \@> 23 == False
+-- null \@> "null" == False
+-- null \@> {} == False
+-- @
+--
+-- @since 2.8.2
+(@>.) :: EntityField record Value -> Value -> Filter record
+(@>.) field val = Filter field (FilterValue val) $ BackendSpecificFilter " @> "
+
+-- | Same as '@>.' except the inclusion check is reversed.
+-- i.e. is the JSON value on the left hand side included
+-- in the JSON value of the right hand side.
+--
+-- @since 2.8.2
+(<@.) :: EntityField record Value -> Value -> Filter record
+(<@.) field val = Filter field (FilterValue val) $ BackendSpecificFilter " <@ "
+
+-- | This operator takes a column and a string to find a
+-- top-level key/field in an object.
+--
+-- @column ?. string@
+--
+-- N.B. This operator might have some unexpected interactions
+-- with non-object values. Please reference the examples.
+--
+-- === __Objects__
+--
+-- @
+-- {"a":null} ? "a" == True
+-- {"test":false,"a":500} ? "a" == True
+-- {"b":{"a":[]}} ? "a" == False
+-- {} ? "a" == False
+-- {} ? "{}" == False
+-- {} ? "" == False
+-- {"":9001} ? "" == True
+-- @
+--
+-- === __Arrays__
+--
+-- This operator will match an array if the string to be matched
+-- is an element of that array, but nothing else.
+--
+-- @
+-- ["a"] ? "a" == True
+-- [["a"]] ? "a" == False
+-- [9,false,"1",null] ? "1" == True
+-- [] ? "[]" == False
+-- [{"a":true}] ? "a" == False
+-- @
+--
+-- === __Other values__
+--
+-- This operator functions like an equivalence operator on strings only.
+-- Any other value does not match.
+--
+-- @
+-- "a" ? "a" == True
+-- "1" ? "1" == True
+-- "ab" ? "a" == False
+-- 1 ? "1" == False
+-- null ? "null" == False
+-- true ? "true" == False
+-- 1.5 ? "1.5" == False
+-- @
+--
+-- @since 2.10.0
+(?.) :: EntityField record Value -> Text -> Filter record
+(?.) = jsonFilter " ?? "
+
+-- | This operator takes a column and a list of strings to
+-- test whether ANY of the elements of the list are top
+-- level fields in an object.
+--
+-- @column ?|. list@
+--
+-- /N.B. An empty list __will never match anything__. Also, this/
+-- /operator might have some unexpected interactions with/
+-- /non-object values. Please reference the examples./
+--
+-- === __Objects__
+--
+-- @
+-- {"a":null} ?| ["a","b","c"] == True
+-- {"test":false,"a":500} ?| ["a","b","c"] == True
+-- {} ?| ["a","{}"] == False
+-- {"b":{"a":[]}} ?| ["a","c"] == False
+-- {"b":{"a":[]},"test":null} ?| [] == False
+-- @
+--
+-- === __Arrays__
+--
+-- This operator will match an array if __any__ of the elements
+-- of the list are matching string elements of the array.
+--
+-- @
+-- ["a"] ?| ["a","b","c"] == True
+-- [["a"]] ?| ["a","b","c"] == False
+-- [9,false,"1",null] ?| ["a","false"] == False
+-- [] ?| ["a","b","c"] == False
+-- [] ?| [] == False
+-- [{"a":true}] ?| ["a","b","c"] == False
+-- [null,4,"b",[]] ?| ["a","b","c"] == True
+-- @
+--
+-- === __Other values__
+--
+-- This operator functions much like an equivalence operator
+-- on strings only. If a string matches with __any__ element of
+-- the given list, the comparison matches. No other values match.
+--
+-- @
+-- "a" ?| ["a","b","c"] == True
+-- "1" ?| ["a","b","1"] == True
+-- "ab" ?| ["a","b","c"] == False
+-- 1 ?| ["a","1"] == False
+-- null ?| ["a","null"] == False
+-- true ?| ["a","true"] == False
+-- "a" ?| [] == False
+-- @
+--
+-- @since 2.10.0
+(?|.) :: EntityField record Value -> [Text] -> Filter record
+(?|.) field = jsonFilter " ??| " field . PostgresArray
+
+-- | This operator takes a column and a list of strings to
+-- test whether ALL of the elements of the list are top
+-- level fields in an object.
+--
+-- @column ?&. list@
+--
+-- /N.B. An empty list __will match anything__. Also, this/
+-- /operator might have some unexpected interactions with/
+-- /non-object values. Please reference the examples./
+--
+-- === __Objects__
+--
+-- @
+-- {"a":null} ?& ["a"] == True
+-- {"a":null} ?& ["a","a"] == True
+-- {"test":false,"a":500} ?& ["a"] == True
+-- {"test":false,"a":500} ?& ["a","b"] == False
+-- {} ?& ["{}"] == False
+-- {"b":{"a":[]}} ?& ["a"] == False
+-- {"b":{"a":[]},"c":false} ?& ["a","c"] == False
+-- {"a":1,"b":2,"c":3,"d":4} ?& ["b","d"] == True
+-- {} ?& [] == True
+-- {"b":{"a":[]},"test":null} ?& [] == True
+-- @
+--
+-- === __Arrays__
+--
+-- This operator will match an array if __all__ of the elements
+-- of the list are matching string elements of the array.
+--
+-- @
+-- ["a"] ?& ["a"] == True
+-- ["a"] ?& ["a","a"] == True
+-- [["a"]] ?& ["a"] == False
+-- ["a","b","c"] ?& ["a","b","d"] == False
+-- [9,"false","1",null] ?& ["1","false"] == True
+-- [] ?& ["a","b"] == False
+-- [{"a":true}] ?& ["a"] == False
+-- ["a","b","c","d"] ?& ["b","c","d"] == True
+-- [null,4,{"test":false}] ?& [] == True
+-- [] ?& [] == True
+-- @
+--
+-- === __Other values__
+--
+-- This operator functions much like an equivalence operator
+-- on strings only. If a string matches with all elements of
+-- the given list, the comparison matches.
+--
+-- @
+-- "a" ?& ["a"] == True
+-- "1" ?& ["a","1"] == False
+-- "b" ?& ["b","b"] == True
+-- "ab" ?& ["a","b"] == False
+-- 1 ?& ["1"] == False
+-- null ?& ["null"] == False
+-- true ?& ["true"] == False
+-- 31337 ?& [] == True
+-- true ?& [] == True
+-- null ?& [] == True
+-- @
+--
+-- @since 2.10.0
+(?&.) :: EntityField record Value -> [Text] -> Filter record
+(?&.) field = jsonFilter " ??& " field . PostgresArray
+
+jsonFilter
+ :: (PersistField a) => Text -> EntityField record Value -> a -> Filter record
+jsonFilter op field a = Filter field (UnsafeValue a) $ BackendSpecificFilter op
+
+-----------------
+-- AESON VALUE --
+-----------------
+
+instance PersistField Value where
+ toPersistValue = toPersistValueJsonB
+ fromPersistValue = fromPersistValueJsonB
+
+instance PersistFieldSql Value where
+ sqlType = sqlTypeJsonB
+
+-- FIXME: PersistText might be a bit more efficient,
+-- but needs testing/profiling before changing it.
+-- (When entering into the DB the type isn't as important as fromPersistValue)
+toPersistValueJsonB :: (ToJSON a) => a -> PersistValue
+toPersistValueJsonB = PersistLiteralEscaped . BSL.toStrict . encode
+
+fromPersistValueJsonB :: (FromJSON a) => PersistValue -> Either Text a
+fromPersistValueJsonB (PersistText t) =
+ case eitherDecodeStrict $ TE.encodeUtf8 t of
+ Left str -> Left $ fromPersistValueParseError "FromJSON" t $ T.pack str
+ Right v -> Right v
+fromPersistValueJsonB (PersistByteString bs) =
+ case eitherDecodeStrict bs of
+ Left str -> Left $ fromPersistValueParseError "FromJSON" bs $ T.pack str
+ Right v -> Right v
+fromPersistValueJsonB x = Left $ fromPersistValueError "FromJSON" "string or bytea" x
+
+-- Constraints on the type might not be necessary,
+-- but better to leave them in.
+sqlTypeJsonB :: (ToJSON a, FromJSON a) => Proxy a -> SqlType
+sqlTypeJsonB _ = SqlOther "JSONB"
+
+fromPersistValueError
+ :: Text
+ -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64"
+ -> Text
+ -- ^ Database type(s), should appear different from Haskell name, e.g. "integer" or "INT", not "Int".
+ -> PersistValue
+ -- ^ Incorrect value
+ -> Text
+ -- ^ Error message
+fromPersistValueError haskellType databaseType received =
+ T.concat
+ [ "Failed to parse Haskell type `"
+ , haskellType
+ , "`; expected "
+ , databaseType
+ , " from database, but received: "
+ , T.pack (show received)
+ , ". Potential solution: Check that your database schema matches your Persistent model definitions."
+ ]
+
+fromPersistValueParseError
+ :: (Show a)
+ => Text
+ -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64"
+ -> a
+ -- ^ Received value
+ -> Text
+ -- ^ Additional error
+ -> Text
+ -- ^ Error message
+fromPersistValueParseError haskellType received err =
+ T.concat
+ [ "Failed to parse Haskell type `"
+ , haskellType
+ , "`, but received "
+ , T.pack (show received)
+ , " | with error: "
+ , err
+ ]
+
+newtype PostgresArray a = PostgresArray [a]
+
+instance (PersistField a) => PersistField (PostgresArray a) where
+ toPersistValue (PostgresArray ts) = PersistArray $ toPersistValue <$> ts
+ fromPersistValue (PersistArray as) = PostgresArray <$> traverse fromPersistValue as
+ fromPersistValue wat = Left $ fromPersistValueError "PostgresArray" "array" wat
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline.hs
new file mode 100644
index 000000000..c769985b5
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline.hs
@@ -0,0 +1,2694 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE DerivingVia #-}
+{-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE PatternSynonyms #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE NamedFieldPuns #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TupleSections #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE ViewPatterns #-}
+
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+-- | A pipelined PostgreSQL backend for persistent.
+--
+-- Uses binary wire protocol via @postgresql-libpq@ and @postgresql-binary@,
+-- with support for libpq pipeline mode.
+module Database.Persist.Postgresql.Pipeline
+ ( -- * Backend types
+ PostgreSQLBackend
+ , ReadBackend (..)
+ , WriteBackend (..)
+ , readToWrite
+ , readToUnknown
+ , writeToUnknown
+ -- * Connection pools
+ , withPostgresqlPipelinePool
+ , createPostgresqlPipelinePool
+ , withPostgresqlPipelineConn
+ -- * Re-exports from persistent
+ , module Database.Persist.Sql
+ , ConnectionString
+ -- * Configuration
+ -- $configuration
+ , PipelineSettings
+ -- ** Construction
+ , defaultPipelineSettings
+ , singleRowSettings
+ , chunkedSettings
+ -- ** Fields
+ , pipelineFetchMode
+ , pipelineConnStr
+ , pipelinePoolSize
+ , FetchMode (..)
+ , PostgresConf (..)
+ , PostgresConfHooks (..)
+ , defaultPostgresConfHooks
+ -- * Upsert utilities
+ , HandleUpdateCollision
+ , copyField
+ , copyUnlessNull
+ , copyUnlessEmpty
+ , copyUnlessEq
+ , excludeNotEqualToOriginal
+ , upsertWhere
+ , upsertManyWhere
+ -- * Misc utilities
+ , PgInterval (..)
+ , tableName
+ , fieldName
+ , mockMigration
+ , migrateEnableExtension
+ -- * Optimized batch operations
+ , getManyKeys
+ , deleteManyKeys
+ , insertManyUnnest
+ , insertManyUnnest_
+ , putManyUnnest
+ , repsertManyUnnest
+ -- * Pipeline mode
+ , withPipeline
+ , flushPipeline
+ -- * Pipeline-aware raw execution
+ , rawExecuteSync
+ , rawExecuteNoReturn
+ -- * Raw access
+ , RawPostgresqlPipeline (..)
+ , createRawPostgresqlPipelinePool
+ -- * Backend creation
+ , createBackend
+ -- * Low-level access (for testing)
+ , getPipelineConn
+ , inlineAndRewrite
+ , collapseInClauses
+ ) where
+
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+
+import Control.Exception (throwIO)
+import Control.Monad
+import Control.Monad.IO.Unlift (MonadIO (..), MonadUnliftIO)
+import Control.Monad.Trans.Class (lift)
+import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT)
+import Control.Monad.Trans.Reader (ReaderT (..), ask, asks, runReaderT)
+import Control.Monad.Trans.Writer (WriterT (..), runWriterT)
+import qualified Data.Foldable as Foldable
+import qualified Data.List.NonEmpty as NEL
+import Data.Proxy (Proxy (..))
+
+import Data.Acquire (Acquire, mkAcquire)
+import qualified Data.Aeson as A
+import Data.Aeson (FromJSON(..), ToJSON(..), (.:), (.:?), (.!=), withObject)
+import Data.Aeson.Types (modifyFailure)
+import GHC.Generics (Generic)
+import Data.ByteString (ByteString)
+import qualified Data.ByteString.Builder as BB
+import qualified Data.ByteString.Char8 as B8
+import qualified Data.ByteString.Lazy as LBS
+import Data.Conduit
+import Data.Data (Data)
+import Data.Either (partitionEithers)
+import Data.IORef
+import Data.Int (Int64)
+import Data.Maybe (fromMaybe)
+import Data.Function (on)
+import Data.List (nubBy, transpose)
+import Data.List.NonEmpty (NonEmpty)
+import qualified Data.Map as Map
+import qualified Data.Monoid as Monoid
+import Data.Pool (Pool)
+import Data.Text (Text)
+import qualified Data.Text as T
+import qualified Data.Text.Encoding as T
+import qualified Data.Text.IO as TIO
+import System.Environment (getEnvironment)
+
+import Database.Persist.Class.PersistUnique (persistUniqueKeyValues)
+import Database.Persist.Compatible
+import qualified Data.Vault.Strict as Vault
+import Database.Persist.Postgresql.Internal
+import Database.Persist.Postgresql.Pipeline.Internal
+import Database.Persist.Postgresql.Pipeline.FFI
+ ( rawResultStatusInt
+ , setChunkedRowsMode
+ , pattern PGRES_SINGLE_TUPLE
+ , pattern PGRES_TUPLES_OK
+ , pattern PGRES_TUPLES_CHUNK
+ , pattern PGRES_FATAL_ERROR
+ )
+import Database.Persist.DirectDecode
+ (FromRow (..), RowDecoder (..), newCounter, runRowReaderCPS, runRowDecoderCPS)
+import Database.Persist.Postgresql.Internal.Decoding (decodePersistValue)
+import Database.Persist.Postgresql.Internal.DirectDecode (PgRowEnv (..))
+import Database.Persist.Postgresql.Internal.DirectEncode (PgParam (..), pgParamToLibPQ)
+import Database.Persist.Postgresql.Internal.Encoding (encodePersistValue)
+import Database.Persist.Postgresql.Internal.Placeholders
+ (skipStringLiteral, skipQuotedIdent, skipBlockComment, ParamStyle (..), detectParamStyle, rewritePlaceholders)
+import Database.Persist.Postgresql.Internal.PgType (OidCache, PgType, emptyOidCache, fromOid)
+import Database.Persist.Sql.DirectRaw (HasDirectQuery (..))
+import Database.Persist.SqlBackend.Internal
+ (SqlBackend (..), DirectQueryCap (..))
+import qualified Data.Vector as V
+import Database.Persist.Sql hiding (readToWrite, readToUnknown, writeToUnknown)
+import qualified Database.Persist.Sql.Util as Util
+import Database.Persist.SqlBackend
+import System.IO.Unsafe (unsafeInterleaveIO, unsafePerformIO)
+
+-- | A @libpq@ connection string. A simple example of connection
+-- string would be @\"host=localhost port=5432 user=test
+-- dbname=test password=test\"@. Please read libpq's
+-- documentation at
+--
+-- for more details on how to create such strings.
+type ConnectionString = ByteString
+
+-- $configuration
+--
+-- 'PipelineSettings' is an opaque configuration record. Start from
+-- 'defaultPipelineSettings' (or a convenience smart constructor like
+-- 'singleRowSettings' / 'chunkedSettings') and customise individual
+-- fields with record update syntax:
+--
+-- @
+-- mySettings :: PipelineSettings
+-- mySettings = defaultPipelineSettings
+-- { pipelineFetchMode = FetchChunked 500
+-- , pipelinePoolSize = 10
+-- , pipelineConnStr = \"host=localhost dbname=mydb\"
+-- }
+-- @
+--
+-- The constructor is intentionally hidden so that new fields can be
+-- added (with defaults) in future releases without breaking
+-- downstream code that uses record updates.
+
+-- | Configuration for a pipelined PostgreSQL backend.
+--
+-- Construct via 'defaultPipelineSettings', 'singleRowSettings', or
+-- 'chunkedSettings', then tweak with 'setFetchMode', 'setConnStr',
+-- and 'setPoolSize'. The constructor is not exported; use the
+-- accessor functions to read values.
+data PipelineSettings = MkPipelineSettings
+ { pipelineFetchMode :: !FetchMode
+ -- ^ How SELECT-like query results are fetched. Default: 'FetchAll'.
+ , pipelineConnStr :: !ConnectionString
+ -- ^ libpq connection string. Default: @\"\"@ (empty -- uses libpq
+ -- defaults \/ environment variables).
+ , pipelinePoolSize :: !Int
+ -- ^ Number of connections in the pool. Default: @1@.
+ }
+
+-- | Sensible defaults: 'FetchAll', empty connection string, pool size 1.
+defaultPipelineSettings :: PipelineSettings
+defaultPipelineSettings = MkPipelineSettings
+ { pipelineFetchMode = FetchAll
+ , pipelineConnStr = ""
+ , pipelinePoolSize = 1
+ }
+
+-- | 'FetchSingleRow' preset. One row per @PGresult@ -- lowest memory,
+-- works with every PostgreSQL version.
+singleRowSettings :: ConnectionString -> Int -> PipelineSettings
+singleRowSettings cstr poolSz = MkPipelineSettings
+ { pipelineFetchMode = FetchSingleRow
+ , pipelineConnStr = cstr
+ , pipelinePoolSize = poolSz
+ }
+
+-- | 'FetchChunked' preset. Up to @chunkSize@ rows per @PGresult@ -- good
+-- balance of bounded memory and low allocation overhead. Requires
+-- libpq >= 17; automatically falls back to 'FetchSingleRow' otherwise.
+chunkedSettings :: Int -> ConnectionString -> Int -> PipelineSettings
+chunkedSettings chunkSize cstr poolSz = MkPipelineSettings
+ { pipelineFetchMode = FetchChunked chunkSize
+ , pipelineConnStr = cstr
+ , pipelinePoolSize = poolSz
+ }
+
+-- | Create a PostgreSQL connection pool and run the given action.
+--
+-- @
+-- withPostgresqlPipelinePool settings $ \\pool ->
+-- runSqlPool (selectList [] []) pool
+-- @
+withPostgresqlPipelinePool
+ :: (MonadLoggerIO m, MonadUnliftIO m)
+ => PipelineSettings
+ -> (Pool (WriteBackend PostgreSQLBackend) -> m a)
+ -> m a
+withPostgresqlPipelinePool settings =
+ withSqlPool (openPipelineWith settings) (pipelinePoolSize settings)
+
+-- | Create a PostgreSQL connection pool (returned, not bracketed).
+createPostgresqlPipelinePool
+ :: (MonadUnliftIO m, MonadLoggerIO m)
+ => PipelineSettings
+ -> m (Pool (WriteBackend PostgreSQLBackend))
+createPostgresqlPipelinePool settings =
+ createSqlPool (openPipelineWith settings) (pipelinePoolSize settings)
+
+-- | Open a single pipelined PostgreSQL connection and run the given action.
+withPostgresqlPipelineConn
+ :: (MonadUnliftIO m, MonadLoggerIO m)
+ => PipelineSettings
+ -> (WriteBackend PostgreSQLBackend -> m a)
+ -> m a
+withPostgresqlPipelineConn settings =
+ withSqlConn (openPipelineWith settings)
+
+-- | Open a pipelined connection with the given settings, returning a
+-- write-capable backend.
+openPipelineWith :: PipelineSettings -> LogFunc -> IO (WriteBackend PostgreSQLBackend)
+openPipelineWith settings logFunc = do
+ pc <- openPgConn (pipelineFetchMode settings) (pipelineConnStr settings)
+ smap <- newIORef mempty
+ return $ WriteBackend $ PostgreSQLBackend $
+ createBackend logFunc (pgVersion pc) smap pc
+
+-- | Create the backend given a logging function, server version,
+-- mutable statement cell, and connection.
+createBackend
+ :: LogFunc
+ -> NonEmpty Word
+ -> IORef (Map.Map Text Statement)
+ -> PgConn
+ -> SqlBackend
+createBackend logFunc serverVersion smap pc =
+ maybe id setConnPutManySql (upsertFunction putManySql serverVersion) $
+ maybe id setConnUpsertSql (upsertFunction upsertSql' serverVersion) $
+ setConnInsertManySql insertManySql' $
+ maybe id setConnRepsertManySql (upsertFunction repsertManySql serverVersion) $
+ modifyConnVault (Vault.insert underlyingConnectionKey pc) $
+ setDirectQueryCap pc $
+ mkSqlBackend
+ MkSqlBackendArgs
+ { connPrepare = prepare' pc
+ , connStmtMap = smap
+ , connInsertSql = insertSql'
+ , connClose = closePgConn pc
+ , connMigrateSql = migrate' emptyBackendSpecificOverrides
+ , connBegin = \_ mIsolation -> do
+ -- Safety: drain any leftover pending results
+ -- from a previous session that didn't go through
+ -- COMMIT/ROLLBACK (e.g. runSqlPoolNoTransaction).
+ -- Zero-cost when pgPending == 0.
+ drainPending pc
+ -- Fire-and-forget: BEGIN is pipelined with the
+ -- first user query for zero extra round-trips.
+ let isolationStr = case mIsolation of
+ Nothing -> ""
+ Just ReadUncommitted -> " ISOLATION LEVEL READ COMMITTED"
+ Just ReadCommitted -> " ISOLATION LEVEL READ COMMITTED"
+ Just RepeatableRead -> " ISOLATION LEVEL REPEATABLE READ"
+ Just Serializable -> " ISOLATION LEVEL SERIALIZABLE"
+ ok <- pgSendQueryParams pc ("BEGIN" <> isolationStr) [] LibPQ.Binary
+ unless ok $ throwIO $ PipelineExecError "failed to send BEGIN"
+ beginReply <- pgRecvResult pc
+ modifyIORef (pgPending pc) (beginReply :)
+ , connCommit = \_ -> do
+ -- Drain all outstanding results (lazy + fire-and-forget).
+ drainPending pc
+ -- Send COMMIT + sync, read COMMIT result + sync marker.
+ ok <- pgSendQueryParams pc "COMMIT" [] LibPQ.Binary
+ unless ok $ throwIO $ PipelineExecError "failed to send COMMIT"
+ ok2 <- pgPipelineSync pc
+ unless ok2 $ throwIO $ PipelineModeError "failed to sync pipeline"
+ commitReply <- pgRecvResult pc
+ st <- LibPQ.resultStatus commitReply
+ LibPQ.unsafeFreeResult commitReply
+ drainSyncResult pc
+ case st of
+ LibPQ.FatalError -> throwIO $ PipelineExecError "COMMIT failed"
+ _ -> return ()
+ , connRollback = \_ -> do
+ -- Drain all outstanding results.
+ drainPending pc
+ ok <- pgSendQueryParams pc "ROLLBACK" [] LibPQ.Binary
+ unless ok $ throwIO $ PipelineExecError "failed to send ROLLBACK"
+ ok2 <- pgPipelineSync pc
+ unless ok2 $ throwIO $ PipelineModeError "failed to sync pipeline"
+ drainToSync pc
+ writeIORef (pgPending pc) []
+ , connEscapeFieldName = escapeF
+ , connEscapeTableName = escapeE . getEntityDBName
+ , connEscapeRawName = escape
+ , connNoLimit = "LIMIT ALL"
+ , connRDBMS = "postgresql"
+ , connLimitOffset = decorateSQLWithLimitOffset "LIMIT ALL"
+ , connLogFunc = logFunc
+ }
+
+-- | Install the direct-query capability for PgRowEnv.
+--
+-- The 'DirectQueryCap' receives a pre-built 'RowDecoder' from the call
+-- site (produced by 'lookupDirectDecoder'). The backend just sends the
+-- query, builds 'PgRowEnv' per row, and applies the decoder.
+setDirectQueryCap :: PgConn -> SqlBackend -> SqlBackend
+setDirectQueryCap pc' sb = sb { connDirectQueryCap = Just cap }
+ where
+ cap = MkDirectQueryCap (Proxy @PgRowEnv) $ \decoder sql pvs -> do
+ let sqlBS = T.encodeUtf8 sql
+ drainPending pc'
+ cache <- readIORef (pgOidCache pc')
+ let params = map encodePersistValue pvs
+ ok <- case detectParamStyle sqlBS of
+ QuestionMarkParams ->
+ let (rewritten, _) = rewritePlaceholders sqlBS
+ in pgSendQueryParams pc' rewritten params LibPQ.Binary
+ NumberedParams _ ->
+ pgSendQueryParams pc' sqlBS params LibPQ.Binary
+ unless ok $ do
+ merr <- LibPQ.errorMessage (pgConn pc')
+ throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr)
+ _ <- pgSendFlushRequest pc'
+ pgFlush pc'
+ ret <- readOneQueryResult pc'
+ colTypes <- getColumnPgTypes ret
+ rowCount <- LibPQ.ntuples ret
+ results <- decodeAllRows decoder ret rowCount colTypes cache
+ LibPQ.unsafeFreeResult ret
+ pure results
+
+ decodeAllRows decoder ret rowCount colTypes cache = go [] (LibPQ.Row 0)
+ where
+ go !acc !row
+ | row == rowCount = pure (reverse acc)
+ | otherwise = do
+ let env = PgRowEnv ret row colTypes cache
+ val <- runRowDecoder decoder env
+ (\e -> throwIO $ PipelineExecError $ T.encodeUtf8 e)
+ pure
+ go (val : acc) (row + 1)
+
+-- | Key for storing PgConn in the SqlBackend vault.
+underlyingConnectionKey :: Vault.Key PgConn
+underlyingConnectionKey = unsafePerformIO Vault.newKey
+{-# NOINLINE underlyingConnectionKey #-}
+
+-- | Access underlying PgConn, returning 'Nothing' if the 'SqlBackend'
+-- provided isn't backed by this pipeline backend.
+getPipelineConn
+ :: (BackendCompatible SqlBackend backend) => backend -> Maybe PgConn
+getPipelineConn = Vault.lookup underlyingConnectionKey <$> getConnVault
+
+-------------------------------------------------------------------------------
+-- PostgreSQLBackend / ReadBackend / WriteBackend
+--
+-- PostgreSQLBackend wraps SqlBackend. ReadBackend and WriteBackend
+-- are capability newtypes that mirror SqlReadBackend / SqlWriteBackend
+-- but carry the UNNEST / = ANY optimised typeclass instances.
+-------------------------------------------------------------------------------
+
+-- | The underlying PostgreSQL backend. Not used directly in user
+-- signatures -- use @'ReadBackend' 'PostgreSQLBackend'@ or
+-- @'WriteBackend' 'PostgreSQLBackend'@ to select read-only or
+-- read-write capability.
+newtype PostgreSQLBackend = PostgreSQLBackend { unPostgreSQLBackend :: SqlBackend }
+
+-- | A read-only capability wrapper. @ReadBackend PostgreSQLBackend@
+-- provides 'PersistStoreRead', 'PersistQueryRead', and
+-- 'PersistUniqueRead' instances but no write instances.
+newtype ReadBackend backend = ReadBackend { unReadBackend :: backend }
+
+-- | A read-write capability wrapper. @WriteBackend PostgreSQLBackend@
+-- provides all read instances plus 'PersistStoreWrite',
+-- 'PersistQueryWrite', and 'PersistUniqueWrite' with UNNEST-based
+-- columnar inserts.
+newtype WriteBackend backend = WriteBackend { unWriteBackend :: backend }
+
+-- | Lift a read-only action into a read-write context.
+readToWrite
+ :: (Monad m)
+ => ReaderT (ReadBackend PostgreSQLBackend) m a
+ -> ReaderT (WriteBackend PostgreSQLBackend) m a
+readToWrite ma = do
+ WriteBackend w <- ask
+ lift $ runReaderT ma (ReadBackend w)
+
+-- | Run a read-only action against a bare 'PostgreSQLBackend'.
+readToUnknown
+ :: (Monad m)
+ => ReaderT (ReadBackend PostgreSQLBackend) m a
+ -> ReaderT PostgreSQLBackend m a
+readToUnknown ma = do
+ unknown <- ask
+ lift $ runReaderT ma (ReadBackend unknown)
+
+-- | Run a read-write action against a bare 'PostgreSQLBackend'.
+writeToUnknown
+ :: (Monad m)
+ => ReaderT (WriteBackend PostgreSQLBackend) m a
+ -> ReaderT PostgreSQLBackend m a
+writeToUnknown ma = do
+ unknown <- ask
+ lift $ runReaderT ma (WriteBackend unknown)
+
+-- PostgreSQLBackend base instances
+
+instance HasPersistBackend PostgreSQLBackend where
+ type BaseBackend PostgreSQLBackend = SqlBackend
+ persistBackend = unPostgreSQLBackend
+
+instance BackendCompatible SqlBackend PostgreSQLBackend where
+ projectBackend = unPostgreSQLBackend
+
+instance PersistCore PostgreSQLBackend where
+ newtype BackendKey PostgreSQLBackend = PostgreSQLBackendKey
+ { unPostgreSQLBackendKey :: Int64 }
+ deriving stock (Show, Read, Eq, Ord, Generic)
+ deriving newtype
+ ( Num, Integral, PersistField, PersistFieldSql
+ , Real, Enum, Bounded
+ , A.ToJSON, A.FromJSON
+ )
+
+-- ReadBackend instances
+
+instance HasPersistBackend (ReadBackend PostgreSQLBackend) where
+ type BaseBackend (ReadBackend PostgreSQLBackend) = SqlBackend
+ persistBackend = unPostgreSQLBackend . unReadBackend
+
+instance BackendCompatible SqlBackend (ReadBackend PostgreSQLBackend) where
+ projectBackend = unPostgreSQLBackend . unReadBackend
+
+instance PersistCore (ReadBackend PostgreSQLBackend) where
+ newtype BackendKey (ReadBackend PostgreSQLBackend) = ReadPostgreSQLBackendKey
+ { unReadPostgreSQLBackendKey :: Int64 }
+ deriving stock (Show, Read, Eq, Ord, Generic)
+ deriving newtype
+ ( Num, Integral, PersistField, PersistFieldSql
+ , Real, Enum, Bounded
+ , A.ToJSON, A.FromJSON
+ )
+
+instance PersistStoreRead (ReadBackend PostgreSQLBackend) where
+ get k = liftViaWrite $ pipelinedGet k
+ getMany ks = withBaseBackend $ getMany ks
+
+instance PersistQueryRead (ReadBackend PostgreSQLBackend) where
+ count filts = liftViaWrite $ pipelinedCount filts
+ exists filts = liftViaWrite $ pipelinedExists filts
+ selectSourceRes filts opts = withBaseBackend $ selectSourceRes filts opts
+ selectKeysRes filts opts = withBaseBackend $ selectKeysRes filts opts
+
+instance PersistUniqueRead (ReadBackend PostgreSQLBackend) where
+ getBy uniq = liftViaWrite $ pipelinedGetBy uniq
+ existsBy uniq = withBaseBackend $ existsBy uniq
+
+liftViaWrite
+ :: MonadIO m
+ => ReaderT (WriteBackend PostgreSQLBackend) IO a
+ -> ReaderT (ReadBackend PostgreSQLBackend) m a
+liftViaWrite action = do
+ backend <- ask
+ liftIO $ runReaderT action (WriteBackend $ unReadBackend backend)
+
+-- WriteBackend instances (read + write, with UNNEST overrides)
+
+instance HasPersistBackend (WriteBackend PostgreSQLBackend) where
+ type BaseBackend (WriteBackend PostgreSQLBackend) = SqlBackend
+ persistBackend = unPostgreSQLBackend . unWriteBackend
+
+instance BackendCompatible SqlBackend (WriteBackend PostgreSQLBackend) where
+ projectBackend = unPostgreSQLBackend . unWriteBackend
+
+instance PersistCore (WriteBackend PostgreSQLBackend) where
+ newtype BackendKey (WriteBackend PostgreSQLBackend) = WritePostgreSQLBackendKey
+ { unWritePostgreSQLBackendKey :: Int64 }
+ deriving stock (Show, Read, Eq, Ord, Generic)
+ deriving newtype
+ ( Num, Integral, PersistField, PersistFieldSql
+ , Real, Enum, Bounded
+ , A.ToJSON, A.FromJSON
+ )
+
+instance PersistStoreRead (WriteBackend PostgreSQLBackend) where
+ get = pipelinedGet
+ getMany ks = withBaseBackend $ getMany ks
+
+instance PersistStoreWrite (WriteBackend PostgreSQLBackend) where
+ insert = pipelinedInsert
+ insert_ v = withBaseBackend $ insert_ v
+ insertMany vs = withBaseBackend $ insertMany vs
+ insertKey k v = withBaseBackend $ insertKey k v
+ insertEntityMany vs = withBaseBackend $ insertEntityMany vs
+ repsert k v = withBaseBackend $ repsert k v
+ replace k v = withBaseBackend $ replace k v
+ delete k = withBaseBackend $ delete k
+ update k upds = withBaseBackend $ update k upds
+
+ insertMany_ [] = return ()
+ insertMany_ records = do
+ let t = entityDef records
+ (sql, params) = insertMany_Hook t (map Util.mkInsertValues records)
+ rawExecute sql params
+
+ repsertMany [] = return ()
+ repsertMany krsDups = do
+ let krs = nubBy ((==) `on` fst) (reverse krsDups)
+ t = entityDef (map snd krs)
+ toVals (k, r) =
+ case entityPrimary t of
+ Nothing -> keyToValues k <> Util.mkInsertValues r
+ Just _ -> Util.mkInsertValues r
+ (sql, params) = repsertManyHook t (map toVals krs)
+ rawExecute sql params
+
+instance PersistQueryRead (WriteBackend PostgreSQLBackend) where
+ count = pipelinedCount
+ exists = pipelinedExists
+ selectSourceRes filts opts = withBaseBackend $ selectSourceRes filts opts
+ selectKeysRes filts opts = withBaseBackend $ selectKeysRes filts opts
+
+instance PersistQueryWrite (WriteBackend PostgreSQLBackend) where
+ deleteWhere filts = withBaseBackend $ deleteWhere filts
+ updateWhere filts upds = withBaseBackend $ updateWhere filts upds
+
+instance PersistUniqueRead (WriteBackend PostgreSQLBackend) where
+ getBy = pipelinedGetBy
+ existsBy uniq = withBaseBackend $ existsBy uniq
+
+instance PersistUniqueWrite (WriteBackend PostgreSQLBackend) where
+ deleteBy uniq = withBaseBackend $ deleteBy uniq
+ upsertBy uniqueKey record updates = pipelinedUpsertBy uniqueKey record updates
+
+ putMany [] = return ()
+ putMany rsD@(r:_) = do
+ let uKeys = persistUniqueKeys r
+ case uKeys of
+ [] -> insertMany_ rsD
+ _ -> do
+ let rs = nubBy ((==) `on` persistUniqueKeyValues) (reverse rsD)
+ t = entityDef rs
+ toVals rv = map toPersistValue (toPersistFields rv)
+ (sql, params) = putManyHook t (map toVals rs)
+ rawExecute sql params
+
+---------------------------------------------------------------------------
+-- Hedis-style automatic pipelining primitives
+---------------------------------------------------------------------------
+
+-- | Send a SQL query into the pipeline buffer and pop a lazy reply.
+-- Does NOT flush or read -- the IO happens when the returned
+-- 'LibPQ.Result' thunk is forced.
+--
+-- Appends to 'pgPending' so commit\/rollback forces the result even
+-- if the caller discards the lazy thunk. Also returns the thunk
+-- to the caller for lazy decoding.
+pipelinedSend :: PgConn -> Text -> [PersistValue] -> IO LibPQ.Result
+pipelinedSend pc sql vals = do
+ let (rewrittenSQL, remainingVals) = inlineAndRewrite (T.encodeUtf8 sql) vals
+ params = map encodePersistValue remainingVals
+ ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary
+ unless ok $ do
+ merr <- LibPQ.errorMessage (pgConn pc)
+ throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr)
+ lazyReply <- pgRecvResult pc
+ modifyIORef (pgPending pc) (lazyReply :)
+ return lazyReply
+
+-- | Decode all rows from a lazy 'LibPQ.Result' into @[[PersistValue]]@
+-- and free the result (early free for better memory). Double-free is
+-- safe since 'LibPQ.Result' is a 'ForeignPtr'. The entire decode+free
+-- happens inside 'unsafeInterleaveIO', so calling this returns a lazy
+-- thunk.
+pipelinedDecodeRows :: LibPQ.Result -> IO [[PersistValue]]
+pipelinedDecodeRows lazyReply = unsafeInterleaveIO $ do
+ st <- LibPQ.resultStatus lazyReply
+ case st of
+ LibPQ.TuplesOk -> do
+ rowCount <- LibPQ.ntuples lazyReply
+ cols <- LibPQ.nfields lazyReply
+ oids <- ireplicateM cols $ \col -> do
+ oid <- LibPQ.ftype lazyReply col
+ return (col, oid)
+ rows <- forM [LibPQ.Row 0 .. rowCount - 1] $ \row ->
+ forM oids $ \(col, oid) -> do
+ mbs <- LibPQ.getvalue' lazyReply row col
+ case decodePersistValue oid mbs of
+ Left err -> throwIO $ PersistMarshalError err
+ Right val -> return val
+ LibPQ.unsafeFreeResult lazyReply
+ return rows
+ LibPQ.CommandOk -> do
+ LibPQ.unsafeFreeResult lazyReply
+ return []
+ _ -> do
+ merr <- LibPQ.resultErrorMessage lazyReply
+ LibPQ.unsafeFreeResult lazyReply
+ throwIO $ PipelineExecError
+ (fromMaybe "pipelined query error" merr)
+
+-- | Run a query through the pipeline and return decoded rows lazily.
+-- Combines 'pipelinedSend' + 'pipelinedDecodeRows'.
+pipelinedQuery :: PgConn -> Text -> [PersistValue] -> IO [[PersistValue]]
+pipelinedQuery pc sql vals = pipelinedSend pc sql vals >>= pipelinedDecodeRows
+
+-- | Decode a single entity from a lazy reply, returning lazily.
+pipelinedDecodeEntity
+ :: PersistEntity record
+ => EntityDef -> LibPQ.Result -> IO (Maybe record)
+pipelinedDecodeEntity t lazyReply = unsafeInterleaveIO $ do
+ rows <- pipelinedDecodeRows lazyReply
+ case rows of
+ [] -> return Nothing
+ (vals:_) -> case Util.parseEntityValues t vals of
+ Left err -> throwIO $ PersistMarshalError err
+ Right entity -> return (Just (entityVal entity))
+
+-- | Decode the first column of the first row as Int64, returning lazily.
+pipelinedDecodeInt64 :: LibPQ.Result -> IO (Maybe Int64)
+pipelinedDecodeInt64 lazyReply = unsafeInterleaveIO $ do
+ rows <- pipelinedDecodeRows lazyReply
+ case rows of
+ ([PersistInt64 n]:_) -> return (Just n)
+ _ -> return Nothing
+
+-- | Decode the first column of the first row as Bool, returning lazily.
+pipelinedDecodeBool :: LibPQ.Result -> IO (Maybe Bool)
+pipelinedDecodeBool lazyReply = unsafeInterleaveIO $ do
+ rows <- pipelinedDecodeRows lazyReply
+ case rows of
+ ([PersistBool b]:_) -> return (Just b)
+ _ -> return Nothing
+
+-- | Decode a key from the first row (RETURNING), returning lazily.
+pipelinedDecodeKey
+ :: PersistEntity record
+ => LibPQ.Result -> IO (Key record)
+pipelinedDecodeKey lazyReply = unsafeInterleaveIO $ do
+ rows <- pipelinedDecodeRows lazyReply
+ case rows of
+ ([PersistInt64 i]:_) -> case keyFromValues [PersistInt64 i] of
+ Left err -> throwIO $ PersistMarshalError err
+ Right k -> return k
+ (vals:_) -> case keyFromValues vals of
+ Left err -> throwIO $ PersistMarshalError err
+ Right k -> return k
+ _ -> throwIO $ PersistMarshalError "pipelinedDecodeKey: no key returned"
+
+-- | Get the 'PgConn' and 'SqlBackend' from a backend.
+-- Does NOT drain pending results -- callers that need prior writes
+-- to be visible should call 'drainPending' themselves.
+withPipelineConn
+ :: (BackendCompatible SqlBackend backend, MonadIO m)
+ => backend
+ -> (PgConn -> SqlBackend -> IO a)
+ -> m a
+withPipelineConn backend f = liftIO $ do
+ let conn = projectBackend backend :: SqlBackend
+ case Vault.lookup underlyingConnectionKey (connVault conn) of
+ Nothing -> fail "persistent-postgresql-ng: no PgConn"
+ Just pc -> f pc conn
+
+---------------------------------------------------------------------------
+-- Hedis-style pipelined operations
+---------------------------------------------------------------------------
+
+pipelinedGet
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ )
+ => Key record
+ -> ReaderT (WriteBackend PostgreSQLBackend) m (Maybe record)
+pipelinedGet k = do
+ backend <- ask
+ withPipelineConn backend $ \pc conn -> do
+ let t = entityDef (dummyFromKey k)
+ escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn
+ cols = T.intercalate ","
+ $ Foldable.toList
+ $ Util.keyAndEntityColumnNames t conn
+ wher = T.intercalate " AND "
+ $ Foldable.toList
+ $ fmap (<> "=? ")
+ $ Util.dbIdColumns conn t
+ sql = T.concat
+ [ "SELECT ", cols, " FROM ", escTbl t, " WHERE ", wher ]
+ lazyReply <- pipelinedSend pc sql (keyToValues k)
+ pipelinedDecodeEntity t lazyReply
+ where
+ dummyFromKey :: Key record -> Maybe record
+ dummyFromKey _ = Nothing
+
+pipelinedGetBy
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ )
+ => Unique record
+ -> ReaderT (WriteBackend PostgreSQLBackend) m (Maybe (Entity record))
+pipelinedGetBy uniq = do
+ backend <- ask
+ withPipelineConn backend $ \pc conn -> do
+ let t = entityDef (Nothing :: Maybe record)
+ escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn
+ escFld = Database.Persist.SqlBackend.Internal.connEscapeFieldName conn
+ cols = T.intercalate ","
+ $ Foldable.toList
+ $ Util.keyAndEntityColumnNames t conn
+ uniqs = persistUniqueToFieldNames uniq
+ wher = T.intercalate " AND "
+ $ Foldable.toList
+ $ fmap (\(_, n) -> escFld n <> "=?") uniqs
+ sql = T.concat
+ [ "SELECT ", cols, " FROM ", escTbl t, " WHERE ", wher ]
+ vals = persistUniqueToValues uniq
+ lazyReply <- pipelinedSend pc sql vals
+ unsafeInterleaveIO $ do
+ rows <- pipelinedDecodeRows lazyReply
+ case rows of
+ [] -> return Nothing
+ (v:_) -> case Util.parseEntityValues t v of
+ Left err -> throwIO $ PersistMarshalError err
+ Right entity -> return (Just entity)
+
+pipelinedInsert
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , SafeToInsert record
+ , MonadIO m
+ )
+ => record
+ -> ReaderT (WriteBackend PostgreSQLBackend) m (Key record)
+pipelinedInsert record = do
+ backend <- ask
+ withPipelineConn backend $ \pc conn -> do
+ let vals = Util.mkInsertValues record
+ case Database.Persist.SqlBackend.Internal.connInsertSql conn (entityDef (Just record)) vals of
+ ISRSingle sql -> pipelinedSend pc sql vals >>= pipelinedDecodeKey
+ ISRSingleCustom sql params -> pipelinedSend pc sql params >>= pipelinedDecodeKey
+ _ -> runReaderT (withBaseBackend $ insert record) backend
+
+pipelinedCount
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ )
+ => [Filter record]
+ -> ReaderT (WriteBackend PostgreSQLBackend) m Int
+pipelinedCount filts = do
+ backend <- ask
+ withPipelineConn backend $ \pc conn -> do
+ let t = entityDef (Nothing :: Maybe record)
+ escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn
+ (wher, vals) = if null filts
+ then ("", [])
+ else filterClauseWithVals (Just PrefixTableName) conn filts
+ sql = T.concat
+ [ "SELECT COUNT(*) FROM ", escTbl t, wher ]
+ lazyReply <- pipelinedSend pc sql vals
+ mn <- pipelinedDecodeInt64 lazyReply
+ return $ maybe 0 fromIntegral mn
+
+pipelinedExists
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ )
+ => [Filter record]
+ -> ReaderT (WriteBackend PostgreSQLBackend) m Bool
+pipelinedExists filts = do
+ backend <- ask
+ withPipelineConn backend $ \pc conn -> do
+ let t = entityDef (Nothing :: Maybe record)
+ escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn
+ (wher, vals) = if null filts
+ then ("", [])
+ else filterClauseWithVals (Just PrefixTableName) conn filts
+ sql = T.concat
+ [ "SELECT EXISTS(SELECT 1 FROM ", escTbl t, wher, ")" ]
+ lazyReply <- pipelinedSend pc sql vals
+ mb <- pipelinedDecodeBool lazyReply
+ return $ maybe False id mb
+
+pipelinedUpsertBy
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , SafeToInsert record
+ , MonadIO m
+ )
+ => Unique record
+ -> record
+ -> [Update record]
+ -> ReaderT (WriteBackend PostgreSQLBackend) m (Entity record)
+pipelinedUpsertBy uniqueKey record updates = do
+ backend <- ask
+ case updates of
+ [] -> withBaseBackend $ upsertBy uniqueKey record updates
+ _ -> withPipelineConn backend $ \pc conn -> do
+ let escFld = Database.Persist.SqlBackend.Internal.connEscapeFieldName conn
+ escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn
+ t = entityDef (Just record)
+ refCol n = T.concat [escTbl t, ".", n]
+ mkUpd = Util.mkUpdateText' escFld refCol
+ case Database.Persist.SqlBackend.Internal.connUpsertSql conn of
+ Just upsertSqlFn -> do
+ let upds = T.intercalate "," $ map mkUpd updates
+ rawSql_ = upsertSqlFn t (persistUniqueToFieldNames uniqueKey) upds
+ colList = T.intercalate ","
+ $ Foldable.toList
+ $ Util.keyAndEntityColumnNames t conn
+ sql = T.replace "??" colList rawSql_
+ vals = map toPersistValue (toPersistFields record)
+ ++ map Util.updatePersistValue updates
+ ++ concatMap persistUniqueToValues [uniqueKey]
+ lazyReply <- pipelinedSend pc sql vals
+ unsafeInterleaveIO $ do
+ rows <- pipelinedDecodeRows lazyReply
+ case rows of
+ (v:_) -> case Util.parseEntityValues t v of
+ Left err -> throwIO $ PersistMarshalError err
+ Right entity -> return entity
+ [] -> throwIO $ PersistMarshalError "pipelinedUpsertBy: no rows returned"
+ Nothing -> runReaderT (withBaseBackend $ upsertBy uniqueKey record updates) backend
+
+-- | Prepare a statement. Rewrites @?@ placeholders to @$1, $2, ...@
+-- and handles 'PersistLiteral_ Unescaped' values by inlining them
+-- into the SQL text.
+prepare' :: PgConn -> Text -> IO Statement
+prepare' pc sql = do
+ return
+ Statement
+ { stmtFinalize = return ()
+ , stmtReset = return ()
+ , stmtExecute = execute' pc sqlBS
+ , stmtQuery = withStmt' pc sqlBS
+ }
+ where
+ sqlBS = T.encodeUtf8 sql
+
+-- | Inline 'PersistLiteral_ Unescaped' values into the SQL text,
+-- removing them from the parameter list. Then rewrite remaining @?@
+-- placeholders to @$1, $2, ...@.
+inlineAndRewrite
+ :: ByteString
+ -> [PersistValue]
+ -> (ByteString, [PersistValue])
+inlineAndRewrite sql vals
+ | any isUnescaped vals =
+ let (sql1, vals1) = inlineUnescaped sql vals
+ in collapseAndRewrite sql1 vals1
+ | otherwise = collapseAndRewrite sql vals
+ where
+ isUnescaped (PersistLiteral_ Unescaped _) = True
+ isUnescaped _ = False
+
+-- | Fused IN-clause collapsing and @?@ → @$N@ placeholder rewriting in a
+-- single pass using a 'BB.Builder'. Combines the work of 'collapseInClauses'
+-- and 'rewritePlaceholders' to avoid an intermediate 'ByteString' allocation.
+collapseAndRewrite :: ByteString -> [PersistValue] -> (ByteString, [PersistValue])
+collapseAndRewrite sql params =
+ let (builder, _n, keptParams) = go sql params 0 mempty []
+ in (LBS.toStrict (BB.toLazyByteString builder), reverse keptParams)
+ where
+ go :: ByteString -> [PersistValue] -> Int -> BB.Builder -> [PersistValue]
+ -> (BB.Builder, Int, [PersistValue])
+ go !bs ps !n !acc keep
+ | B8.null bs = (acc, n, keep)
+ | otherwise =
+ let c = B8.head bs
+ rest = B8.tail bs
+ in case c of
+ '\'' -> let (lit, after) = skipStringLiteral rest
+ in go after ps n (acc <> BB.char8 '\'' <> BB.byteString lit) keep
+
+ '"' -> let (ident, after) = skipQuotedIdent rest
+ in go after ps n (acc <> BB.char8 '"' <> BB.byteString ident) keep
+
+ '-' | not (B8.null rest) && B8.head rest == '-' ->
+ let (comment, after) = B8.break (== '\n') rest
+ in go after ps n (acc <> BB.char8 '-' <> BB.byteString comment) keep
+
+ '/' | not (B8.null rest) && B8.head rest == '*' ->
+ let (comment, after) = skipBlockComment (B8.tail rest) 1
+ in go after ps n (acc <> BB.byteString "/*" <> BB.byteString comment) keep
+
+ -- ?? → literal ?
+ '?' | not (B8.null rest) && B8.head rest == '?' ->
+ go (B8.tail rest) ps n (acc <> BB.char8 '?') keep
+
+ -- ? → $N
+ '?' -> case ps of
+ [] -> go rest [] n (acc <> BB.char8 '?') keep
+ (p : ps') ->
+ let n' = n + 1
+ in go rest ps' n' (acc <> BB.char8 '$' <> BB.intDec n') (p : keep)
+
+ _ | matchesIn bs ->
+ tryCollapseIn bs ps n acc keep
+ | matchesNotIn bs ->
+ tryCollapseNotIn bs ps n acc keep
+ | otherwise ->
+ go rest ps n (acc <> BB.char8 c) keep
+
+ tryCollapseIn bs ps n acc keep =
+ let afterIn = B8.drop 5 bs
+ (nQMarks, afterParens) = countQMarks afterIn
+ in if nQMarks >= 2
+ then let (collapsed, remaining) = splitAt nQMarks ps
+ n' = n + 1
+ in go afterParens remaining n'
+ (acc <> BB.byteString " = ANY(" <> BB.char8 '$' <> BB.intDec n' <> BB.char8 ')')
+ (PersistArray collapsed : keep)
+ else go (B8.tail bs) ps n (acc <> BB.char8 (B8.head bs)) keep
+
+ tryCollapseNotIn bs ps n acc keep =
+ let afterNotIn = B8.drop 9 bs
+ (nQMarks, afterParens) = countQMarks afterNotIn
+ in if nQMarks >= 2
+ then let (collapsed, remaining) = splitAt nQMarks ps
+ n' = n + 1
+ in go afterParens remaining n'
+ (acc <> BB.byteString " <> ALL(" <> BB.char8 '$' <> BB.intDec n' <> BB.char8 ')')
+ (PersistArray collapsed : keep)
+ else go (B8.tail bs) ps n (acc <> BB.char8 (B8.head bs)) keep
+
+-- IN-clause detection helpers, shared by collapseInClauses and collapseAndRewrite.
+
+-- | Check if current position matches @\" IN (\"@ (case-insensitive).
+matchesIn :: ByteString -> Bool
+matchesIn bs =
+ B8.length bs >= 5
+ && isSqlSpace (B8.index bs 0)
+ && (B8.index bs 1 == 'I' || B8.index bs 1 == 'i')
+ && (B8.index bs 2 == 'N' || B8.index bs 2 == 'n')
+ && B8.index bs 3 == ' '
+ && B8.index bs 4 == '('
+
+-- | Check if current position matches @\" NOT IN (\"@ (case-insensitive).
+matchesNotIn :: ByteString -> Bool
+matchesNotIn bs =
+ B8.length bs >= 9
+ && isSqlSpace (B8.index bs 0)
+ && (B8.index bs 1 == 'N' || B8.index bs 1 == 'n')
+ && (B8.index bs 2 == 'O' || B8.index bs 2 == 'o')
+ && (B8.index bs 3 == 'T' || B8.index bs 3 == 't')
+ && B8.index bs 4 == ' '
+ && (B8.index bs 5 == 'I' || B8.index bs 5 == 'i')
+ && (B8.index bs 6 == 'N' || B8.index bs 6 == 'n')
+ && B8.index bs 7 == ' '
+ && B8.index bs 8 == '('
+
+-- | SQL whitespace: space, newline, carriage return, tab.
+isSqlSpace :: Char -> Bool
+isSqlSpace w = w == ' ' || w == '\n' || w == '\r' || w == '\t'
+{-# INLINE isSqlSpace #-}
+
+-- | Count consecutive @?@ marks separated by commas (with optional whitespace)
+-- inside parentheses. Input starts after the opening @(@.
+-- Returns @(count, rest after closing paren)@.
+-- Returns @(0, original input)@ to abort collapsing on unexpected content.
+countQMarks :: ByteString -> (Int, ByteString)
+countQMarks = countLoop 0
+ where
+ countLoop !n bs0
+ | B8.null bs0 = (n, bs0)
+ | otherwise =
+ let bs = B8.dropWhile isSqlSpace bs0
+ in if B8.null bs
+ then (n, bs)
+ else case B8.head bs of
+ '?'
+ | not (B8.null (B8.tail bs)) && B8.head (B8.tail bs) == '?' ->
+ -- ?? is a literal ?, not a param -- abort collapsing
+ (0, bs0)
+ | otherwise ->
+ let bs' = B8.tail bs
+ bs'' = B8.dropWhile isSqlSpace bs'
+ in if B8.null bs''
+ then (n + 1, bs'')
+ else case B8.head bs'' of
+ ',' -> countLoop (n + 1) (B8.tail bs'')
+ ')' -> (n + 1, B8.tail bs'')
+ _ -> (0, bs0) -- unexpected char, abort
+ ')' -> (n, B8.tail bs) -- empty parens or trailing
+ _ -> (0, bs0) -- unexpected char, abort
+
+-- | Rewrite @IN (?,?,?)@ to @= ANY(?)@ and @NOT IN (?,?,?)@ to
+-- @<> ALL(?)@, collapsing multiple individual parameters into a single
+-- 'PersistList' parameter. This reduces the number of bind parameters
+-- sent to PostgreSQL when 'PersistList' is encoded as a native array.
+--
+-- Only collapses when there are 2 or more @?@ placeholders in the
+-- IN clause. Single-element @IN (?)@ is left unchanged.
+collapseInClauses :: ByteString -> [PersistValue] -> (ByteString, [PersistValue])
+collapseInClauses sql params =
+ let (builder, keptParams) = go sql params mempty []
+ in (LBS.toStrict (BB.toLazyByteString builder), reverse keptParams)
+ where
+ go :: ByteString -> [PersistValue] -> BB.Builder -> [PersistValue]
+ -> (BB.Builder, [PersistValue])
+ go !bs ps !acc keep
+ | B8.null bs = (acc, keep)
+ | otherwise =
+ let c = B8.head bs
+ rest = B8.tail bs
+ in case c of
+ -- String literal: copy through (don't match IN inside strings)
+ '\'' -> let (lit, after) = skipStringLiteral rest
+ in go after ps (acc <> BB.char8 '\'' <> BB.byteString lit) keep
+
+ -- Quoted identifier: copy through
+ '"' -> let (ident, after) = skipQuotedIdent rest
+ in go after ps (acc <> BB.char8 '"' <> BB.byteString ident) keep
+
+ -- Line comment
+ '-' | not (B8.null rest) && B8.head rest == '-' ->
+ let (comment, after) = B8.break (== '\n') rest
+ in go after ps (acc <> BB.char8 '-' <> BB.byteString comment) keep
+
+ -- Block comment
+ '/' | not (B8.null rest) && B8.head rest == '*' ->
+ let (comment, after) = skipBlockComment (B8.tail rest) 1
+ in go after ps (acc <> BB.byteString "/*" <> BB.byteString comment) keep
+
+ -- ?? -> literal ?, not a parameter
+ '?' | not (B8.null rest) && B8.head rest == '?' ->
+ go (B8.tail rest) ps (acc <> BB.byteString "??") keep
+
+ -- ? -> regular parameter, just pass through
+ '?' -> case ps of
+ [] -> go rest [] (acc <> BB.char8 '?') keep
+ (p : ps') -> go rest ps' (acc <> BB.char8 '?') (p : keep)
+
+ -- Check for IN/NOT IN pattern
+ _ | matchesIn bs ->
+ tryCollapseIn bs ps acc keep
+ | matchesNotIn bs ->
+ tryCollapseNotIn bs ps acc keep
+ | otherwise ->
+ go rest ps (acc <> BB.char8 c) keep
+
+ tryCollapseIn bs ps acc keep =
+ let afterIn = B8.drop 5 bs -- skip " IN ("
+ (nQMarks, afterParens) = countQMarks afterIn
+ in if nQMarks >= 2
+ then let (collapsed, remaining) = splitAt nQMarks ps
+ in go afterParens remaining (acc <> BB.byteString " = ANY(?)") (PersistArray collapsed : keep)
+ else go (B8.tail bs) ps (acc <> BB.char8 (B8.head bs)) keep
+
+ tryCollapseNotIn bs ps acc keep =
+ let afterNotIn = B8.drop 9 bs -- skip " NOT IN ("
+ (nQMarks, afterParens) = countQMarks afterNotIn
+ in if nQMarks >= 2
+ then let (collapsed, remaining) = splitAt nQMarks ps
+ in go afterParens remaining (acc <> BB.byteString " <> ALL(?)") (PersistArray collapsed : keep)
+ else go (B8.tail bs) ps (acc <> BB.char8 (B8.head bs)) keep
+
+-- | Scan the SQL for @?@ placeholders and the parameter list for
+-- 'PersistLiteral_ Unescaped' values. When a @?@ corresponds to an
+-- Unescaped value, inline the literal bytes directly into the SQL text.
+-- Non-Unescaped values are kept in the output parameter list.
+--
+-- This handles @??@ (literal @?@ for column expansion) by passing it
+-- through without consuming a parameter.
+--
+-- Correctly skips string literals, quoted identifiers, line comments,
+-- and block comments.
+inlineUnescaped :: ByteString -> [PersistValue] -> (ByteString, [PersistValue])
+inlineUnescaped sql vals =
+ let (builder, keptParams) = go sql vals mempty []
+ in (LBS.toStrict (BB.toLazyByteString builder), reverse keptParams)
+ where
+ go :: ByteString -> [PersistValue] -> BB.Builder -> [PersistValue]
+ -> (BB.Builder, [PersistValue])
+ go bs params !acc keepParams
+ | B8.null bs = (acc, keepParams)
+ | otherwise =
+ let c = B8.head bs
+ rest = B8.tail bs
+ in case c of
+ -- String literal: copy through (properly handles '' escapes)
+ '\'' ->
+ let (lit, after) = skipStringLiteral rest
+ in go after params (acc <> BB.char8 '\'' <> BB.byteString lit) keepParams
+
+ -- Quoted identifier: copy through (properly handles "" escapes)
+ '"' ->
+ let (ident, after) = skipQuotedIdent rest
+ in go after params (acc <> BB.char8 '"' <> BB.byteString ident) keepParams
+
+ -- Line comment: copy through
+ '-' | not (B8.null rest) && B8.head rest == '-' ->
+ let (comment, after) = B8.break (== '\n') rest
+ in go after params (acc <> BB.char8 '-' <> BB.byteString comment) keepParams
+
+ -- Block comment: copy through
+ '/' | not (B8.null rest) && B8.head rest == '*' ->
+ let (comment, after) = skipBlockComment (B8.tail rest) 1
+ in go after params (acc <> BB.byteString "/*" <> BB.byteString comment) keepParams
+
+ -- Placeholder
+ '?' ->
+ -- ?? -> literal ?
+ if not (B8.null rest) && B8.head rest == '?'
+ then go (B8.tail rest) params (acc <> BB.byteString "??") keepParams
+ -- ? -> check if param is Unescaped
+ else case params of
+ [] -> go rest [] (acc <> BB.char8 '?') keepParams
+ (PersistLiteral_ Unescaped literal : ps) ->
+ go rest ps (acc <> BB.byteString literal) keepParams
+ (p : ps) ->
+ go rest ps (acc <> BB.char8 '?') (p : keepParams)
+
+ _ -> go rest params (acc <> BB.char8 c) keepParams
+
+-- Pipeline helpers
+-- These functions implement the libpq pipeline result protocol:
+-- Each query result: getResult → Just result, then getResult → Nothing (separator).
+-- PipelineSync result: getResult → Just PipelineSync (no NULL terminator).
+
+-- | Drain one result from the pipeline via the lazy reply stream.
+-- Pops and forces one element, frees the PGresult.
+-- Returns @Just errorMsg@ on FatalError, @Nothing@ otherwise.
+drainOneResult :: PgConn -> IO (Maybe ByteString)
+drainOneResult pc = do
+ ret <- pgRecvResult pc -- pop lazy thunk
+ -- Force the thunk (triggers flush+read if needed)
+ st <- LibPQ.resultStatus ret
+ err <- case st of
+ LibPQ.FatalError -> do
+ merr <- LibPQ.resultErrorMessage ret
+ return (Just (fromMaybe "unknown error" merr))
+ _ -> return Nothing
+ LibPQ.unsafeFreeResult ret
+ return err
+
+-- | Read one query result from the pipeline via the lazy reply stream.
+-- Pops and forces one element. Returns the Result (caller must free).
+-- Throws on error status.
+readOneQueryResult :: PgConn -> IO LibPQ.Result
+readOneQueryResult pc = do
+ ret <- pgRecvResult pc -- pop lazy thunk
+ -- Force the thunk
+ st <- LibPQ.resultStatus ret
+ case st of
+ LibPQ.CommandOk -> return ret
+ LibPQ.TuplesOk -> return ret
+ _ -> do
+ merr <- LibPQ.resultErrorMessage ret
+ LibPQ.unsafeFreeResult ret
+ syncPipeline pc
+ throwIO $ PipelineExecError (fromMaybe "query error in pipeline" merr)
+
+-- | Drain N pending results, collecting error messages. Does NOT throw.
+drainNResults :: PgConn -> Int -> IO [ByteString]
+drainNResults pc n = do
+ errs <- replicateM n (drainOneResult pc)
+ return [e | Just e <- errs]
+
+-- | Read a PipelineSync result. No NULL terminator follows it.
+drainSyncResult :: PgConn -> IO ()
+drainSyncResult pc = do
+ mret <- pgGetResult pc
+ case mret of
+ Nothing -> throwIO $ PipelineModeError "expected PipelineSync, got NULL"
+ Just ret -> do
+ st <- LibPQ.resultStatus ret
+ LibPQ.unsafeFreeResult ret
+ case st of
+ LibPQ.PipelineSync -> return ()
+ _ -> throwIO $ PipelineModeError $
+ "expected PipelineSync, got " ++ show st
+
+-- | Drain everything (results and NULL separators) until PipelineSync.
+-- Ignores all errors. Used by 'connRollback' and 'closePgConn'.
+drainToSync :: PgConn -> IO ()
+drainToSync pc = do
+ mret <- pgGetResult pc
+ case mret of
+ Nothing -> drainToSync pc -- NULL separator between results, keep reading
+ Just ret -> do
+ st <- LibPQ.resultStatus ret
+ LibPQ.unsafeFreeResult ret
+ case st of
+ LibPQ.PipelineSync -> return ()
+ _ -> drainToSync pc
+
+-- | Send a pipeline sync and drain to the sync point, clearing any
+-- aborted-pipeline state. Best-effort: if the sync itself fails to
+-- send we silently continue (the subsequent throw is more important).
+syncPipeline :: PgConn -> IO ()
+syncPipeline pc = do
+ ok <- pgPipelineSync pc
+ when ok $ drainToSync pc
+
+-- | Drain all pending fire-and-forget results. Throws if any failed.
+-- Used by 'withStmt'' before sending a query that needs results.
+--
+-- On error, a 'pipelineSync' is issued and drained before throwing so
+-- that the pipeline is left in a clean (non-aborted) state. This
+-- allows callers who catch 'PipelineExecError' to continue issuing
+-- commands (e.g. @ROLLBACK TO SAVEPOINT@) without hitting
+-- @PGRES_PIPELINE_ABORTED@.
+-- | Drain all pending results (both fire-and-forget and lazy reads).
+-- Forces each thunk (triggering flush+read via the lazy stream cascade),
+-- checks for errors, and frees the PGresult. Double-free is safe:
+-- 'LibPQ.Result' is a 'ForeignPtr', so 'unsafeFreeResult' is idempotent.
+-- If the caller's decode thunk already freed the result, this is a no-op.
+drainPending :: PgConn -> IO ()
+drainPending pc = do
+ pending <- atomicModifyIORef (pgPending pc) (\ps -> ([], ps))
+ case pending of
+ [] -> return ()
+ _ -> do
+ _ <- pgSendFlushRequest pc
+ pgFlush pc
+ errors <- forM (reverse pending) $ \ret -> do
+ st <- LibPQ.resultStatus ret
+ err <- case st of
+ LibPQ.FatalError -> do
+ merr <- LibPQ.resultErrorMessage ret
+ return (Just (fromMaybe "unknown error" merr))
+ _ -> return Nothing
+ LibPQ.unsafeFreeResult ret
+ return err
+ case [e | Just e <- errors] of
+ (e:_) -> do
+ syncPipeline pc
+ throwIO $ PipelineExecError e
+ [] -> return ()
+
+-- | Execute a statement in pipeline mode (fire-and-forget).
+-- Sends the query, pops a lazy reply from the stream, and appends
+-- it to 'pgPending'. The result is forced at the next drain point
+-- (read operation, commit, or rollback).
+execute' :: PgConn -> ByteString -> [PersistValue] -> IO Int64
+execute' pc sql vals = do
+ let (rewrittenSQL, remainingVals) = inlineAndRewrite sql vals
+ params = map encodePersistValue remainingVals
+ ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary
+ unless ok $ do
+ merr <- LibPQ.errorMessage (pgConn pc)
+ throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr)
+ lazyReply <- pgRecvResult pc
+ modifyIORef (pgPending pc) (lazyReply :)
+ return 0
+
+-- | Execute a query in pipeline mode and stream the results as a conduit.
+-- First drains all pending fire-and-forget results, then sends the query
+-- and reads its result.
+--
+-- The fetch strategy depends on 'pgFetchMode':
+--
+-- * 'FetchAll' (default): reads the entire result into one @PGresult@
+-- before yielding any rows. Memory = O(result set).
+--
+-- * 'FetchSingleRow': activates @PQsetSingleRowMode@ so that each
+-- @PGresult@ contains exactly one row. The conduit fetches, yields,
+-- and frees one @PGresult@ at a time. Memory = O(1 row).
+--
+-- * 'FetchChunked N': activates @PQsetChunkedRowsMode(N)@ (PG17+) so
+-- that each @PGresult@ contains up to N rows. Same streaming
+-- discipline as single-row mode but with lower per-row allocation
+-- overhead. Memory = O(N rows).
+withStmt'
+ :: (MonadIO m)
+ => PgConn
+ -> ByteString
+ -> [PersistValue]
+ -> Acquire (ConduitM () [PersistValue] m ())
+withStmt' pc sql vals =
+ case pgFetchMode pc of
+ FetchAll -> withStmtDefault pc sql vals
+ FetchSingleRow -> withStmtStreaming pc sql vals
+ FetchChunked _ -> withStmtStreaming pc sql vals
+
+-------------------------------------------------------------------------------
+-- Default (FetchAll): entire result in one PGresult
+-------------------------------------------------------------------------------
+
+-- | Default fetch: read entire result into memory, then stream rows from it.
+withStmtDefault
+ :: (MonadIO m)
+ => PgConn
+ -> ByteString
+ -> [PersistValue]
+ -> Acquire (ConduitM () [PersistValue] m ())
+withStmtDefault pc sql vals =
+ pull `fmap` mkAcquire openS closeS
+ where
+ openS = do
+ drainPending pc
+ let (rewrittenSQL, remainingVals) = inlineAndRewrite sql vals
+ params = map encodePersistValue remainingVals
+ ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary
+ unless ok $ do
+ merr <- LibPQ.errorMessage (pgConn pc)
+ throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr)
+ -- readOneQueryResult pops from the lazy reply stream.
+ -- The stream thunk handles flush+read internally.
+ ret <- readOneQueryResult pc
+ cols <- LibPQ.nfields ret
+ oids <- ireplicateM cols $ \col -> do
+ oid <- LibPQ.ftype ret col
+ return (col, oid)
+ rowCount <- LibPQ.ntuples ret
+ return (ret, rowCount, oids)
+
+ closeS (ret, _, _) = LibPQ.unsafeFreeResult ret
+
+ pull (ret, rowCount, oids) = go (LibPQ.Row 0)
+ where
+ go !row
+ | row == rowCount = return ()
+ | otherwise = do
+ vals' <- liftIO $ forM oids $ \(col, oid) -> do
+ mbs <- LibPQ.getvalue' ret row col
+ case decodePersistValue oid mbs of
+ Left err -> fail $ "decodePersistValue: " ++ T.unpack err
+ Right val -> return val
+ yield vals'
+ go (row + 1)
+
+-------------------------------------------------------------------------------
+-- Streaming (FetchSingleRow / FetchChunked): incremental PGresult fetching
+-------------------------------------------------------------------------------
+
+-- | Streaming fetch: the conduit drives libpq I/O, fetching one small
+-- @PGresult@ at a time, yielding its rows, freeing it, then fetching the
+-- next. At most one @PGresult@ is resident at any time.
+withStmtStreaming
+ :: (MonadIO m)
+ => PgConn
+ -> ByteString
+ -> [PersistValue]
+ -> Acquire (ConduitM () [PersistValue] m ())
+withStmtStreaming pc sql vals =
+ pull `fmap` mkAcquire openS closeS
+ where
+ openS = do
+ drainPending pc
+ let (rewrittenSQL, remainingVals) = inlineAndRewrite sql vals
+ params = map encodePersistValue remainingVals
+ ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary
+ unless ok $ do
+ merr <- LibPQ.errorMessage (pgConn pc)
+ throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr)
+ -- Activate row-fetch mode immediately after sendQueryParams
+ activateFetchMode pc
+ _ <- pgSendFlushRequest pc
+ pgFlush pc
+ -- Return an IORef that tracks whether we have fully consumed
+ -- the result stream. closeS uses this for cleanup.
+ doneRef <- newIORef False
+ return doneRef
+
+ closeS doneRef = do
+ done <- readIORef doneRef
+ unless done $ drainStreamRemainder pc
+
+ -- The conduit body: all libpq result I/O happens here, interleaved
+ -- with yield. This is what makes it a true streaming source.
+ pull doneRef = do
+ mfirst <- liftIO $ pgGetResult pc
+ case mfirst of
+ Nothing -> do
+ -- NULL right away = empty result (protocol: no rows)
+ liftIO $ writeIORef doneRef True
+ Just first -> do
+ st <- liftIO $ rawResultStatusInt first
+ -- Extract column metadata from the first result.
+ -- All subsequent results share the same column layout.
+ oids <- liftIO $ getColumnOids first
+ -- Process the first result and enter the streaming loop
+ case () of
+ _ | st == PGRES_SINGLE_TUPLE || st == PGRES_TUPLES_CHUNK -> do
+ yieldRowsAndFree first oids
+ streamLoop oids doneRef
+ | st == PGRES_TUPLES_OK -> do
+ -- Zero-row result: the query returned no data.
+ liftIO $ LibPQ.unsafeFreeResult first
+ -- Consume the NULL separator
+ _ <- liftIO $ pgGetResult pc
+ liftIO $ writeIORef doneRef True
+ | st == PGRES_FATAL_ERROR -> do
+ merr <- liftIO $ LibPQ.resultErrorMessage first
+ liftIO $ LibPQ.unsafeFreeResult first
+ -- Consume NULL separator
+ _ <- liftIO $ pgGetResult pc
+ liftIO $ syncPipeline pc
+ liftIO $ writeIORef doneRef True
+ liftIO $ throwIO $ PipelineExecError
+ (fromMaybe "query error in pipeline (streaming)" merr)
+ | otherwise -> do
+ liftIO $ LibPQ.unsafeFreeResult first
+ _ <- liftIO $ pgGetResult pc
+ liftIO $ writeIORef doneRef True
+ liftIO $ throwIO $ PipelineExecError $
+ "unexpected result status in streaming mode: "
+ <> B8.pack (show st)
+
+ -- Core streaming loop: fetch next PGresult, yield its rows, free it.
+ -- Each PGresult is freed BEFORE fetching the next one.
+ streamLoop oids doneRef = do
+ mret <- liftIO $ pgGetResult pc
+ case mret of
+ Nothing -> do
+ -- NULL separator: the query's result stream is fully consumed.
+ liftIO $ writeIORef doneRef True
+ Just ret -> do
+ st <- liftIO $ rawResultStatusInt ret
+ case () of
+ _ | st == PGRES_SINGLE_TUPLE || st == PGRES_TUPLES_CHUNK -> do
+ yieldRowsAndFree ret oids
+ streamLoop oids doneRef
+ | st == PGRES_TUPLES_OK -> do
+ -- End marker (0 rows). Free and consume NULL separator.
+ liftIO $ LibPQ.unsafeFreeResult ret
+ _ <- liftIO $ pgGetResult pc
+ liftIO $ writeIORef doneRef True
+ | st == PGRES_FATAL_ERROR -> do
+ merr <- liftIO $ LibPQ.resultErrorMessage ret
+ liftIO $ LibPQ.unsafeFreeResult ret
+ _ <- liftIO $ pgGetResult pc
+ liftIO $ syncPipeline pc
+ liftIO $ writeIORef doneRef True
+ liftIO $ throwIO $ PipelineExecError
+ (fromMaybe "query error in pipeline (streaming)" merr)
+ | otherwise -> do
+ liftIO $ LibPQ.unsafeFreeResult ret
+ _ <- liftIO $ pgGetResult pc
+ liftIO $ writeIORef doneRef True
+ liftIO $ throwIO $ PipelineExecError $
+ "unexpected result status in streaming mode: "
+ <> B8.pack (show st)
+
+-- | Extract column (index, OID) pairs from a result.
+getColumnOids :: LibPQ.Result -> IO [(LibPQ.Column, LibPQ.Oid)]
+getColumnOids ret = do
+ cols <- LibPQ.nfields ret
+ ireplicateM cols $ \col -> do
+ oid <- LibPQ.ftype ret col
+ return (col, oid)
+
+-- | Yield all rows from a @PGresult@ via the conduit, then free the result.
+yieldRowsAndFree
+ :: (MonadIO m)
+ => LibPQ.Result
+ -> [(LibPQ.Column, LibPQ.Oid)]
+ -> ConduitM () [PersistValue] m ()
+yieldRowsAndFree ret oids = do
+ rowCount <- liftIO $ LibPQ.ntuples ret
+ go (LibPQ.Row 0) rowCount
+ liftIO $ LibPQ.unsafeFreeResult ret
+ where
+ go !row rowCount
+ | row == rowCount = return ()
+ | otherwise = do
+ vals' <- liftIO $ forM oids $ \(col, oid) -> do
+ mbs <- LibPQ.getvalue' ret row col
+ case decodePersistValue oid mbs of
+ Left err -> fail $ "decodePersistValue: " ++ T.unpack err
+ Right val -> return val
+ yield vals'
+ go (row + 1) rowCount
+
+-- | Extract column metadata as a 'V.Vector' of @(Column, PgType)@ pairs
+-- for use with 'PgRowEnv'.
+getColumnPgTypes :: LibPQ.Result -> IO (V.Vector (LibPQ.Column, PgType))
+getColumnPgTypes ret = do
+ cols <- LibPQ.nfields ret
+ V.fromList <$> ireplicateM cols (\col -> do
+ oid <- LibPQ.ftype ret col
+ return (col, fromOid oid))
+
+-- | Yield all rows from a @PGresult@ decoded directly via 'FromRow',
+-- then free the result. No 'PersistValue' intermediary.
+--
+-- Uses 'prepareRow' to resolve all 'FieldRunner's once from column
+-- metadata, then applies the prepared 'RowDecoder' to each row.
+yieldRowsDirectAndFree
+ :: (FromRow PgRowEnv a, MonadIO m)
+ => LibPQ.Result
+ -> V.Vector (LibPQ.Column, PgType)
+ -> OidCache
+ -> ConduitM () a m ()
+yieldRowsDirectAndFree ret colTypes cache = do
+ rowCount <- liftIO $ LibPQ.ntuples ret
+ case rowCount of
+ LibPQ.Row 0 -> return ()
+ _ -> do
+ let metaEnv = PgRowEnv ret (LibPQ.Row 0) colTypes cache
+ ctr <- liftIO newCounter
+ decoder <- liftIO $ prepareRow metaEnv ctr
+ (\e -> throwIO $ PipelineExecError $ T.encodeUtf8 e)
+ pure
+ go decoder (LibPQ.Row 0) rowCount
+ liftIO $ LibPQ.unsafeFreeResult ret
+ where
+ go !decoder !row !rowCount
+ | row == rowCount = return ()
+ | otherwise = do
+ let env = PgRowEnv ret row colTypes cache
+ val <- liftIO $ runRowDecoderCPS decoder env
+ (throwIO . PipelineExecError . T.encodeUtf8)
+ pure
+ yield val
+ go decoder (row + 1) rowCount
+
+-- | Send a query, handling both @?@ and @$N@ parameter styles.
+--
+-- If the SQL contains @$N@ placeholders (detected by 'detectParamStyle'),
+-- it is passed through to libpq as-is with the parameter array indexed by
+-- @$N@ number. Otherwise, @?@ placeholders are rewritten to @$1, $2, ..@
+-- via 'inlineAndRewrite'.
+sendQueryWithParams
+ :: PgConn
+ -> ByteString
+ -> [PersistValue]
+ -> IO Bool
+sendQueryWithParams pc sql vals = case detectParamStyle sql of
+ QuestionMarkParams ->
+ let (rewritten, remaining) = inlineAndRewrite sql vals
+ params = map encodePersistValue remaining
+ in pgSendQueryParams pc rewritten params LibPQ.Binary
+ NumberedParams _maxN ->
+ let params = map encodePersistValue vals
+ in pgSendQueryParams pc sql params LibPQ.Binary
+
+---------------------------------------------------------------------------
+-- HasDirectQuery instance
+---------------------------------------------------------------------------
+
+instance HasDirectQuery (WriteBackend PostgreSQLBackend) where
+ type Env (WriteBackend PostgreSQLBackend) = PgRowEnv
+ type Param (WriteBackend PostgreSQLBackend) = PgParam
+ directQuerySource backend sql params =
+ case getPipelineConn backend of
+ Nothing -> error "persistent-postgresql-ng: no PgConn in vault"
+ Just pc ->
+ let sqlBS = T.encodeUtf8 sql
+ in pull `fmap` mkAcquire (openDQ pc sqlBS) closeDQ
+ where
+ paramList = V.toList params
+
+ openDQ pc sqlBS = do
+ drainPending pc
+ cache <- readIORef (pgOidCache pc)
+ ok <- sendDirectParams pc sqlBS paramList
+ unless ok $ do
+ merr <- LibPQ.errorMessage (pgConn pc)
+ throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr)
+ _ <- pgSendFlushRequest pc
+ pgFlush pc
+ ret <- readOneQueryResult pc
+ colTypes <- getColumnPgTypes ret
+ rowCount <- LibPQ.ntuples ret
+ return (ret, rowCount, colTypes, cache)
+
+ closeDQ (ret, _, _, _) = LibPQ.unsafeFreeResult ret
+
+ pull (ret, rowCount, colTypes, cache) = go (LibPQ.Row 0)
+ where
+ go !row
+ | row == rowCount = return ()
+ | otherwise = do
+ yield (PgRowEnv ret row colTypes cache)
+ go (row + 1)
+
+instance HasDirectQuery (ReadBackend PostgreSQLBackend) where
+ type Env (ReadBackend PostgreSQLBackend) = PgRowEnv
+ type Param (ReadBackend PostgreSQLBackend) = PgParam
+ directQuerySource backend sql params =
+ directQuerySource (writeFromRead backend) sql params
+ where
+ writeFromRead :: ReadBackend PostgreSQLBackend -> WriteBackend PostgreSQLBackend
+ writeFromRead (ReadBackend pg) = WriteBackend pg
+
+-- | Send a query with pre-encoded parameters (the direct-encode path).
+--
+-- Handles both @?@ and @$N@ placeholder styles: if @?@ placeholders are
+-- detected, they are rewritten to @$1, $2, ...@ via 'rewritePlaceholders'.
+-- Unlike 'sendQueryWithParams', no IN-clause collapsing or unescaped
+-- literal inlining is performed -- the direct path assumes the caller
+-- provides final SQL and properly typed parameters.
+sendDirectParams
+ :: PgConn
+ -> ByteString
+ -> [PgParam]
+ -> IO Bool
+sendDirectParams pc sql params = case detectParamStyle sql of
+ QuestionMarkParams ->
+ let (rewritten, _count) = rewritePlaceholders sql
+ in pgSendQueryParams pc rewritten (map pgParamToLibPQ params) LibPQ.Binary
+ NumberedParams _maxN ->
+ pgSendQueryParams pc sql (map pgParamToLibPQ params) LibPQ.Binary
+
+-- | Activate the appropriate row-fetch mode for the query just sent.
+-- Must be called immediately after 'pgSendQueryParams', before any
+-- other operation on the connection.
+activateFetchMode :: PgConn -> IO ()
+activateFetchMode pc =
+ case pgFetchMode pc of
+ FetchAll -> return ()
+ FetchSingleRow -> do
+ ok <- LibPQ.setSingleRowMode (pgConn pc)
+ unless ok $
+ throwIO $ PipelineExecError "failed to set single-row mode"
+ FetchChunked n -> do
+ ok <- setChunkedRowsMode (pgConn pc) n
+ unless ok $
+ throwIO $ PipelineExecError "failed to set chunked rows mode"
+
+-- | Drain any remaining results from a streaming query. Used by closeS
+-- when the conduit consumer stops pulling before the result stream is
+-- exhausted (e.g. @take 5@ on a large query).
+drainStreamRemainder :: PgConn -> IO ()
+drainStreamRemainder pc = go
+ where
+ go = do
+ mret <- pgGetResult pc
+ case mret of
+ Nothing -> return () -- NULL separator reached, done
+ Just ret -> do
+ st <- rawResultStatusInt ret
+ LibPQ.unsafeFreeResult ret
+ case () of
+ _ | st == PGRES_SINGLE_TUPLE || st == PGRES_TUPLES_CHUNK ->
+ go -- more data, keep draining
+ | st == PGRES_TUPLES_OK -> do
+ -- End marker. Consume the NULL separator.
+ _ <- pgGetResult pc
+ return ()
+ | otherwise ->
+ -- Error or unexpected status; consume NULL separator
+ -- and stop. Pipeline sync will be handled elsewhere.
+ do _ <- pgGetResult pc
+ return ()
+
+-- Index-based monadic iteration, avoiding intermediate list allocation.
+
+ireplicateM :: (Monad m, Eq i, Num i) => i -> (i -> m a) -> m [a]
+ireplicateM n f = go 0
+ where
+ go i
+ | i == n = pure []
+ | otherwise = f i >>= \a -> (a :) <$> go (i + 1)
+{-# INLINABLE ireplicateM #-}
+
+-- SQL generation functions, adapted from existing persistent-postgresql
+
+insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult
+insertSql' ent vals =
+ case getEntityId ent of
+ EntityIdNaturalKey _pdef ->
+ ISRManyKeys sql vals
+ EntityIdField field ->
+ ISRSingle (sql <> " RETURNING " <> escapeF (fieldDB field))
+ where
+ (fieldNames, placeholders) = unzip (Util.mkInsertPlaceholders ent escapeF)
+ sql =
+ T.concat
+ [ "INSERT INTO "
+ , escapeE $ getEntityDBName ent
+ , if null (getEntityFields ent)
+ then " DEFAULT VALUES"
+ else
+ T.concat
+ [ "("
+ , T.intercalate "," fieldNames
+ , ") VALUES("
+ , T.intercalate "," placeholders
+ , ")"
+ ]
+ ]
+
+upsertSql' :: EntityDef -> NonEmpty (FieldNameHS, FieldNameDB) -> Text -> Text
+upsertSql' ent uniqs updateVal =
+ T.concat
+ [ "INSERT INTO "
+ , escapeE (getEntityDBName ent)
+ , "("
+ , T.intercalate "," fieldNames
+ , ") VALUES ("
+ , T.intercalate "," placeholders
+ , ") ON CONFLICT ("
+ , T.intercalate "," $ map (escapeF . snd) (NEL.toList uniqs)
+ , ") DO UPDATE SET "
+ , updateVal
+ , " WHERE "
+ , wher
+ , " RETURNING ??"
+ ]
+ where
+ (fieldNames, placeholders) = unzip (Util.mkInsertPlaceholders ent escapeF)
+ wher = T.intercalate " AND " $ map (singleClause . snd) $ NEL.toList uniqs
+ singleClause :: FieldNameDB -> Text
+ singleClause field = escapeE (getEntityDBName ent) <> "." <> (escapeF field) <> " =?"
+
+insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult
+insertManySql' ent valss =
+ ISRSingleCustom sql params
+ where
+ fields = filter isFieldNotGenerated (getEntityFields ent)
+ fieldNames = map (escapeF . fieldDB) fields
+ typeCasts = map (unnestTypeCast . fieldSqlType) fields
+ unnestArgs = map ("?" <>) typeCasts
+ sql =
+ T.concat
+ [ "INSERT INTO "
+ , escapeE (getEntityDBName ent)
+ , "("
+ , T.intercalate "," fieldNames
+ , ") SELECT * FROM UNNEST("
+ , T.intercalate "," unnestArgs
+ , ") RETURNING "
+ , Util.commaSeparated $ NEL.toList $ Util.dbIdColumnsEsc escapeF ent
+ ]
+ params = map PersistArray (transpose valss)
+
+-- | Hook for 'connInsertMany_Custom': UNNEST-based bulk insert (no RETURNING).
+insertMany_Hook :: EntityDef -> [[PersistValue]] -> (Text, [PersistValue])
+insertMany_Hook ent valss = (unnestInsertSql ent False, map PersistArray (transpose valss))
+
+-- | Hook for 'connPutManyCustom': UNNEST-based bulk upsert.
+putManyHook :: EntityDef -> [[PersistValue]] -> (Text, [PersistValue])
+putManyHook ent valss =
+ (sql, map PersistArray (transpose valss))
+ where
+ fields = filter isFieldNotGenerated (getEntityFields ent)
+ fieldDbToText = escapeF . fieldDB
+ table = escapeE (getEntityDBName ent)
+ columns = map fieldDbToText fields
+ typeCasts = map (unnestTypeCast . fieldSqlType) fields
+ unnestArgs = map ("?" <>) typeCasts
+ conflictColumns =
+ concatMap
+ (map (escapeF . snd) . NEL.toList . uniqueFields)
+ (getEntityUniques ent)
+ updates = map (\c -> c <> "=EXCLUDED." <> c) columns
+ sql = T.concat
+ [ "INSERT INTO ", table
+ , "(", T.intercalate "," columns, ")"
+ , " SELECT * FROM UNNEST("
+ , T.intercalate "," unnestArgs
+ , ")"
+ , " ON CONFLICT ("
+ , T.intercalate "," conflictColumns
+ , ") DO UPDATE SET "
+ , T.intercalate "," updates
+ ]
+
+-- | Hook for 'connRepsertManyCustom': UNNEST-based bulk repsert (conflict on PK).
+repsertManyHook :: EntityDef -> [[PersistValue]] -> (Text, [PersistValue])
+repsertManyHook ent valss =
+ (sql, map PersistArray (transpose valss))
+ where
+ allFields = NEL.toList $ keyAndEntityFields ent
+ fields = filter isFieldNotGenerated allFields
+ fieldDbToText = escapeF . fieldDB
+ table = escapeE (getEntityDBName ent)
+ columns = map fieldDbToText fields
+ typeCasts = map (unnestTypeCast . fieldSqlType) fields
+ unnestArgs = map ("?" <>) typeCasts
+ conflictColumns =
+ NEL.toList $ escapeF . fieldDB <$> getEntityKeyFields ent
+ keyFieldNames = map (escapeF . fieldDB) (NEL.toList $ getEntityKeyFields ent)
+ updates =
+ map (\c -> c <> "=EXCLUDED." <> c) $
+ filter (`notElem` keyFieldNames) columns
+ sql = T.concat
+ [ "INSERT INTO ", table
+ , "(", T.intercalate "," columns, ")"
+ , " SELECT * FROM UNNEST("
+ , T.intercalate "," unnestArgs
+ , ")"
+ , " ON CONFLICT ("
+ , T.intercalate "," conflictColumns
+ , ") DO UPDATE SET "
+ , case updates of
+ [] -> case columns of
+ (c:_) -> c <> "=" <> table <> "." <> c
+ _ -> error "repsertManyHook: entity has no columns"
+ _ -> T.intercalate "," updates
+ ]
+
+migrate'
+ :: BackendSpecificOverrides
+ -> [EntityDef]
+ -> (Text -> IO Statement)
+ -> EntityDef
+ -> IO (Either [Text] CautiousMigration)
+migrate' overrides allDefs getter entity =
+ fmap (fmap $ map showAlterDb) $
+ migrateStructured overrides allDefs getter entity
+
+-- | Version comparison for upsert support (>= 9.5)
+upsertFunction :: a -> NonEmpty Word -> Maybe a
+upsertFunction f version =
+ if version >= postgres9dot5
+ then Just f
+ else Nothing
+ where
+ postgres9dot5 :: NonEmpty Word
+ postgres9dot5 = 9 NEL.:| [5]
+
+-- putManySql / repsertManySql
+
+putManySql :: EntityDef -> Int -> Text
+putManySql ent n = putManySql' conflictColumns fields ent n
+ where
+ fields = getEntityFields ent
+ conflictColumns =
+ concatMap
+ (map (escapeF . snd) . NEL.toList . uniqueFields)
+ (getEntityUniques ent)
+
+repsertManySql :: EntityDef -> Int -> Text
+repsertManySql ent n = putManySql' conflictColumns fields ent n
+ where
+ fields = NEL.toList $ keyAndEntityFields ent
+ conflictColumns = NEL.toList $ escapeF . fieldDB <$> getEntityKeyFields ent
+
+putManySql' :: [Text] -> [FieldDef] -> EntityDef -> Int -> Text
+putManySql' conflictColumns (filter isFieldNotGenerated -> fields) ent n = q
+ where
+ fieldDbToText = escapeF . fieldDB
+ mkAssignment f = T.concat [f, "=EXCLUDED.", f]
+
+ table = escapeE . getEntityDBName $ ent
+ columns = Util.commaSeparated $ map fieldDbToText fields
+ placeholders = map (const "?") fields
+ updates = map (mkAssignment . fieldDbToText) fields
+
+ q =
+ T.concat
+ [ "INSERT INTO "
+ , table
+ , Util.parenWrapped columns
+ , " VALUES "
+ , Util.commaSeparated
+ . replicate n
+ . Util.parenWrapped
+ . Util.commaSeparated
+ $ placeholders
+ , " ON CONFLICT "
+ , Util.parenWrapped . Util.commaSeparated $ conflictColumns
+ , " DO UPDATE SET "
+ , Util.commaSeparated updates
+ ]
+
+-------------------------------------------------------------------------------
+-- Optimized batch operations
+--
+-- These standalone functions generate more efficient SQL than the default
+-- persistent typeclass methods. They use PostgreSQL-specific features
+-- (= ANY for key lookups, UNNEST for columnar inserts) to minimize bind
+-- parameters and improve query plan caching.
+-------------------------------------------------------------------------------
+
+-- | Retrieve multiple records by key using a single array parameter.
+--
+-- Generates @SELECT ... WHERE id = ANY($1)@ instead of the default
+-- @WHERE id=? OR id=? OR ...@, reducing bind parameters from N to 1
+-- and improving query plan caching (the query shape is constant
+-- regardless of how many keys are requested).
+--
+-- Only supports single-column keys (the common auto-increment case).
+-- For composite keys, falls back to the default 'getMany'.
+getManyKeys
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ )
+ => [Key record]
+ -> ReaderT SqlBackend m (Map.Map (Key record) record)
+getManyKeys [] = return Map.empty
+getManyKeys ks = do
+ let t = entityDef (Nothing :: Maybe record)
+ idCols = NEL.toList $ Util.dbIdColumnsEsc escapeF t
+ case idCols of
+ [idCol] -> do
+ let sql = T.concat
+ [ "SELECT ?? FROM "
+ , escapeE (getEntityDBName t)
+ , " WHERE "
+ , idCol
+ , " = ANY(?)"
+ ]
+ keyVals = map (headErr . keyToValues) ks
+ entities <- rawSql sql [PersistArray keyVals]
+ return $ Map.fromList $
+ map (\(e :: Entity record) -> (entityKey e, entityVal e)) entities
+ _ ->
+ -- Composite key: fall back to persistent's default OR-based approach
+ getMany ks
+ where
+ headErr [] = error "getManyKeys: keyToValues returned empty list"
+ headErr (x:_) = x
+
+-- | Delete multiple records by key using a single array parameter.
+--
+-- Generates @DELETE FROM ... WHERE id = ANY($1)@ instead of issuing
+-- N individual @DELETE@ statements. In pipeline mode this is
+-- fire-and-forget, so N deletes that would normally be N pipelined
+-- commands become a single command.
+--
+-- Only supports single-column keys. For composite keys, falls back
+-- to individual deletes.
+deleteManyKeys
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ )
+ => [Key record]
+ -> ReaderT SqlBackend m ()
+deleteManyKeys [] = return ()
+deleteManyKeys ks = do
+ let t = entityDef (Nothing :: Maybe record)
+ idCols = NEL.toList $ Util.dbIdColumnsEsc escapeF t
+ case idCols of
+ [idCol] -> do
+ let sql = T.concat
+ [ "DELETE FROM "
+ , escapeE (getEntityDBName t)
+ , " WHERE "
+ , idCol
+ , " = ANY(?)"
+ ]
+ keyVals = map (headErr . keyToValues) ks
+ rawExecute sql [PersistArray keyVals]
+ _ ->
+ mapM_ delete ks
+ where
+ headErr [] = error "deleteManyKeys: keyToValues returned empty list"
+ headErr (x:_) = x
+
+-- | Insert multiple records using PostgreSQL @UNNEST@ for columnar
+-- parameter binding. Returns the generated keys.
+--
+-- Generates:
+--
+-- @
+-- INSERT INTO t (col1, col2, ...)
+-- SELECT * FROM UNNEST(?::type1[], ?::type2[], ...)
+-- RETURNING id
+-- @
+--
+-- This uses M array parameters (one per column) instead of N*M scalar
+-- parameters (one per field per row), avoiding PostgreSQL's per-query
+-- parameter limit and improving throughput for large batches.
+insertManyUnnest
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ , SafeToInsert record
+ )
+ => [record]
+ -> ReaderT SqlBackend m [Key record]
+insertManyUnnest [] = return []
+insertManyUnnest records = do
+ let t = entityDef (Nothing :: Maybe record)
+ sql = unnestInsertSql t True
+ params = unnestInsertParams records
+ rawSql sql params
+
+-- | Like 'insertManyUnnest' but does not return keys.
+--
+-- In pipeline mode this is fire-and-forget: the @INSERT@ is sent without
+-- waiting for a response, and errors surface at the next drain point.
+insertManyUnnest_
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ , SafeToInsert record
+ )
+ => [record]
+ -> ReaderT SqlBackend m ()
+insertManyUnnest_ [] = return ()
+insertManyUnnest_ records = do
+ let t = entityDef (Nothing :: Maybe record)
+ sql = unnestInsertSql t False
+ params = unnestInsertParams records
+ rawExecute sql params
+
+-- | Upsert multiple records using @UNNEST@ + @ON CONFLICT DO UPDATE SET@.
+--
+-- Generates:
+--
+-- @
+-- INSERT INTO t (cols)
+-- SELECT * FROM UNNEST(?::type1[], ?::type2[], ...)
+-- ON CONFLICT (unique_cols) DO UPDATE SET col1=EXCLUDED.col1, ...
+-- @
+--
+-- Uses M array parameters (one per column) for the entire batch.
+-- Falls back to 'insertManyUnnest_' if the entity has no unique
+-- constraints.
+putManyUnnest
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ , SafeToInsert record
+ )
+ => [record]
+ -> ReaderT SqlBackend m ()
+putManyUnnest [] = return ()
+putManyUnnest records = do
+ let t = entityDef (Nothing :: Maybe record)
+ uniqs = getEntityUniques t
+ case uniqs of
+ [] -> insertManyUnnest_ records
+ _ -> do
+ let fields = filter isFieldNotGenerated (getEntityFields t)
+ fieldDbToText = escapeF . fieldDB
+ table = escapeE (getEntityDBName t)
+ columns = map fieldDbToText fields
+ typeCasts = map (unnestTypeCast . fieldSqlType) fields
+ unnestArgs = map ("?" <>) typeCasts
+ conflictColumns =
+ concatMap
+ (map (escapeF . snd) . NEL.toList . uniqueFields)
+ uniqs
+ updates = map (\c -> c <> "=EXCLUDED." <> c) columns
+ sql = T.concat
+ [ "INSERT INTO ", table
+ , "(", T.intercalate "," columns, ")"
+ , " SELECT * FROM UNNEST("
+ , T.intercalate "," unnestArgs
+ , ")"
+ , " ON CONFLICT ("
+ , T.intercalate "," conflictColumns
+ , ") DO UPDATE SET "
+ , T.intercalate "," updates
+ ]
+ params = unnestInsertParams records
+ rawExecute sql params
+
+-- | Repsert (replace-or-insert) multiple records using @UNNEST@ +
+-- @ON CONFLICT (primary key) DO UPDATE SET@.
+--
+-- Unlike 'putManyUnnest' which conflicts on unique constraints, this
+-- conflicts on the primary key columns and includes them in the
+-- @INSERT@.
+--
+-- Uses M array parameters (one per column) for the entire batch.
+repsertManyUnnest
+ :: forall record m
+ . ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ )
+ => [(Key record, record)]
+ -> ReaderT SqlBackend m ()
+repsertManyUnnest [] = return ()
+repsertManyUnnest krs = do
+ let t = entityDef (Nothing :: Maybe record)
+ allFields = NEL.toList $ keyAndEntityFields t
+ fields = filter isFieldNotGenerated allFields
+ fieldDbToText = escapeF . fieldDB
+ table = escapeE (getEntityDBName t)
+ columns = map fieldDbToText fields
+ typeCasts = map (unnestTypeCast . fieldSqlType) fields
+ unnestArgs = map (\tc -> "?" <> tc) typeCasts
+ conflictColumns =
+ NEL.toList $ escapeF . fieldDB <$> getEntityKeyFields t
+ -- Update all non-key fields on conflict
+ keyFieldNames = map fieldDB (NEL.toList $ getEntityKeyFields t)
+ updates =
+ map (\c -> c <> "=EXCLUDED." <> c) $
+ filter (\c -> escapeF `fmap` (lookup' c keyFieldNames) == Nothing) columns
+ sql = T.concat
+ [ "INSERT INTO ", table
+ , "(", T.intercalate "," columns, ")"
+ , " SELECT * FROM UNNEST("
+ , T.intercalate "," unnestArgs
+ , ") ON CONFLICT ("
+ , T.intercalate "," conflictColumns
+ , ") DO UPDATE SET "
+ , case updates of
+ [] -> case columns of
+ (c:_) -> c <> "=" <> table <> "." <> c
+ [] -> error "repsertManyUnnest: entity has no columns"
+ _ -> T.intercalate "," updates
+ ]
+ -- Build values: for each (key, record) pair, produce key columns ++ entity columns
+ -- matching keyAndEntityFields order (minus generated columns).
+ rowValues = map (repsertRowValues t) krs
+ params = map PersistArray (transpose rowValues)
+ rawExecute sql params
+ where
+ -- Check if any FieldNameDB matches (by comparing the raw DB name)
+ lookup' :: Text -> [FieldNameDB] -> Maybe FieldNameDB
+ lookup' _ [] = Nothing
+ lookup' col (fn:fns)
+ | escapeF fn == col = Just fn
+ | otherwise = lookup' col fns
+
+ repsertRowValues :: EntityDef -> (Key record, record) -> [PersistValue]
+ repsertRowValues ent (k, r) =
+ case entityPrimary ent of
+ Nothing -> keyToValues k <> Util.mkInsertValues r
+ Just _ -> Util.mkInsertValues r
+
+-- | Generate @INSERT ... SELECT * FROM UNNEST(...)@ SQL, optionally
+-- with a @RETURNING@ clause.
+unnestInsertSql :: EntityDef -> Bool -> Text
+unnestInsertSql t withReturning =
+ T.concat
+ [ "INSERT INTO "
+ , escapeE (getEntityDBName t)
+ , "("
+ , T.intercalate "," fieldNames
+ , ") SELECT * FROM UNNEST("
+ , T.intercalate "," unnestArgs
+ , ")"
+ , if withReturning
+ then " RETURNING "
+ <> (T.intercalate "," . NEL.toList $
+ Util.dbIdColumnsEsc escapeF t)
+ else ""
+ ]
+ where
+ fields = filter isFieldNotGenerated (getEntityFields t)
+ fieldNames = map (escapeF . fieldDB) fields
+ typeCasts = map (unnestTypeCast . fieldSqlType) fields
+ unnestArgs = map (\tc -> "?" <> tc) typeCasts
+
+-- | Transpose record values from row-major to column-major and wrap
+-- each column in a 'PersistArray'.
+unnestInsertParams :: (PersistEntity record) => [record] -> [PersistValue]
+unnestInsertParams records =
+ let rowValues = map Util.mkInsertValues records
+ in map PersistArray (transpose rowValues)
+
+-- | Map a 'SqlType' to a PostgreSQL array type cast for @UNNEST@
+-- parameters, e.g. @SqlInt64@ becomes @\"::int8[]\"@.
+unnestTypeCast :: SqlType -> Text
+unnestTypeCast SqlString = "::varchar[]"
+unnestTypeCast SqlInt32 = "::int4[]"
+unnestTypeCast SqlInt64 = "::int8[]"
+unnestTypeCast SqlReal = "::float8[]"
+unnestTypeCast (SqlNumeric _ _) = "::numeric[]"
+unnestTypeCast SqlBool = "::boolean[]"
+unnestTypeCast SqlDay = "::date[]"
+unnestTypeCast SqlTime = "::time[]"
+unnestTypeCast SqlDayTime = "::timestamptz[]"
+unnestTypeCast SqlBlob = "::bytea[]"
+unnestTypeCast (SqlOther t) = "::" <> t <> "[]"
+
+-- | Get the SQL string for the table that a PersistEntity represents.
+tableName :: (PersistEntity record) => record -> Text
+tableName = escapeE . tableDBName
+
+-- | Get the SQL string for the field that an EntityField represents.
+fieldName :: (PersistEntity record) => EntityField record typ -> Text
+fieldName = escapeF . fieldDBName
+
+-- HandleUpdateCollision and upsert functions
+
+-- | This type is used to determine how to update rows using Postgres'
+-- @INSERT ... ON CONFLICT KEY UPDATE@ functionality.
+data HandleUpdateCollision record where
+ CopyField :: EntityField record typ -> HandleUpdateCollision record
+ CopyUnlessEq
+ :: (PersistField typ)
+ => EntityField record typ
+ -> typ
+ -> HandleUpdateCollision record
+
+-- | Copy the field into the database only if the value in the
+-- corresponding record is non-@NULL@.
+copyUnlessNull
+ :: (PersistField typ)
+ => EntityField record (Maybe typ)
+ -> HandleUpdateCollision record
+copyUnlessNull field = CopyUnlessEq field Nothing
+
+-- | Copy the field into the database only if the value in the
+-- corresponding record is non-empty.
+copyUnlessEmpty
+ :: (Monoid.Monoid typ, PersistField typ)
+ => EntityField record typ
+ -> HandleUpdateCollision record
+copyUnlessEmpty field = CopyUnlessEq field Monoid.mempty
+
+-- | Copy the field into the database only if the field is not equal to the
+-- provided value.
+copyUnlessEq
+ :: (PersistField typ)
+ => EntityField record typ
+ -> typ
+ -> HandleUpdateCollision record
+copyUnlessEq = CopyUnlessEq
+
+-- | Copy the field directly from the record.
+copyField
+ :: (PersistField typ) => EntityField record typ -> HandleUpdateCollision record
+copyField = CopyField
+
+upsertWhere
+ :: ( backend ~ PersistEntityBackend record
+ , PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , MonadIO m
+ , PersistStore backend
+ , BackendCompatible SqlBackend backend
+ , OnlyOneUniqueKey record
+ )
+ => record
+ -> [Update record]
+ -> [Filter record]
+ -> ReaderT backend m ()
+upsertWhere record updates filts =
+ upsertManyWhere [record] [] updates filts
+
+upsertManyWhere
+ :: forall record backend m
+ . ( backend ~ PersistEntityBackend record
+ , BackendCompatible SqlBackend backend
+ , PersistEntityBackend record ~ SqlBackend
+ , PersistEntity record
+ , OnlyOneUniqueKey record
+ , MonadIO m
+ )
+ => [record]
+ -> [HandleUpdateCollision record]
+ -> [Update record]
+ -> [Filter record]
+ -> ReaderT backend m ()
+upsertManyWhere [] _ _ _ = return ()
+upsertManyWhere records fieldValues updates filters = do
+ conn <- asks projectBackend
+ let
+ uniqDef = onlyOneUniqueDef (Proxy :: Proxy record)
+ uncurry rawExecute $
+ mkBulkUpsertQuery records conn fieldValues updates filters uniqDef
+
+excludeNotEqualToOriginal
+ :: (PersistField typ, PersistEntity rec)
+ => EntityField rec typ
+ -> Filter rec
+excludeNotEqualToOriginal field =
+ Filter
+ { filterField = field
+ , filterFilter = Ne
+ , filterValue =
+ UnsafeValue $
+ PersistLiteral_
+ Unescaped
+ bsForExcludedField
+ }
+ where
+ bsForExcludedField =
+ T.encodeUtf8 $
+ "EXCLUDED." <> fieldName field
+
+mkBulkUpsertQuery
+ :: ( PersistEntity record
+ , PersistEntityBackend record ~ SqlBackend
+ , OnlyOneUniqueKey record
+ )
+ => [record]
+ -> SqlBackend
+ -> [HandleUpdateCollision record]
+ -> [Update record]
+ -> [Filter record]
+ -> UniqueDef
+ -> (Text, [PersistValue])
+mkBulkUpsertQuery records conn fieldValues updates filters uniqDef =
+ (q, recordValues <> updsValues <> copyUnlessValues <> whereVals)
+ where
+ mfieldDef x = case x of
+ CopyField rec -> Right (fieldDbToText (persistFieldDef rec))
+ CopyUnlessEq rec val -> Left (fieldDbToText (persistFieldDef rec), toPersistValue val)
+ (fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues
+ fieldDbToText = escapeF . fieldDB
+ entityDef' = entityDef records
+ conflictColumns =
+ map (escapeF . snd) $ NEL.toList $ uniqueFields uniqDef
+ firstField = case entityFieldNames of
+ [] -> error "The entity you're trying to insert does not have any fields."
+ (field : _) -> field
+ entityFieldNames = map fieldDbToText (getEntityFields entityDef')
+ nameOfTable = escapeE . getEntityDBName $ entityDef'
+ copyUnlessValues = map snd fieldsToMaybeCopy
+ recordValues = concatMap (map toPersistValue . toPersistFields) records
+ recordPlaceholders =
+ Util.commaSeparated
+ $ map
+ (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields)
+ $ records
+ mkCondFieldSet n _ =
+ T.concat
+ [ n
+ , "=COALESCE("
+ , "NULLIF("
+ , "EXCLUDED."
+ , n
+ , ","
+ , "?"
+ , ")"
+ , ","
+ , nameOfTable
+ , "."
+ , n
+ , ")"
+ ]
+ condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy
+ fieldSets = map (\n -> T.concat [n, "=EXCLUDED.", n]) updateFieldNames
+ upds =
+ map
+ (Util.mkUpdateText' escapeF (\n -> T.concat [nameOfTable, ".", n]))
+ updates
+ updsValues = map (\(Update _ val _) -> toPersistValue val) updates
+ (wher, whereVals) =
+ if null filters
+ then ("", [])
+ else (filterClauseWithVals (Just PrefixTableName) conn filters)
+ updateText =
+ case fieldSets <> upds <> condFieldSets of
+ [] ->
+ T.concat [firstField, "=", nameOfTable, ".", firstField]
+ xs ->
+ Util.commaSeparated xs
+ q =
+ T.concat
+ [ "INSERT INTO "
+ , nameOfTable
+ , Util.parenWrapped . Util.commaSeparated $ entityFieldNames
+ , " VALUES "
+ , recordPlaceholders
+ , " ON CONFLICT "
+ , Util.parenWrapped $ Util.commaSeparated $ conflictColumns
+ , " DO UPDATE SET "
+ , updateText
+ , wher
+ ]
+
+-- | Enable a Postgres extension.
+migrateEnableExtension :: Text -> Migration
+migrateEnableExtension extName = WriterT $ WriterT $ do
+ res :: [Single Int] <-
+ rawSql
+ "SELECT COUNT(*) FROM pg_catalog.pg_extension WHERE extname = ?"
+ [PersistText extName]
+ if res == [Single 0]
+ then return (((), []), [(False, "CREATe EXTENSION \"" <> extName <> "\"")])
+ else return (((), []), [])
+
+-- | Mock a migration even when the database is not present.
+mockMigration :: Migration -> IO ()
+mockMigration mig = do
+ smap <- newIORef mempty
+ let
+ mockMigrateOverrides = emptyBackendSpecificOverrides
+ mockMigrateFn allDefs _ entity =
+ fmap (fmap $ map showAlterDb) $
+ return $
+ Right $
+ mockMigrateStructured mockMigrateOverrides allDefs entity
+ sqlbackend =
+ mkSqlBackend
+ MkSqlBackendArgs
+ { connPrepare = \_ -> do
+ return
+ Statement
+ { stmtFinalize = return ()
+ , stmtReset = return ()
+ , stmtExecute = undefined
+ , stmtQuery = \_ -> return $ return ()
+ }
+ , connInsertSql = undefined
+ , connStmtMap = smap
+ , connClose = undefined
+ , connMigrateSql = mockMigrateFn
+ , connBegin = undefined
+ , connCommit = undefined
+ , connRollback = undefined
+ , connEscapeFieldName = escapeF
+ , connEscapeTableName = escapeE . getEntityDBName
+ , connEscapeRawName = escape
+ , connNoLimit = undefined
+ , connRDBMS = undefined
+ , connLimitOffset = undefined
+ , connLogFunc = undefined
+ }
+ result = runReaderT $ runWriterT $ runWriterT mig
+ resp <- result sqlbackend
+ mapM_ TIO.putStrLn $ map snd $ snd resp
+
+-- PostgresConf
+
+-- | Information required to connect to a PostgreSQL database.
+data PostgresConf = PostgresConf
+ { pgConnStr :: ConnectionString
+ , pgPoolStripes :: Int
+ , pgPoolIdleTimeout :: Integer
+ , pgPoolSize :: Int
+ }
+ deriving (Show, Read, Data)
+
+instance FromJSON PostgresConf where
+ parseJSON v = modifyFailure ("Persistent: error loading PostgreSQL conf: " ++) $
+ flip (withObject "PostgresConf") v $ \o -> do
+ let defaultPoolConfig = defaultConnectionPoolConfig
+ database <- o .: "database"
+ host <- o .: "host"
+ port <- o .:? "port" .!= (5432 :: Int)
+ user <- o .: "user"
+ password <- o .: "password"
+ poolSize <- o .:? "poolsize" .!= (connectionPoolConfigSize defaultPoolConfig)
+ poolStripes <-
+ o .:? "stripes" .!= (connectionPoolConfigStripes defaultPoolConfig)
+ poolIdleTimeout <-
+ o .:? "idleTimeout"
+ .!= (floor $ connectionPoolConfigIdleTimeout defaultPoolConfig)
+ let
+ cstr =
+ B8.intercalate " "
+ [ "host='" <> B8.pack host <> "'"
+ , "port=" <> B8.pack (show port)
+ , "user='" <> B8.pack user <> "'"
+ , "password='" <> B8.pack password <> "'"
+ , "dbname='" <> B8.pack database <> "'"
+ ]
+ return $ PostgresConf cstr poolStripes poolIdleTimeout poolSize
+
+instance PersistConfig PostgresConf where
+ type PersistConfigBackend PostgresConf = ReaderT (WriteBackend PostgreSQLBackend)
+ type PersistConfigPool PostgresConf = Pool (WriteBackend PostgreSQLBackend)
+ createPoolConfig conf =
+ runNoLoggingT $ createPostgresqlPipelinePool
+ defaultPipelineSettings
+ { pipelineConnStr = pgConnStr conf
+ , pipelinePoolSize = pgPoolSize conf
+ }
+ runPool _ = runSqlPool
+ loadConfig = parseJSON
+ applyEnv c0 = do
+ env <- getEnvironment
+ return $
+ addUser env $
+ addPass env $
+ addDatabase env $
+ addPort env $
+ addHost env c0
+ where
+ addParam param val c =
+ c{pgConnStr = B8.concat [pgConnStr c, " ", param, "='", pgescape val, "'"]}
+
+ pgescape = B8.pack . go
+ where
+ go ('\'' : rest) = '\\' : '\'' : go rest
+ go ('\\' : rest) = '\\' : '\\' : go rest
+ go (x : rest) = x : go rest
+ go [] = []
+
+ maybeAddParam param envvar env =
+ maybe id (addParam param) $ lookup envvar env
+
+ addHost = maybeAddParam "host" "PGHOST"
+ addPort = maybeAddParam "port" "PGPORT"
+ addUser = maybeAddParam "user" "PGUSER"
+ addPass = maybeAddParam "password" "PGPASS"
+ addDatabase = maybeAddParam "dbname" "PGDATABASE"
+
+-- | Hooks for configuring the connection to Postgres.
+data PostgresConfHooks = PostgresConfHooks
+ { pgConfHooksGetServerVersion :: LibPQ.Connection -> IO (NonEmpty Word)
+ , pgConfHooksAfterCreate :: LibPQ.Connection -> IO ()
+ }
+
+-- | Default settings for 'PostgresConfHooks'.
+defaultPostgresConfHooks :: PostgresConfHooks
+defaultPostgresConfHooks =
+ PostgresConfHooks
+ { pgConfHooksGetServerVersion = \conn -> do
+ v <- LibPQ.serverVersion conn
+ let ver = fromIntegral v
+ major = ver `div` 10000
+ minor = (ver `mod` 10000) `div` 100
+ if ver == 0
+ then return $ 9 NEL.:| [4]
+ else return $ major NEL.:| [minor]
+ , pgConfHooksAfterCreate = const $ pure ()
+ }
+
+-- Pipeline mode
+
+-- | No-op: pipeline mode is always on.
+--
+-- In previous versions, this function entered and exited libpq pipeline mode
+-- around a block of operations. Now that pipelining is automatic and always
+-- active, this is the identity function. It is kept for API compatibility.
+withPipeline :: (MonadIO m) => ReaderT SqlBackend m a -> ReaderT SqlBackend m a
+withPipeline = id
+
+-- | Flush all pending fire-and-forget results in the pipeline.
+--
+-- Forces all queued DML operations to be sent to the server and their
+-- results consumed. Throws 'PipelineExecError' if any operation
+-- failed. This is useful for eagerly discovering errors from
+-- fire-and-forget operations (INSERT, UPDATE, DELETE) without waiting
+-- for the next SELECT or COMMIT.
+flushPipeline :: (MonadIO m) => ReaderT SqlBackend m ()
+flushPipeline = do
+ backend <- ask
+ case getPipelineConn backend of
+ Nothing -> return ()
+ Just pc -> liftIO $ drainPending pc
+
+-- | Execute raw SQL synchronously, returning the actual affected row count.
+--
+-- Unlike 'rawExecuteCount' (which returns 0 in pipeline mode because DML
+-- is fire-and-forget), this function drains all pending pipeline operations,
+-- sends the statement, flushes, and reads the result -- giving you the real
+-- number of rows affected.
+--
+-- Use this when you need the count (e.g. for @deleteCount@, @updateCount@
+-- semantics).
+rawExecuteSync
+ :: (MonadIO m)
+ => Text
+ -> [PersistValue]
+ -> ReaderT SqlBackend m Int64
+rawExecuteSync sql vals = do
+ backend <- ask
+ case getPipelineConn backend of
+ Nothing -> return 0
+ Just pc -> liftIO $ do
+ drainPending pc
+ let sqlBS = T.encodeUtf8 sql
+ (rewrittenSQL, remainingVals) = inlineAndRewrite sqlBS vals
+ params = map encodePersistValue remainingVals
+ ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary
+ unless ok $ do
+ merr <- LibPQ.errorMessage (pgConn pc)
+ throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr)
+ _ <- pgSendFlushRequest pc
+ pgFlush pc
+ ret <- readOneQueryResult pc
+ mbs <- LibPQ.cmdTuples ret
+ LibPQ.unsafeFreeResult ret
+ case mbs of
+ Nothing -> return 0
+ Just bs -> case B8.readInt bs of
+ Just (n, _) -> return (fromIntegral n)
+ Nothing -> return 0
+
+-- | Execute raw SQL as fire-and-forget (non-flushing).
+--
+-- The statement is queued in the pipeline and its result is consumed at the
+-- next drain point (SELECT, COMMIT, or 'flushPipeline'). Errors from this
+-- statement will surface at that drain point as a 'PipelineExecError'.
+--
+-- This is the same behavior as 'rawExecute', but the name makes the
+-- pipeline semantics explicit.
+rawExecuteNoReturn
+ :: (MonadIO m)
+ => Text
+ -> [PersistValue]
+ -> ReaderT SqlBackend m ()
+rawExecuteNoReturn sql vals = do
+ backend <- ask
+ case getPipelineConn backend of
+ Nothing -> return ()
+ Just pc -> liftIO $ do
+ let sqlBS = T.encodeUtf8 sql
+ (rewrittenSQL, remainingVals) = inlineAndRewrite sqlBS vals
+ params = map encodePersistValue remainingVals
+ ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary
+ unless ok $ do
+ merr <- LibPQ.errorMessage (pgConn pc)
+ throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr)
+ lazyReply <- pgRecvResult pc
+ modifyIORef (pgPending pc) (lazyReply :)
+
+-- RawPostgresqlPipeline
+
+-- | Wrapper for persistent SqlBackends that carry the corresponding
+-- libpq 'PgConn'.
+data RawPostgresqlPipeline backend = RawPostgresqlPipeline
+ { persistentBackend :: backend
+ , rawPipelineConnection :: PgConn
+ }
+
+instance BackendCompatible (RawPostgresqlPipeline b) (RawPostgresqlPipeline b) where
+ projectBackend = id
+
+instance BackendCompatible b (RawPostgresqlPipeline b) where
+ projectBackend = persistentBackend
+
+withRawConnection
+ :: (PgConn -> SqlBackend)
+ -> PgConn
+ -> RawPostgresqlPipeline SqlBackend
+withRawConnection f pc =
+ RawPostgresqlPipeline
+ { persistentBackend = f pc
+ , rawPipelineConnection = pc
+ }
+
+-- | Create a PostgreSQL connection pool which also exposes the
+-- raw PgConn.
+createRawPostgresqlPipelinePool
+ :: (MonadUnliftIO m, MonadLoggerIO m)
+ => PipelineSettings
+ -> m (Pool (RawPostgresqlPipeline SqlBackend))
+createRawPostgresqlPipelinePool settings =
+ createSqlPool (\logFunc -> do
+ pc <- openPgConn (pipelineFetchMode settings) (pipelineConnStr settings)
+ smap <- newIORef mempty
+ return $ withRawConnection (\pc' -> createBackend logFunc (pgVersion pc') smap pc') pc
+ ) (pipelinePoolSize settings)
+
+instance (PersistCore b) => PersistCore (RawPostgresqlPipeline b) where
+ newtype BackendKey (RawPostgresqlPipeline b) = RawPostgresqlPipelineKey
+ { unRawPostgresqlPipelineKey :: BackendKey (Compatible b (RawPostgresqlPipeline b))
+ }
+
+makeCompatibleKeyInstances [t| forall b. Compatible b (RawPostgresqlPipeline b) |]
+
+$(pure [])
+
+makeCompatibleInstances [t| forall b. Compatible b (RawPostgresqlPipeline b) |]
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/FFI.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/FFI.hs
new file mode 100644
index 000000000..e2627ecc9
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/FFI.hs
@@ -0,0 +1,149 @@
+{-# LANGUAGE PatternSynonyms #-}
+
+-- | Additional libpq FFI bindings for features not yet in @postgresql-libpq@.
+--
+-- Provides:
+--
+-- * 'setChunkedRowsMode' — PG17's @PQsetChunkedRowsMode@, with a
+-- compile-time stub when headers are pre-17.
+-- * 'hasChunkedRowsSupport' — runtime check for chunked mode availability.
+-- * 'rawResultStatusInt' — raw @ExecStatus@ as 'CInt', bypassing the
+-- @postgresql-libpq@ enum which throws on unknown values (e.g.
+-- @PGRES_TUPLES_CHUNK@).
+-- * Pattern synonyms for the raw @ExecStatus@ integer constants.
+module Database.Persist.Postgresql.Pipeline.FFI
+ ( -- * Chunked rows mode (PG17+)
+ setChunkedRowsMode
+ , hasChunkedRowsSupport
+ -- * Raw result status
+ , rawResultStatusInt
+ -- * Raw ExecStatus constants
+ , pattern PGRES_EMPTY_QUERY
+ , pattern PGRES_COMMAND_OK
+ , pattern PGRES_TUPLES_OK
+ , pattern PGRES_COPY_OUT
+ , pattern PGRES_COPY_IN
+ , pattern PGRES_BAD_RESPONSE
+ , pattern PGRES_NONFATAL_ERROR
+ , pattern PGRES_FATAL_ERROR
+ , pattern PGRES_COPY_BOTH
+ , pattern PGRES_SINGLE_TUPLE
+ , pattern PGRES_PIPELINE_SYNC
+ , pattern PGRES_PIPELINE_ABORTED
+ , pattern PGRES_TUPLES_CHUNK
+ ) where
+
+import Foreign.C.Types (CInt (..))
+import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
+import Foreign.Ptr (Ptr)
+import Unsafe.Coerce (unsafeCoerce)
+
+import Database.PostgreSQL.LibPQ.Internal (PGconn, withConn)
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+
+-------------------------------------------------------------------------------
+-- Foreign imports from cbits/hs_libpq_extra.c
+-------------------------------------------------------------------------------
+
+foreign import ccall unsafe "hs_PQsetChunkedRowsMode"
+ c_setChunkedRowsMode :: Ptr PGconn -> CInt -> IO CInt
+
+foreign import ccall unsafe "hs_has_chunked_rows_support"
+ c_hasChunkedRowsSupport :: IO CInt
+
+-- | Phantom type tag matching the C @PGresult@ struct. We declare our
+-- own rather than reusing the one from @postgresql-libpq@ (which is not
+-- exported).
+data PGresult
+
+foreign import ccall unsafe "hs_PQresultStatusRaw"
+ c_resultStatusRaw :: Ptr PGresult -> IO CInt
+
+-------------------------------------------------------------------------------
+-- Haskell wrappers
+-------------------------------------------------------------------------------
+
+-- | Activate chunked row retrieval for the query that was just sent via
+-- @sendQueryParams@. Must be called immediately after sending, before
+-- any other operation on the connection.
+--
+-- Returns 'True' on success, 'False' if the mode could not be set
+-- (e.g. libpq < 17, or wrong call timing).
+setChunkedRowsMode :: LibPQ.Connection -> Int -> IO Bool
+setChunkedRowsMode conn chunkSize =
+ withConn conn $ \ptr -> do
+ rc <- c_setChunkedRowsMode ptr (fromIntegral chunkSize)
+ return (rc == 1)
+
+-- | 'True' if the linked libpq supports @PQsetChunkedRowsMode@ (PG17+).
+-- Cached at first call; safe to call repeatedly.
+hasChunkedRowsSupport :: IO Bool
+hasChunkedRowsSupport = do
+ rc <- c_hasChunkedRowsSupport
+ return (rc == 1)
+
+-- | Get the raw @ExecStatus@ integer from a 'LibPQ.Result'.
+--
+-- This bypasses @LibPQ.resultStatus@ which calls @fail@ on status codes
+-- it doesn't recognise (e.g. @PGRES_TUPLES_CHUNK@ = 12 on older
+-- @postgresql-libpq@ builds).
+--
+-- __Safety__: 'LibPQ.Result' is a @newtype@ around @ForeignPtr PGresult@.
+-- We use 'unsafeCoerce' to extract the 'ForeignPtr' without the
+-- constructor being in scope. This is safe because:
+--
+-- * The representation is guaranteed by @postgresql-libpq@'s source.
+-- * 'withForeignPtr' keeps the pointer alive for the duration of the call.
+-- * The C function is a pure read with no side effects.
+rawResultStatusInt :: LibPQ.Result -> IO CInt
+rawResultStatusInt result =
+ withForeignPtr (unsafeCoerce result :: ForeignPtr PGresult) c_resultStatusRaw
+
+-------------------------------------------------------------------------------
+-- Raw ExecStatus constants
+--
+-- These mirror the C enum ExecStatusType. Values 0–11 are stable across
+-- all supported PostgreSQL versions; 12 (PGRES_TUPLES_CHUNK) was added
+-- in PG17.
+-------------------------------------------------------------------------------
+
+pattern PGRES_EMPTY_QUERY :: CInt
+pattern PGRES_EMPTY_QUERY = 0
+
+pattern PGRES_COMMAND_OK :: CInt
+pattern PGRES_COMMAND_OK = 1
+
+pattern PGRES_TUPLES_OK :: CInt
+pattern PGRES_TUPLES_OK = 2
+
+pattern PGRES_COPY_OUT :: CInt
+pattern PGRES_COPY_OUT = 3
+
+pattern PGRES_COPY_IN :: CInt
+pattern PGRES_COPY_IN = 4
+
+pattern PGRES_BAD_RESPONSE :: CInt
+pattern PGRES_BAD_RESPONSE = 5
+
+pattern PGRES_NONFATAL_ERROR :: CInt
+pattern PGRES_NONFATAL_ERROR = 6
+
+pattern PGRES_FATAL_ERROR :: CInt
+pattern PGRES_FATAL_ERROR = 7
+
+pattern PGRES_COPY_BOTH :: CInt
+pattern PGRES_COPY_BOTH = 8
+
+pattern PGRES_SINGLE_TUPLE :: CInt
+pattern PGRES_SINGLE_TUPLE = 9
+
+pattern PGRES_PIPELINE_SYNC :: CInt
+pattern PGRES_PIPELINE_SYNC = 10
+
+pattern PGRES_PIPELINE_ABORTED :: CInt
+pattern PGRES_PIPELINE_ABORTED = 11
+
+-- | Chunked tuple result (PG17+). Each @PGresult@ contains up to N rows
+-- as specified by @PQsetChunkedRowsMode@.
+pattern PGRES_TUPLES_CHUNK :: CInt
+pattern PGRES_TUPLES_CHUNK = 12
diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/Internal.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/Internal.hs
new file mode 100644
index 000000000..a2a328e3a
--- /dev/null
+++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/Internal.hs
@@ -0,0 +1,352 @@
+{-# LANGUAGE OverloadedStrings #-}
+
+-- | Low-level PostgreSQL connection wrapper with server version info.
+--
+-- Wraps 'LibPQ.Connection' and provides pipeline mode primitives.
+module Database.Persist.Postgresql.Pipeline.Internal
+ ( PgConn (..)
+ , FetchMode (..)
+ , openPgConn
+ , closePgConn
+ , pgSendQueryParams
+ , pgGetResult
+ , pgEnterPipelineMode
+ , pgExitPipelineMode
+ , pgPipelineSync
+ , pgPipelineStatus
+ , pgFlush
+ , pgSendFlushRequest
+ , pgRecvResult
+ , PipelineError (..)
+ ) where
+
+import Control.Exception (Exception, throwIO)
+import Control.Monad (when)
+import Data.ByteString (ByteString)
+import qualified Data.ByteString.Char8 as B8
+import Data.IORef
+import Data.List.NonEmpty (NonEmpty (..))
+import Data.Maybe (mapMaybe)
+import GHC.Conc (threadWaitWrite)
+import System.IO.Unsafe (unsafeInterleaveIO)
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+
+import Database.Persist.Postgresql.Internal.PgType (OidCache, emptyOidCache)
+import Database.Persist.Postgresql.Pipeline.FFI (hasChunkedRowsSupport)
+
+-- | Controls how query result rows are fetched from the server.
+data FetchMode
+ = FetchAll
+ -- ^ Default. The server returns the entire result set in a single
+ -- @PGresult@. Simplest and fastest for small\/medium result sets,
+ -- but the full result must fit in memory.
+ | FetchSingleRow
+ -- ^ Single-row mode (@PQsetSingleRowMode@). Each @PGresult@ contains
+ -- exactly one row (status @PGRES_SINGLE_TUPLE@). Lowest memory
+ -- footprint but highest per-row allocation overhead. Works with
+ -- any PostgreSQL\/libpq version.
+ | FetchChunked !Int
+ -- ^ Chunked mode (@PQsetChunkedRowsMode@, PG17+ libpq). Each
+ -- @PGresult@ contains up to the given number of rows (status
+ -- @PGRES_TUPLES_CHUNK@). Good balance of bounded memory and reduced
+ -- allocation overhead.
+ --
+ -- If the linked libpq does not support chunked mode, this is
+ -- automatically downgraded to 'FetchSingleRow' at connection time.
+ deriving (Eq, Show)
+
+-- | A PostgreSQL connection with cached server version info.
+--
+-- Pipeline mode is always on. The 'pgPending' counter tracks how many
+-- fire-and-forget query results have not yet been read. These are drained
+-- at query time ('stmtQuery') or transaction boundaries ('connCommit',
+-- 'connRollback').
+data PgConn = PgConn
+ { pgConn :: !LibPQ.Connection
+ , pgVersion :: !(NonEmpty Word)
+ , pgPending :: !(IORef [LibPQ.Result])
+ -- ^ All outstanding pipeline results (lazy thunks popped from
+ -- 'pgReplies'). Both fire-and-forget DML ('execute'') and lazy
+ -- reads ('pipelinedSend') append here. 'drainPending' forces,
+ -- error-checks, and frees all of them. Double-free is safe
+ -- because 'LibPQ.Result' is a 'ForeignPtr' -- 'unsafeFreeResult'
+ -- is idempotent.
+ --
+ -- Stored in reverse order (newest first) for O(1) append.
+ , pgFetchMode :: !FetchMode
+ -- ^ Row fetch strategy for SELECT-like queries on this connection.
+ , pgOidCache :: !(IORef OidCache)
+ -- ^ Cache for dynamically-discovered type OIDs (custom enums,
+ -- composites, domains). Starts empty; future custom-type support
+ -- can populate it at connection time or on first encounter.
+ -- Use 'resolveOid' from "Database.Persist.Postgresql.Internal.PgType"
+ -- to consult both the built-in table and this cache.
+ , pgReplies :: !(IORef [LibPQ.Result])
+ -- ^ Lazy reply stream (Hedis-style). Each element, when forced,
+ -- flushes the send buffer and reads one pipeline result + NULL
+ -- separator from the connection. Created at connection time via
+ -- 'unsafeInterleaveIO'. Used by 'pgRecvResult' which pops the
+ -- head lazily (without forcing the IO) so that back-to-back reads
+ -- are automatically pipelined.
+ }
+
+-- | Errors from pipeline operations.
+data PipelineError
+ = PipelineExecError !ByteString
+ -- ^ Error message from a failed query
+ | PipelineModeError !String
+ -- ^ Error entering/exiting pipeline mode
+ deriving (Show)
+
+instance Exception PipelineError
+
+-- | Resolve the requested 'FetchMode', downgrading 'FetchChunked' to
+-- 'FetchSingleRow' when the linked libpq doesn't support it.
+resolveFetchMode :: FetchMode -> IO FetchMode
+resolveFetchMode fm@FetchAll = return fm
+resolveFetchMode fm@FetchSingleRow = return fm
+resolveFetchMode fm@(FetchChunked _) = do
+ ok <- hasChunkedRowsSupport
+ if ok
+ then return fm
+ else return FetchSingleRow
+
+-- | Open a libpq connection, detect the server version, and enter pipeline mode.
+--
+-- The 'FetchMode' is resolved (see 'resolveFetchMode') and stored in the
+-- returned 'PgConn'.
+openPgConn :: FetchMode -> ByteString -> IO PgConn
+openPgConn requestedMode connStr = do
+ conn <- LibPQ.connectdb connStr
+ st <- LibPQ.status conn
+ case st of
+ LibPQ.ConnectionOk -> do
+ ver <- getServerVersionLibPQ conn
+ -- Enable nonblocking mode before entering pipeline mode.
+ -- In blocking mode, pipeline operations risk deadlock when
+ -- both client and server buffers fill simultaneously.
+ nbOk <- LibPQ.setnonblocking conn True
+ when (not nbOk) $ do
+ LibPQ.finish conn
+ fail "persistent-postgresql-ng: failed to set nonblocking mode"
+ pending <- newIORef ([] :: [LibPQ.Result])
+ oidCache <- newIORef emptyOidCache
+ mode <- resolveFetchMode requestedMode
+ repliesRef <- newIORef []
+ let pc = PgConn
+ { pgConn = conn
+ , pgVersion = ver
+ , pgPending = pending
+ , pgFetchMode = mode
+ , pgOidCache = oidCache
+ , pgReplies = repliesRef
+ }
+ replies <- mkReplyStream pc
+ writeIORef repliesRef replies
+ ok <- pgEnterPipelineMode pc
+ if ok
+ then return pc
+ else do
+ LibPQ.finish conn
+ fail "persistent-postgresql-ng: failed to enter pipeline mode"
+ _ -> do
+ merr <- LibPQ.errorMessage conn
+ LibPQ.finish conn
+ let msg = maybe "Unknown connection error" B8.unpack merr
+ fail $ "persistent-postgresql-ng: connection failed: " ++ msg
+
+-- | Close a libpq connection. Exits pipeline mode first.
+closePgConn :: PgConn -> IO ()
+closePgConn pc = do
+ -- Best-effort: exit pipeline mode before closing.
+ -- If there are pending results, pipelineSync + drain them first.
+ pendingList <- readIORef (pgPending pc)
+ when (not (null pendingList)) $ do
+ _ <- pgPipelineSync pc
+ drainToSyncClose (pgConn pc)
+ _ <- pgExitPipelineMode pc
+ LibPQ.finish (pgConn pc)
+ where
+ -- Drain results until PipelineSync, ignoring all errors.
+ drainToSyncClose conn = do
+ mret <- LibPQ.getResult conn
+ case mret of
+ Nothing -> drainToSyncClose conn -- NULL separator, keep reading
+ Just ret -> do
+ st <- LibPQ.resultStatus ret
+ LibPQ.unsafeFreeResult ret
+ case st of
+ LibPQ.PipelineSync -> return ()
+ _ -> drainToSyncClose conn
+
+-- | Send a query in pipeline mode (non-blocking).
+pgSendQueryParams
+ :: PgConn
+ -> ByteString
+ -> [Maybe (LibPQ.Oid, ByteString, LibPQ.Format)]
+ -> LibPQ.Format
+ -> IO Bool
+pgSendQueryParams pc sql params resultFmt =
+ LibPQ.sendQueryParams (pgConn pc) sql params resultFmt
+
+-- | Get the next result from the connection (for pipeline mode).
+pgGetResult :: PgConn -> IO (Maybe LibPQ.Result)
+pgGetResult = LibPQ.getResult . pgConn
+
+-- | Enter pipeline mode.
+pgEnterPipelineMode :: PgConn -> IO Bool
+pgEnterPipelineMode = LibPQ.enterPipelineMode . pgConn
+
+-- | Exit pipeline mode.
+pgExitPipelineMode :: PgConn -> IO Bool
+pgExitPipelineMode = LibPQ.exitPipelineMode . pgConn
+
+-- | Insert a sync point in pipeline mode.
+pgPipelineSync :: PgConn -> IO Bool
+pgPipelineSync = LibPQ.pipelineSync . pgConn
+
+-- | Get the current pipeline status.
+pgPipelineStatus :: PgConn -> IO LibPQ.PipelineStatus
+pgPipelineStatus = LibPQ.pipelineStatus . pgConn
+
+-- | Flush the client send buffer, waiting for socket writability as needed.
+-- In nonblocking mode, 'LibPQ.flush' may return 'FlushWriting' when the
+-- socket's send buffer is full. We wait for writability using GHC's I/O
+-- manager ('threadWaitWrite'), then consume any pending server data
+-- ('consumeInput') to prevent the server from blocking on its send buffer,
+-- and retry.
+-- Throws 'PipelineExecError' on failure.
+pgFlush :: PgConn -> IO ()
+pgFlush pc = do
+ st <- LibPQ.flush (pgConn pc)
+ case st of
+ LibPQ.FlushOk -> return ()
+ LibPQ.FlushFailed -> throwIO $ PipelineExecError "flush failed"
+ LibPQ.FlushWriting -> do
+ mfd <- LibPQ.socket (pgConn pc)
+ case mfd of
+ Nothing -> throwIO $ PipelineExecError "flush: no socket"
+ Just fd -> flushLoop fd
+ where
+ flushLoop fd = do
+ threadWaitWrite fd
+ -- Consume any pending server data to prevent server-side blocking.
+ -- The server may be unable to read our data because its send buffer
+ -- is full with results we haven't consumed yet.
+ _ <- LibPQ.consumeInput (pgConn pc)
+ st <- LibPQ.flush (pgConn pc)
+ case st of
+ LibPQ.FlushOk -> return ()
+ LibPQ.FlushFailed -> throwIO $ PipelineExecError "flush failed"
+ LibPQ.FlushWriting -> flushLoop fd
+
+-- | Send a request for the server to flush its output buffer.
+-- In pipeline mode, the server buffers results until it receives a
+-- Sync or FlushRequest message.
+pgSendFlushRequest :: PgConn -> IO Bool
+pgSendFlushRequest = LibPQ.sendFlushRequest . pgConn
+
+-- | Pop the next result from the lazy reply stream without forcing the
+-- read. The actual flush+read happens when the returned 'LibPQ.Result'
+-- is evaluated (via 'unsafeInterleaveIO').
+--
+-- Uses @head@/@tail@ instead of pattern matching to avoid forcing the
+-- cons cell -- this is the key to Hedis-style automatic pipelining.
+-- @atomicModifyIORef@ stores @tail xs@ and returns @head xs@ as
+-- unevaluated thunks. The actual IO fires only when the caller
+-- inspects the returned 'LibPQ.Result'.
+pgRecvResult :: PgConn -> IO LibPQ.Result
+pgRecvResult pc = atomicModifyIORef (pgReplies pc) (\xs -> (tail xs, head xs))
+{-# INLINE pgRecvResult #-}
+
+-- | Build the infinite lazy reply stream. Each cons cell, when forced,
+-- flushes the send buffer and reads one query result (result + NULL
+-- separator) from the connection.
+mkReplyStream :: PgConn -> IO [LibPQ.Result]
+mkReplyStream pc = go
+ where
+ go = unsafeInterleaveIO $ do
+ -- Flush the send buffer so the server sees all queued queries.
+ _ <- LibPQ.sendFlushRequest (pgConn pc)
+ pgFlush pc
+ -- Read one result.
+ mret <- LibPQ.getResult (pgConn pc)
+ case mret of
+ Nothing -> do
+ -- No result available -- shouldn't happen in a well-formed
+ -- pipeline, but protect against it.
+ rest <- go
+ return rest
+ Just ret -> do
+ st <- LibPQ.resultStatus ret
+ case st of
+ LibPQ.CommandOk -> do
+ -- DML result (INSERT/UPDATE/DELETE).
+ -- Consume the NULL separator.
+ _ <- LibPQ.getResult (pgConn pc)
+ rest <- go
+ return (ret : rest)
+ LibPQ.TuplesOk -> do
+ -- Query result (SELECT).
+ -- Consume the NULL separator.
+ _ <- LibPQ.getResult (pgConn pc)
+ rest <- go
+ return (ret : rest)
+ _ -> do
+ -- Error or other status -- still consume separator
+ -- and include in stream (caller handles errors).
+ _ <- LibPQ.getResult (pgConn pc)
+ rest <- go
+ return (ret : rest)
+
+-- | Get the server version from a libpq connection.
+--
+-- Uses @LibPQ.serverVersion@ which returns an integer encoding like @140005@
+-- for version 14.0.5 (major * 10000 + minor * 100 + patch).
+getServerVersionLibPQ :: LibPQ.Connection -> IO (NonEmpty Word)
+getServerVersionLibPQ conn = do
+ v <- LibPQ.serverVersion conn
+ let ver = fromIntegral v
+ major = ver `div` 10000
+ minor = (ver `mod` 10000) `div` 100
+ patch = ver `mod` 100
+ if ver == 0
+ then
+ -- Fallback: if serverVersion returns 0, try querying
+ getServerVersionQuery conn
+ else
+ return $ major :| [minor, patch]
+
+-- | Fallback: query the server version via SQL.
+getServerVersionQuery :: LibPQ.Connection -> IO (NonEmpty Word)
+getServerVersionQuery conn = do
+ mret <- LibPQ.exec conn "SHOW server_version"
+ case mret of
+ Nothing -> return $ 9 :| [4] -- Default to minimum supported
+ Just ret -> do
+ st <- LibPQ.resultStatus ret
+ case st of
+ LibPQ.TuplesOk -> do
+ mbs <- LibPQ.getvalue ret (LibPQ.toRow (0 :: Int)) (LibPQ.toColumn (0 :: Int))
+ LibPQ.unsafeFreeResult ret
+ case mbs of
+ Nothing -> return $ 9 :| [4]
+ Just bs -> case parseVersionBS bs of
+ Just v -> return v
+ Nothing -> return $ 9 :| [4]
+ _ -> do
+ LibPQ.unsafeFreeResult ret
+ return $ 9 :| [4]
+
+-- | Parse a version string like "14.5" or "14.5.1" from a ByteString.
+parseVersionBS :: ByteString -> Maybe (NonEmpty Word)
+parseVersionBS bs =
+ let parts = B8.split '.' bs
+ nums = mapMaybe readWord parts
+ in case nums of
+ (x:xs) -> Just (x :| xs)
+ [] -> Nothing
+ where
+ readWord s = case reads (B8.unpack s) of
+ [(n, "")] -> Just n
+ _ -> Nothing
diff --git a/persistent-postgresql-ng/LICENSE b/persistent-postgresql-ng/LICENSE
new file mode 100644
index 000000000..f7f439fe3
--- /dev/null
+++ b/persistent-postgresql-ng/LICENSE
@@ -0,0 +1,20 @@
+Copyright (c) 2026 Ian Duncan
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be included
+in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
+CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
+TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
+SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/persistent-postgresql-ng/bench-baseline.html b/persistent-postgresql-ng/bench-baseline.html
new file mode 100644
index 000000000..5ccf8537f
--- /dev/null
+++ b/persistent-postgresql-ng/bench-baseline.html
@@ -0,0 +1,1159 @@
+
+
+
+
+ criterion report
+
+
+
+
+
+
+
+
criterion performance measurements
+
want to understand this report?
+
+
+
+
+ ⓘ
+
+
+
+
+
+
+
+
+
+
+ The overview chart supports the following sort orders:
+
+ - index order is the order as the benchmarks are defined in criterion
+ - lexical order sorts groups left-to-right, alphabetically
+ - colexical order sorts groups right-to-left, alphabetically
+ - time ascending/descending order sorts by the estimated mean execution time
+
+
+
+
+
+
+
+
+
+
diff --git a/persistent-postgresql-ng/bench/Main.hs b/persistent-postgresql-ng/bench/Main.hs
new file mode 100644
index 000000000..97b2952b1
--- /dev/null
+++ b/persistent-postgresql-ng/bench/Main.hs
@@ -0,0 +1,302 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+-- | Benchmarks comparing persistent-postgresql-ng (binary protocol)
+-- against persistent-postgresql (text protocol via postgresql-simple).
+--
+-- Run with:
+--
+-- @
+-- stack bench persistent-postgresql-ng
+-- @
+--
+-- Requires PostgreSQL running on localhost:5432 with a @test@ database
+-- accessible by the current user (or the @postgres@ role).
+module Main (main) where
+
+import Control.Exception (evaluate)
+import Control.Monad (forM_, void)
+import Control.Monad.IO.Class (liftIO)
+import Control.Monad.Logger (runNoLoggingT)
+import Criterion.Main
+import qualified Data.ByteString.Char8 as B8
+import Data.Int (Int64)
+import Data.Maybe (fromMaybe)
+import Data.Pool (Pool)
+import Data.Text (Text)
+import qualified Data.Text as T
+import Control.Monad.Trans.Reader (ReaderT, withReaderT)
+import Database.Persist
+import Database.Persist.Class.PersistStore (BackendCompatible (..))
+import Database.Persist.Sql
+import Database.Persist.TH
+import System.Environment (lookupEnv)
+
+import qualified Database.Persist.Postgresql as PgSimple
+import qualified Database.Persist.Postgresql.Pipeline as PgPipeline
+
+share
+ [mkPersist sqlSettings, mkMigrate "benchMigrate"]
+ [persistLowerCase|
+BenchPerson
+ name Text
+ age Int
+ email Text
+ UniqueBenchPersonEmail email
+ deriving Show Eq
+|]
+
+getConnString :: IO B8.ByteString
+getConnString = do
+ host <- fromMaybe "localhost" <$> lookupEnv "PGHOST"
+ port <- fromMaybe "5432" <$> lookupEnv "PGPORT"
+ user <- fromMaybe "postgres" <$> lookupEnv "PGUSER"
+ db <- fromMaybe "test" <$> lookupEnv "PGDATABASE"
+ pure $ B8.pack $ "host=" <> host <> " port=" <> port
+ <> " user=" <> user <> " dbname=" <> db
+
+cleanTablesPl :: ReaderT (PgPipeline.WriteBackend PgPipeline.PostgreSQLBackend) IO ()
+cleanTablesPl = deleteWhere ([] :: [Filter BenchPerson])
+
+cleanTablesSim :: SqlPersistT IO ()
+cleanTablesSim = deleteWhere ([] :: [Filter BenchPerson])
+
+people :: Int -> [BenchPerson]
+people n =
+ [ BenchPerson
+ (T.pack $ "Person " <> show i)
+ (20 + i `mod` 60)
+ (T.pack $ "person" <> show i <> "@test.com")
+ | i <- [1..n]
+ ]
+
+pl :: Pool (PgPipeline.WriteBackend PgPipeline.PostgreSQLBackend)
+ -> ReaderT (PgPipeline.WriteBackend PgPipeline.PostgreSQLBackend) IO a -> IO a
+pl pool action = runSqlPool action pool
+
+
+sim :: Pool SqlBackend -> SqlPersistT IO a -> IO a
+sim pool action = runSqlPool action pool
+
+main :: IO ()
+main = do
+ connStr <- getConnString
+
+ pipelinePool <- runNoLoggingT $ PgPipeline.createPostgresqlPipelinePool
+ PgPipeline.defaultPipelineSettings
+ { PgPipeline.pipelineConnStr = connStr
+ , PgPipeline.pipelinePoolSize = 1
+ }
+ simplePool <- runNoLoggingT $ PgSimple.createPostgresqlPool connStr 1
+
+ pl pipelinePool $ withReaderT projectBackend $ void $ runMigrationSilent benchMigrate
+
+ let p = pl pipelinePool
+ s = sim simplePool
+
+ defaultMain
+ [ bgroup "delete x100 then select (pipeline sweet spot)"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl
+ keys <- mapM insert (people 100)
+ forM_ keys delete
+ void $ selectList ([] :: [Filter BenchPerson]) []
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim
+ keys <- mapM insert (people 100)
+ mapM_ delete keys
+ void $ selectList ([] :: [Filter BenchPerson]) []
+ ]
+
+ , bgroup "update x100 then select (pipeline sweet spot)"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl
+ keys <- mapM insert (people 100)
+ forM_ keys $ \k -> update k [BenchPersonAge =. 99]
+ void $ selectList ([] :: [Filter BenchPerson]) []
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim
+ keys <- mapM insert (people 100)
+ forM_ keys $ \k -> update k [BenchPersonAge =. 99]
+ void $ selectList ([] :: [Filter BenchPerson]) []
+ ]
+
+ , bgroup "mixed DML x100 (delete+update+replace, no reads until end)"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl
+ keys <- mapM insert (people 300)
+ let (delKeys, rest) = splitAt 100 keys
+ let (updKeys, repKeys) = splitAt 100 rest
+ forM_ delKeys delete
+ forM_ updKeys $ \k -> update k [BenchPersonAge =. 99]
+ forM_ (zip repKeys (people 100)) $ \(k, person) ->
+ replace k (person { benchPersonAge = 42 })
+ void $ selectList ([] :: [Filter BenchPerson]) []
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim
+ keys <- mapM insert (people 300)
+ let (delKeys, rest) = splitAt 100 keys
+ let (updKeys, repKeys) = splitAt 100 rest
+ mapM_ delete delKeys
+ forM_ updKeys $ \k -> update k [BenchPersonAge =. 99]
+ forM_ (zip repKeys (people 100)) $ \(k, person) ->
+ replace k (person { benchPersonAge = 42 })
+ void $ selectList ([] :: [Filter BenchPerson]) []
+ ]
+
+ , bgroup "getManyKeys x100 (single round-trip vs N)"
+ [ bench "pipeline (= ANY)" $ whnfIO $ p $ do
+ cleanTablesPl
+ keys <- mapM insert (people 100)
+ void $ withReaderT projectBackend $ PgPipeline.getManyKeys keys
+ , bench "simple (OR chain)" $ whnfIO $ s $ do
+ cleanTablesSim
+ keys <- mapM insert (people 100)
+ void $ getMany keys
+ ]
+
+ , bgroup "single insert x100"
+ [ bench "pipeline" $ whnfIO $ p $ cleanTablesPl >> forM_ (people 100) insert_
+ , bench "simple" $ whnfIO $ s $ cleanTablesSim >> forM_ (people 100) insert_
+ ]
+
+ , bgroup "insertMany x100"
+ [ bench "pipeline" $ whnfIO $ p $ cleanTablesPl >> insertMany_ (people 100)
+ , bench "simple" $ whnfIO $ s $ cleanTablesSim >> insertMany_ (people 100)
+ ]
+
+ , bgroup "insertMany x1000"
+ [ bench "pipeline" $ whnfIO $ p $ cleanTablesPl >> insertMany_ (people 1000)
+ , bench "simple" $ whnfIO $ s $ cleanTablesSim >> insertMany_ (people 1000)
+ ]
+
+ , bgroup "selectList x100"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl >> insertMany_ (people 100)
+ void $ selectList ([] :: [Filter BenchPerson]) []
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim >> insertMany_ (people 100)
+ void $ selectList ([] :: [Filter BenchPerson]) []
+ ]
+
+ , bgroup "selectList filtered"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl >> insertMany_ (people 100)
+ void $ selectList [BenchPersonAge >. 50] []
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim >> insertMany_ (people 100)
+ void $ selectList [BenchPersonAge >. 50] []
+ ]
+
+ , bgroup "select IN x20 (= ANY vs IN (?,?...))"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl >> insertMany_ (people 100)
+ let names = [T.pack $ "Person " <> show i | i <- [1..20 :: Int]]
+ void $ selectList [BenchPersonName <-. names] []
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim >> insertMany_ (people 100)
+ let names = [T.pack $ "Person " <> show i | i <- [1..20 :: Int]]
+ void $ selectList [BenchPersonName <-. names] []
+ ]
+
+ -- The key pipelining benchmark: mapM get sends all queries
+ -- before reading any results (Hedis-style lazy replies).
+ -- Keys are pre-materialized (as Int64) to isolate the get path.
+ , env (do rawKeys <- runSqlPool
+ (do cleanTablesSim
+ ks <- mapM insert (people 100)
+ _ <- liftIO $ evaluate (length ks)
+ return (map (\k -> fromSqlKey k) ks))
+ simplePool
+ evaluate rawKeys) $ \rawKeys ->
+ bgroup "get by key x100 (pipelined reads)"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ let keys = map (toSqlKey . fromIntegral) rawKeys :: [Key BenchPerson]
+ results <- mapM get keys
+ -- Force all Maybe values (triggers the deferred reads)
+ liftIO $ forM_ results $ evaluate . fmap benchPersonAge
+ , bench "simple" $ whnfIO $ s $ do
+ let keys = map (toSqlKey . fromIntegral) rawKeys :: [Key BenchPerson]
+ results <- mapM get keys
+ liftIO $ forM_ results $ evaluate . fmap benchPersonAge
+ ]
+
+ -- Insert pipelining: mapM insert sends all INSERT RETURNING
+ -- queries before reading any keys back.
+ , bgroup "insert x100 (pipelined RETURNING)"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl
+ keys <- mapM insert (people 100)
+ liftIO $ void $ evaluate (length keys)
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim
+ keys <- mapM insert (people 100)
+ liftIO $ void $ evaluate (length keys)
+ ]
+
+ , bgroup "update x100"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl
+ keys <- mapM insert (people 100)
+ forM_ keys $ \k -> update k [BenchPersonAge =. 99]
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim
+ keys <- mapM insert (people 100)
+ forM_ keys $ \k -> update k [BenchPersonAge =. 99]
+ ]
+
+ , bgroup "delete x100"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl
+ keys <- mapM insert (people 100)
+ forM_ keys delete
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim
+ keys <- mapM insert (people 100)
+ mapM_ delete keys
+ ]
+
+ , bgroup "upsert x100"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl
+ forM_ (people 100) $ \person -> upsert person [BenchPersonAge =. benchPersonAge person]
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim
+ forM_ (people 100) $ \person -> upsert person [BenchPersonAge =. benchPersonAge person]
+ ]
+
+ , bgroup "replace x100"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl
+ keys <- mapM insert (people 100)
+ forM_ (zip keys (people 100)) $ \(k, person) ->
+ replace k (person { benchPersonAge = 99 })
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim
+ keys <- mapM insert (people 100)
+ forM_ (zip keys (people 100)) $ \(k, person) ->
+ replace k (person { benchPersonAge = 99 })
+ ]
+
+ , bgroup "deleteWhere x100"
+ [ bench "pipeline" $ whnfIO $ p $ do
+ cleanTablesPl >> insertMany_ (people 100)
+ forM_ [1..100 :: Int] $ \i ->
+ deleteWhere [BenchPersonName ==. T.pack ("Person " <> show i)]
+ , bench "simple" $ whnfIO $ s $ do
+ cleanTablesSim >> insertMany_ (people 100)
+ forM_ [1..100 :: Int] $ \i ->
+ deleteWhere [BenchPersonName ==. T.pack ("Person " <> show i)]
+ ]
+ ]
diff --git a/persistent-postgresql-ng/bench/delay-proxy.py b/persistent-postgresql-ng/bench/delay-proxy.py
new file mode 100644
index 000000000..64be36ba7
--- /dev/null
+++ b/persistent-postgresql-ng/bench/delay-proxy.py
@@ -0,0 +1,70 @@
+#!/usr/bin/env python3
+"""TCP proxy that adds artificial latency to every packet.
+
+Usage: python3 delay-proxy.py
+
+Each direction (client→server and server→client) gets the delay added,
+so the effective round-trip latency increase is 2 * delay_ms.
+"""
+
+import asyncio
+import sys
+import time
+
+
+async def pipe(reader, writer, delay_sec, label):
+ try:
+ while True:
+ data = await reader.read(65536)
+ if not data:
+ break
+ if delay_sec > 0:
+ await asyncio.sleep(delay_sec)
+ writer.write(data)
+ await writer.drain()
+ except (ConnectionResetError, BrokenPipeError):
+ pass
+ finally:
+ writer.close()
+
+
+async def handle_client(client_reader, client_writer, target_host, target_port, delay_sec):
+ try:
+ server_reader, server_writer = await asyncio.open_connection(target_host, target_port)
+ except Exception as e:
+ print(f"Failed to connect to {target_host}:{target_port}: {e}")
+ client_writer.close()
+ return
+
+ await asyncio.gather(
+ pipe(client_reader, server_writer, delay_sec, "c→s"),
+ pipe(server_reader, client_writer, delay_sec, "s→c"),
+ )
+
+
+async def main():
+ if len(sys.argv) != 5:
+ print(f"Usage: {sys.argv[0]} ")
+ sys.exit(1)
+
+ listen_port = int(sys.argv[1])
+ target_host = sys.argv[2]
+ target_port = int(sys.argv[3])
+ delay_ms = float(sys.argv[4])
+ delay_sec = delay_ms / 1000.0
+
+ async def on_connect(reader, writer):
+ await handle_client(reader, writer, target_host, target_port, delay_sec)
+
+ server = await asyncio.start_server(on_connect, "127.0.0.1", listen_port)
+ addr = server.sockets[0].getsockname()
+ print(f"Delay proxy listening on {addr[0]}:{addr[1]}")
+ print(f"Forwarding to {target_host}:{target_port} with {delay_ms}ms delay per direction")
+ print(f"Effective RTT increase: {delay_ms * 2}ms")
+
+ async with server:
+ await server.serve_forever()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/persistent-postgresql-ng/bench/run-with-latency.sh b/persistent-postgresql-ng/bench/run-with-latency.sh
new file mode 100755
index 000000000..b07fffd1a
--- /dev/null
+++ b/persistent-postgresql-ng/bench/run-with-latency.sh
@@ -0,0 +1,51 @@
+#!/usr/bin/env bash
+# Run benchmarks with artificial network latency on the loopback interface.
+#
+# Usage:
+# sudo ./bench/run-with-latency.sh [delay_ms]
+#
+# Default delay: 1ms (simulates a local datacenter hop)
+#
+# Requires root for pfctl/dnctl. Cleans up on exit.
+
+set -euo pipefail
+
+DELAY_MS="${1:-1}"
+PG_PORT="${PGPORT:-5432}"
+PIPE_NR=42
+
+cleanup() {
+ echo "Cleaning up dummynet rules..."
+ pfctl -f /etc/pf.conf 2>/dev/null || true
+ dnctl pipe "$PIPE_NR" delete 2>/dev/null || true
+ echo "Done."
+}
+trap cleanup EXIT
+
+echo "=== Adding ${DELAY_MS}ms latency to loopback port ${PG_PORT} ==="
+
+dnctl pipe "$PIPE_NR" config delay "${DELAY_MS}ms"
+
+# Build pf rules: anchor existing config + dummynet for localhost:PG_PORT
+# Traffic in both directions goes through the pipe (so RTT = 2 * DELAY_MS)
+{
+ # Keep existing pf.conf rules
+ cat /etc/pf.conf
+ echo ""
+ echo "dummynet in quick on lo0 proto tcp from any to any port ${PG_PORT} pipe ${PIPE_NR}"
+ echo "dummynet out quick on lo0 proto tcp from any port ${PG_PORT} to any pipe ${PIPE_NR}"
+} | pfctl -f - -e 2>/dev/null || true
+
+echo "=== Verifying latency ==="
+# Quick check: time a trivial psql round-trip
+START=$(python3 -c "import time; print(time.time())")
+psql -h localhost -p "$PG_PORT" -c "SELECT 1" -t -q > /dev/null 2>&1 || true
+END=$(python3 -c "import time; print(time.time())")
+ELAPSED=$(python3 -c "print(f'{(${END} - ${START})*1000:.1f}ms')")
+echo "Single SELECT 1 round-trip: ${ELAPSED} (expected ~${DELAY_MS}ms added)"
+
+echo ""
+echo "=== Running benchmarks with ${DELAY_MS}ms latency ==="
+cd "$(dirname "$0")/../.."
+stack bench persistent-postgresql-ng \
+ --benchmark-arguments "--output persistent-postgresql-ng/bench/bench-latency-${DELAY_MS}ms.html"
diff --git a/persistent-postgresql-ng/cbits/hs_libpq_extra.c b/persistent-postgresql-ng/cbits/hs_libpq_extra.c
new file mode 100644
index 000000000..d53219b10
--- /dev/null
+++ b/persistent-postgresql-ng/cbits/hs_libpq_extra.c
@@ -0,0 +1,74 @@
+/*
+ * Extra libpq helpers for persistent-postgresql-ng.
+ *
+ * Provides:
+ * - hs_PQsetChunkedRowsMode: compile-time guarded wrapper for PG17+
+ * - hs_has_chunked_rows_support: reports whether chunked mode is available
+ * - hs_PQresultStatusRaw: raw int result status (bypasses Haskell enum)
+ *
+ * We avoid #include so that this file compiles without
+ * needing the libpq header search path in the Cabal build. Instead
+ * we forward-declare the opaque types and functions we use.
+ */
+
+/* Opaque libpq types */
+typedef struct pg_conn PGconn;
+typedef struct pg_result PGresult;
+
+/* libpq functions we call — these are resolved at link time from the
+ * libpq shared library that postgresql-libpq already links against. */
+extern int PQresultStatus(const PGresult *res);
+
+/*
+ * PGRES_TUPLES_CHUNK (= 12) was added in PostgreSQL 17. We hard-code the
+ * value here because the PG enum values are ABI-stable.
+ *
+ * For PQsetChunkedRowsMode we cannot simply call it when the header
+ * doesn't declare it. We use a weak symbol check on platforms that
+ * support it (GCC/Clang) to detect availability at load time.
+ */
+
+#if defined(__GNUC__) || defined(__clang__)
+
+/* Weak declaration: resolves to NULL if the symbol doesn't exist in
+ * the linked libpq (i.e. libpq < 17). */
+extern int PQsetChunkedRowsMode(PGconn *conn, int chunkSize)
+ __attribute__((weak));
+
+int hs_PQsetChunkedRowsMode(PGconn *conn, int chunkSize) {
+ if (PQsetChunkedRowsMode) {
+ return PQsetChunkedRowsMode(conn, chunkSize);
+ }
+ return 0;
+}
+
+int hs_has_chunked_rows_support(void) {
+ return PQsetChunkedRowsMode != 0;
+}
+
+#else
+
+/* Fallback for compilers without weak symbol support: no chunked mode. */
+int hs_PQsetChunkedRowsMode(PGconn *conn, int chunkSize) {
+ (void)conn;
+ (void)chunkSize;
+ return 0;
+}
+
+int hs_has_chunked_rows_support(void) {
+ return 0;
+}
+
+#endif
+
+/*
+ * Return the raw ExecStatus integer for a PGresult.
+ *
+ * The Haskell postgresql-libpq library's resultStatus throws on unknown
+ * status codes (e.g. PGRES_TUPLES_CHUNK = 12 when compiled against
+ * older headers). This helper lets us inspect the raw value and branch
+ * in Haskell without hitting that failure path.
+ */
+int hs_PQresultStatusRaw(const PGresult *res) {
+ return PQresultStatus(res);
+}
diff --git a/persistent-postgresql-ng/persistent-postgresql-ng.cabal b/persistent-postgresql-ng/persistent-postgresql-ng.cabal
new file mode 100644
index 000000000..ec2393b1a
--- /dev/null
+++ b/persistent-postgresql-ng/persistent-postgresql-ng.cabal
@@ -0,0 +1,142 @@
+name: persistent-postgresql-ng
+version: 0.1.0.0
+license: MIT
+license-file: LICENSE
+author: Ian Duncan
+maintainer: ian@ianduncan.me
+synopsis: Pipelined PostgreSQL backend for persistent using binary protocol.
+description:
+ A PostgreSQL backend for persistent that uses the binary wire protocol
+ via postgresql-libpq (>= 0.11) and postgresql-binary, with support for
+ libpq pipeline mode to reduce round-trips.
+category: Database
+stability: Experimental
+cabal-version: >=1.10
+build-type: Simple
+extra-source-files:
+ sql/*.sql
+
+library
+ build-depends:
+ aeson >=1.0
+ , base >=4.9 && <5
+ , bytestring >=0.10
+ , bytestring-strict-builder >=0.4
+ , conduit >=1.2.12
+ , containers >=0.5
+ , file-embed >=0.0.16
+ , monad-logger >=0.3.25
+ , mtl
+ , persistent >=2.18.1 && <3
+ , postgresql-binary >=0.13 && <0.15
+ , postgresql-libpq >=0.11 && <0.12
+ , resource-pool
+ , resourcet >=1.1.9
+ , scientific >=0.3
+ , text >=1.2
+ , time >=1.6
+ , transformers >=0.5
+ , unliftio-core
+ , vault
+ , vector
+
+ exposed-modules:
+ Database.Persist.Postgresql.Pipeline
+ Database.Persist.Postgresql.Pipeline.Internal
+ Database.Persist.Postgresql.Pipeline.FFI
+ Database.Persist.Postgresql.Internal
+ Database.Persist.Postgresql.Internal.Decoding
+ Database.Persist.Postgresql.Internal.Encoding
+ Database.Persist.Postgresql.Internal.DirectDecode
+ Database.Persist.Postgresql.Internal.DirectEncode
+ Database.Persist.Postgresql.Internal.PgCodec
+ Database.Persist.Postgresql.Internal.Migration
+ Database.Persist.Postgresql.Internal.PgType
+ Database.Persist.Postgresql.Internal.Placeholders
+ Database.Persist.Postgresql.JSON
+ Database.Persist.Postgresql.CustomType
+
+ c-sources: cbits/hs_libpq_extra.c
+
+ ghc-options: -Wall
+ default-language: Haskell2010
+
+test-suite test
+ type: exitcode-stdio-1.0
+ main-is: main.hs
+ hs-source-dirs: test
+ other-modules:
+ ArrayAggTest
+ BinaryRoundTripSpec
+ DirectDecodeSpec
+ DirectEntityPOC
+ CustomConstraintTest
+ EquivalentTypeTestPostgres
+ ImplicitUuidSpec
+ InCollapseSpec
+ JSONTest
+ MigrationReferenceSpec
+ MigrationSpec
+ PgPipelineInit
+ PgPipelineIntervalTest
+ PipelineDeferralSpec
+ PipelineModeSpec
+ PipelineRegressionSpec
+ PlaceholderSpec
+ UpsertWhere
+
+ ghc-options: -Wall
+ build-depends:
+ aeson
+ , base >=4.9 && <5
+ , bytestring
+ , containers
+ , fast-logger
+ , hspec >=2.4
+ , hspec-expectations
+ , hspec-expectations-lifted
+ , HUnit
+ , monad-logger
+ , persistent
+ , persistent-postgresql-ng
+ , persistent-qq
+ , persistent-test
+ , postgresql-binary
+ , postgresql-libpq
+ , QuickCheck
+ , quickcheck-instances
+ , resourcet
+ , scientific
+ , text
+ , time
+ , transformers
+ , unliftio
+ , unliftio-core
+ , unordered-containers
+ , vector
+
+ default-language: Haskell2010
+
+benchmark bench
+ type: exitcode-stdio-1.0
+ main-is: Main.hs
+ hs-source-dirs: bench
+ ghc-options: -Wall -O2 -rtsopts
+ build-depends:
+ base >=4.9 && <5
+ , bytestring
+ , criterion >=1.5
+ , monad-logger
+ , persistent
+ , persistent-postgresql
+ , persistent-postgresql-ng
+ , resource-pool
+ , text
+ , time
+ , transformers
+
+ default-language: Haskell2010
+
+source-repository head
+ type: git
+ location: https://github.com/yesodweb/persistent.git
diff --git a/persistent-postgresql-ng/sql/getForeignKeyReferences.sql b/persistent-postgresql-ng/sql/getForeignKeyReferences.sql
new file mode 100644
index 000000000..3bd1f8d35
--- /dev/null
+++ b/persistent-postgresql-ng/sql/getForeignKeyReferences.sql
@@ -0,0 +1,84 @@
+-- Get all foreign key references among the given set of table names in the
+-- current namespace/schema. This query is used by the migrator to check whether
+-- foreign key definitions are up to date.
+--
+-- This query takes one parameter: an array of table names.
+with
+ foreign_constraints as (
+ select
+ c.*
+ from
+ pg_constraint AS c
+ inner join pg_class src_table
+ on src_table.oid = c.conrelid
+ inner join pg_namespace ns
+ on ns.oid = c.connamespace
+ where
+ -- f = foreign key constraint
+ c.contype = 'f'
+ and src_table.relname = ANY (?)
+ and ns.nspname = current_schema()
+ ),
+ foreign_constraint_with_source_columns as (
+ select
+ c.oid,
+ array_agg(
+ a.attname::text
+ ORDER BY
+ k.n ASC
+ ) as column_names
+ from
+ foreign_constraints AS c
+ -- conkey is a list of the column indices on the source
+ -- table
+ CROSS JOIN LATERAL unnest(c.conkey) WITH ORDINALITY AS k (attnum, n)
+ INNER JOIN pg_attribute AS a
+ -- conrelid is the id of the source table
+ ON k.attnum = a.attnum AND c.conrelid = a.attrelid
+ group by
+ c.oid
+ ),
+ foreign_constraint_with_foreign_columns as (
+ select
+ c.oid,
+ array_agg(
+ a.attname::text
+ ORDER BY
+ k.n ASC
+ ) as foreign_column_names
+ from
+ foreign_constraints AS c
+ -- confkey is a list of the column indices on the foreign
+ -- table
+ CROSS JOIN LATERAL unnest(c.confkey) WITH ORDINALITY AS k (attnum, n)
+ JOIN pg_attribute AS a
+ -- confrelid is the id of the foreign table
+ ON k.attnum = a.attnum AND c.confrelid = a.attrelid
+ group by
+ c.oid
+ )
+SELECT
+ fkey_constraint.conname::text as fkey_name,
+ src_table.relname::text AS source_table,
+ foreign_table.relname::text AS referenced_table,
+ -- NB: postgres arrays are one-indexed!
+ src_columns.column_names[1],
+ foreign_columns.foreign_column_names[1],
+ fkey_constraint.confupdtype,
+ fkey_constraint.confdeltype
+from
+ foreign_constraints AS fkey_constraint
+ inner join foreign_constraint_with_source_columns src_columns
+ on src_columns.oid = fkey_constraint.oid
+ inner join foreign_constraint_with_foreign_columns foreign_columns
+ on foreign_columns.oid = fkey_constraint.oid
+ inner join pg_class src_table
+ on src_table.oid = fkey_constraint.conrelid
+ inner join pg_class foreign_table
+ on foreign_table.oid = fkey_constraint.confrelid
+
+-- In the future, we may want to look at multi-column FK constraints too. but
+-- for now we only care about single-column constraints.
+where
+ array_length(src_columns.column_names, 1) = 1
+ and array_length(foreign_columns.foreign_column_names, 1) = 1;
diff --git a/persistent-postgresql-ng/test/ArrayAggTest.hs b/persistent-postgresql-ng/test/ArrayAggTest.hs
new file mode 100644
index 000000000..cf97a7c10
--- /dev/null
+++ b/persistent-postgresql-ng/test/ArrayAggTest.hs
@@ -0,0 +1,71 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+module ArrayAggTest where
+
+import Control.Monad.IO.Class (MonadIO)
+import Data.Aeson
+import Data.List (sort)
+import qualified Data.Text as T
+import Test.Hspec.Expectations ()
+
+import PersistentTestModels
+import PgPipelineInit
+
+share
+ [mkPersist persistSettings, mkMigrate "jsonTestMigrate"]
+ [persistLowerCase|
+ TestValue
+ json Value
+|]
+
+cleanDB
+ :: (BaseBackend backend ~ SqlBackend, PersistQueryWrite backend, MonadIO m)
+ => ReaderT backend m ()
+cleanDB = deleteWhere ([] :: [Filter TestValue])
+
+emptyArr :: Value
+emptyArr = toJSON ([] :: [Value])
+
+specs :: Spec
+specs = do
+ describe "rawSql/array_agg" $ do
+ let
+ runArrayAggTest :: (PersistField [a], Ord a, Show a) => Text -> [a] -> Assertion
+ runArrayAggTest dbField expected = runConnAssert $ do
+ void $
+ insertMany
+ [ UserPT "a" $ Just "b"
+ , UserPT "c" $ Just "d"
+ , UserPT "e" Nothing
+ , UserPT "g" $ Just "h"
+ ]
+ escape <- getEscapeRawNameFunction
+ let
+ query =
+ T.concat
+ [ "SELECT array_agg("
+ , escape dbField
+ , ") "
+ , "FROM "
+ , escape "UserPT"
+ ]
+ [Single xs] <- rawSql query []
+ liftIO $ sort xs @?= expected
+
+ it "works for [Text]" $ do
+ runArrayAggTest "ident" ["a", "c", "e", "g" :: Text]
+ it "works for [Maybe Text]" $ do
+ runArrayAggTest "password" [Nothing, Just "b", Just "d", Just "h" :: Maybe Text]
diff --git a/persistent-postgresql-ng/test/BinaryRoundTripSpec.hs b/persistent-postgresql-ng/test/BinaryRoundTripSpec.hs
new file mode 100644
index 000000000..fa50acca0
--- /dev/null
+++ b/persistent-postgresql-ng/test/BinaryRoundTripSpec.hs
@@ -0,0 +1,152 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+
+-- | Pure round-trip tests for binary encoding/decoding of PersistValues.
+module BinaryRoundTripSpec (specs) where
+
+import Test.Hspec
+import Data.ByteString (ByteString)
+import Data.Int (Int64)
+import Data.Text (Text)
+import Data.Time
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+import Database.Persist (PersistValue (..))
+import Database.Persist.Postgresql.Internal.Encoding (encodePersistValue)
+import Database.Persist.Postgresql.Internal.Decoding (decodePersistValue)
+
+-- | Encode a value then decode it, verifying round-trip equality.
+roundTrip :: PersistValue -> Either Text PersistValue
+roundTrip pv = case encodePersistValue pv of
+ Nothing -> Right PersistNull
+ Just (oid, bs, _fmt) -> decodePersistValue oid (Just bs)
+
+specs :: Spec
+specs = describe "Binary round-trip" $ do
+ describe "scalar types" $ do
+ it "PersistNull" $ do
+ roundTrip PersistNull `shouldBe` Right PersistNull
+
+ it "PersistBool True" $ do
+ roundTrip (PersistBool True) `shouldBe` Right (PersistBool True)
+
+ it "PersistBool False" $ do
+ roundTrip (PersistBool False) `shouldBe` Right (PersistBool False)
+
+ it "PersistInt64" $ do
+ roundTrip (PersistInt64 42) `shouldBe` Right (PersistInt64 42)
+
+ it "PersistInt64 negative" $ do
+ roundTrip (PersistInt64 (-9999)) `shouldBe` Right (PersistInt64 (-9999))
+
+ it "PersistInt64 max" $ do
+ roundTrip (PersistInt64 maxBound) `shouldBe` Right (PersistInt64 maxBound)
+
+ it "PersistDouble" $ do
+ roundTrip (PersistDouble 3.14) `shouldBe` Right (PersistDouble 3.14)
+
+ it "PersistText" $ do
+ roundTrip (PersistText "hello") `shouldBe` Right (PersistText "hello")
+
+ it "PersistText empty" $ do
+ roundTrip (PersistText "") `shouldBe` Right (PersistText "")
+
+ it "PersistText unicode" $ do
+ roundTrip (PersistText "\x1F600\x00E9") `shouldBe` Right (PersistText "\x1F600\x00E9")
+
+ it "PersistByteString" $ do
+ roundTrip (PersistByteString "bytes\x00\xff")
+ `shouldBe` Right (PersistByteString "bytes\x00\xff")
+
+ it "PersistDay" $ do
+ let d = fromGregorian 2024 6 15
+ roundTrip (PersistDay d) `shouldBe` Right (PersistDay d)
+
+ it "PersistTimeOfDay" $ do
+ let t = TimeOfDay 14 30 0
+ roundTrip (PersistTimeOfDay t) `shouldBe` Right (PersistTimeOfDay t)
+
+ it "PersistUTCTime" $ do
+ let t = UTCTime (fromGregorian 2024 6 15) (12 * 3600 + 30 * 60)
+ roundTrip (PersistUTCTime t) `shouldBe` Right (PersistUTCTime t)
+
+ it "PersistRational" $ do
+ -- Rational -> numeric -> Rational round-trip
+ let r = 355 / 113 -- approximation of pi
+ case roundTrip (PersistRational r) of
+ Right (PersistRational r') ->
+ -- numeric has finite precision, so we check approximate equality
+ abs (fromRational (r - r') :: Double) `shouldSatisfy` (< 1e-10)
+ other -> expectationFailure $ "Expected PersistRational, got: " ++ show other
+
+ it "PersistRational integer" $ do
+ roundTrip (PersistRational 42) `shouldBe` Right (PersistRational 42)
+
+ describe "PersistList (JSON text for storage)" $ do
+ it "PersistList encodes with OID 0 (unknown)" $ do
+ let val = PersistList [PersistInt64 1, PersistInt64 2, PersistInt64 3]
+ case encodePersistValue val of
+ Just (oid, _, _) ->
+ oid `shouldBe` LibPQ.Oid 0 -- unknown, PG infers from column
+ Nothing -> expectationFailure "Expected Just for PersistList encoding"
+
+ it "PersistList empty encodes with OID 0" $ do
+ let val = PersistList []
+ case encodePersistValue val of
+ Just (oid, _, _) ->
+ oid `shouldBe` LibPQ.Oid 0 -- unknown, PG infers from column
+ Nothing -> expectationFailure "Expected Just for empty PersistList encoding"
+
+ describe "PersistArray (native PostgreSQL arrays)" $ do
+ it "PersistArray of Int64" $ do
+ let val = PersistArray [PersistInt64 1, PersistInt64 2, PersistInt64 3]
+ case encodePersistValue val of
+ Just (oid, _, _) -> do
+ oid `shouldBe` LibPQ.Oid 1016 -- int8[]
+ case roundTrip val of
+ Right (PersistList xs) ->
+ xs `shouldBe` [PersistInt64 1, PersistInt64 2, PersistInt64 3]
+ other -> expectationFailure $ "Expected PersistList, got: " ++ show other
+ Nothing -> expectationFailure "Expected Just for array encoding"
+
+ it "PersistArray of Text" $ do
+ let val = PersistArray [PersistText "hello", PersistText "world"]
+ case encodePersistValue val of
+ Just (oid, _, _) -> do
+ oid `shouldBe` LibPQ.Oid 1009 -- text[]
+ case roundTrip val of
+ Right (PersistList xs) ->
+ xs `shouldBe` [PersistText "hello", PersistText "world"]
+ other -> expectationFailure $ "Expected PersistList, got: " ++ show other
+ Nothing -> expectationFailure "Expected Just for array encoding"
+
+ it "PersistArray of Bool" $ do
+ let val = PersistArray [PersistBool True, PersistBool False]
+ case encodePersistValue val of
+ Just (oid, _, _) ->
+ oid `shouldBe` LibPQ.Oid 1000 -- bool[]
+ Nothing -> expectationFailure "Expected Just for array encoding"
+
+ it "PersistArray of Double" $ do
+ let val = PersistArray [PersistDouble 1.5, PersistDouble 2.5]
+ case encodePersistValue val of
+ Just (oid, _, _) ->
+ oid `shouldBe` LibPQ.Oid 1022 -- float8[]
+ Nothing -> expectationFailure "Expected Just for array encoding"
+
+ it "PersistArray with nulls" $ do
+ let val = PersistArray [PersistInt64 1, PersistNull, PersistInt64 3]
+ case encodePersistValue val of
+ Just (oid, _, _) -> do
+ oid `shouldBe` LibPQ.Oid 1016 -- int8[]
+ case roundTrip val of
+ Right (PersistList xs) ->
+ xs `shouldBe` [PersistInt64 1, PersistNull, PersistInt64 3]
+ other -> expectationFailure $ "Expected PersistList, got: " ++ show other
+ Nothing -> expectationFailure "Expected Just for array encoding"
+
+ it "PersistArray empty encodes with OID 0" $ do
+ let val = PersistArray []
+ case encodePersistValue val of
+ Just (oid, _, _) ->
+ oid `shouldBe` LibPQ.Oid 0 -- unknown, PG infers from column
+ Nothing -> expectationFailure "Expected Just for empty array encoding"
diff --git a/persistent-postgresql-ng/test/CustomConstraintTest.hs b/persistent-postgresql-ng/test/CustomConstraintTest.hs
new file mode 100644
index 000000000..5944e2e60
--- /dev/null
+++ b/persistent-postgresql-ng/test/CustomConstraintTest.hs
@@ -0,0 +1,78 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE EmptyDataDecls #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+module CustomConstraintTest where
+
+import qualified Data.Text as T
+import PgPipelineInit
+
+share
+ [mkPersist sqlSettings, mkMigrate "customConstraintMigrate"]
+ [persistLowerCase|
+CustomConstraint1
+ some_field Text
+ deriving Show
+
+CustomConstraint2
+ cc_id CustomConstraint1Id constraint=custom_constraint
+ deriving Show
+
+CustomConstraint3
+ cc_id1 CustomConstraint1Id
+ cc_id2 CustomConstraint1Id
+ deriving Show
+|]
+
+specs :: Spec
+specs = do
+ describe "custom constraint used in migration" $ do
+ it "custom constraint is actually created" $ runConnAssert $ do
+ void $ runMigrationSilent customConstraintMigrate
+ void $ runMigrationSilent customConstraintMigrate
+ let
+ query =
+ T.concat
+ [ "SELECT DISTINCT COUNT(*) "
+ , "FROM information_schema.constraint_column_usage ccu, "
+ , "information_schema.key_column_usage kcu, "
+ , "information_schema.table_constraints tc "
+ , "WHERE tc.constraint_type='FOREIGN KEY' "
+ , "AND kcu.constraint_name=tc.constraint_name "
+ , "AND ccu.constraint_name=kcu.constraint_name "
+ , "AND kcu.ordinal_position=1 "
+ , "AND ccu.table_name=? "
+ , "AND ccu.column_name=? "
+ , "AND kcu.table_name=? "
+ , "AND kcu.column_name=? "
+ , "AND tc.constraint_name=?"
+ ]
+ [Single exists_] <-
+ rawSql
+ query
+ [ PersistText "custom_constraint1"
+ , PersistText "id"
+ , PersistText "custom_constraint2"
+ , PersistText "cc_id"
+ , PersistText "custom_constraint"
+ ]
+ liftIO $ 1 @?= (exists_ :: Int)
+
+ it "allows multiple constraints on a single column" $ runConnAssert $ do
+ void $ runMigrationSilent customConstraintMigrate
+ rawExecute
+ "ALTER TABLE \"custom_constraint3\" ADD CONSTRAINT \"extra_constraint\" FOREIGN KEY(\"cc_id1\") REFERENCES \"custom_constraint1\"(\"id\")"
+ []
+ void $ getMigration customConstraintMigrate
+ pure ()
diff --git a/persistent-postgresql-ng/test/DirectDecodeSpec.hs b/persistent-postgresql-ng/test/DirectDecodeSpec.hs
new file mode 100644
index 000000000..c43ea7d3e
--- /dev/null
+++ b/persistent-postgresql-ng/test/DirectDecodeSpec.hs
@@ -0,0 +1,128 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE FlexibleContexts #-}
+
+-- | Tests for the direct decode path: FieldDecode instances for PgRowEnv,
+-- FromRow composition, and $N parameter detection.
+module DirectDecodeSpec (specs) where
+
+import Test.Hspec
+
+import Data.ByteString (ByteString)
+import Data.IORef (newIORef)
+import Data.Int (Int16, Int32, Int64)
+import Data.Text (Text)
+import Data.Time
+import qualified Data.Vector as V
+import qualified Database.PostgreSQL.LibPQ as LibPQ
+import qualified PostgreSQL.Binary.Encoding as PE
+
+import Database.Persist.DirectDecode
+ (FieldDecode (..), FromRow (..), RowReader, runRowReader, nextField)
+import Database.Persist.Names (FieldNameDB (..))
+import Database.Persist.Postgresql.Internal.DirectDecode (PgRowEnv (..))
+import Database.Persist.Postgresql.Internal.PgType
+import Database.Persist.Postgresql.Internal.Placeholders
+ (ParamStyle (..), detectParamStyle)
+
+specs :: Spec
+specs = do
+ paramStyleSpecs
+ fieldDecodeRoundTripSpecs
+
+---------------------------------------------------------------------------
+-- Parameter style detection
+---------------------------------------------------------------------------
+
+paramStyleSpecs :: Spec
+paramStyleSpecs = describe "detectParamStyle" $ do
+ it "detects ? as QuestionMarkParams" $ do
+ detectParamStyle "SELECT * FROM t WHERE a = ?" `shouldBe` QuestionMarkParams
+
+ it "detects $1 as Numbered 1" $ do
+ detectParamStyle "SELECT * FROM t WHERE a = $1" `shouldBe` NumberedParams 1
+
+ it "detects max $N" $ do
+ detectParamStyle "SELECT * FROM t WHERE a = $1 AND b = $3" `shouldBe` NumberedParams 3
+
+ it "detects $1 used multiple times" $ do
+ detectParamStyle "SELECT * FROM t WHERE a = $1 OR b = $1" `shouldBe` NumberedParams 1
+
+ it "ignores $N inside string literals" $ do
+ detectParamStyle "SELECT * FROM t WHERE a = '$1'" `shouldBe` QuestionMarkParams
+
+ it "ignores $N inside quoted identifiers" $ do
+ detectParamStyle "SELECT * FROM \"$1\" WHERE a = ?" `shouldBe` QuestionMarkParams
+
+ it "ignores $N inside line comments" $ do
+ detectParamStyle "SELECT * -- $1\nFROM t WHERE a = ?" `shouldBe` QuestionMarkParams
+
+ it "ignores $N inside block comments" $ do
+ detectParamStyle "SELECT /* $1 */ * FROM t WHERE a = ?" `shouldBe` QuestionMarkParams
+
+ it "handles empty SQL" $ do
+ detectParamStyle "" `shouldBe` QuestionMarkParams
+
+ it "handles no params" $ do
+ detectParamStyle "SELECT 1" `shouldBe` QuestionMarkParams
+
+---------------------------------------------------------------------------
+-- FieldDecode round-trip: encode with postgresql-binary → build a
+-- fake PGresult-like PgRowEnv → decode via FieldDecode.
+--
+-- These tests verify the FieldDecode instances decode correctly from
+-- known binary encodings without needing a live PostgreSQL connection.
+-- We construct a minimal PgRowEnv by encoding values into ByteStrings
+-- with known OIDs and wrapping them in a structure the instances can
+-- read from.
+---------------------------------------------------------------------------
+
+-- For pure tests without a real PGresult, we can't easily construct one.
+-- Instead, we test the parameter style detection (above) and verify the
+-- FieldDecode instances are correctly generated by TH by checking that
+-- the FromRow instance for the DataTypeTable entity (from main.hs) is
+-- available. Full integration testing requires a PostgreSQL connection.
+--
+-- The key value of these tests:
+-- 1. detectParamStyle correctly handles all SQL patterns
+-- 2. The TH generates FromRow instances that compile
+-- 3. The FieldDecode instance dispatch on PgType is correct (tested
+-- via the BinaryRoundTripSpec for the PersistValue path, and via
+-- integration tests for the direct path)
+
+fieldDecodeRoundTripSpecs :: Spec
+fieldDecodeRoundTripSpecs = describe "FieldDecode PgRowEnv" $ do
+ -- These are compile-time tests: if the module compiles, the
+ -- FieldDecode instances and nextField combinators exist.
+ it "nextField @Text compiles" $ do
+ let _decoder :: RowReader PgRowEnv Text
+ _decoder = nextField (FieldNameDB "col")
+ True `shouldBe` True
+
+ it "nextField @(Maybe Int64) compiles" $ do
+ let _decoder :: RowReader PgRowEnv (Maybe Int64)
+ _decoder = nextField (FieldNameDB "col")
+ True `shouldBe` True
+
+ it "applicative composition compiles" $ do
+ let _decoder :: RowReader PgRowEnv (Text, Int64, Maybe Bool)
+ _decoder = (,,)
+ <$> nextField (FieldNameDB "name")
+ <*> nextField (FieldNameDB "age")
+ <*> nextField (FieldNameDB "active")
+ True `shouldBe` True
+
+ it "all scalar FieldDecode instances compile" $ do
+ let _bool :: RowReader PgRowEnv Bool = nextField (FieldNameDB "")
+ _int16 :: RowReader PgRowEnv Int16 = nextField (FieldNameDB "")
+ _int32 :: RowReader PgRowEnv Int32 = nextField (FieldNameDB "")
+ _int64 :: RowReader PgRowEnv Int64 = nextField (FieldNameDB "")
+ _int :: RowReader PgRowEnv Int = nextField (FieldNameDB "")
+ _double :: RowReader PgRowEnv Double = nextField (FieldNameDB "")
+ _text :: RowReader PgRowEnv Text = nextField (FieldNameDB "")
+ _bs :: RowReader PgRowEnv ByteString = nextField (FieldNameDB "")
+ _day :: RowReader PgRowEnv Day = nextField (FieldNameDB "")
+ _tod :: RowReader PgRowEnv TimeOfDay = nextField (FieldNameDB "")
+ _utc :: RowReader PgRowEnv UTCTime = nextField (FieldNameDB "")
+ True `shouldBe` True
diff --git a/persistent-postgresql-ng/test/DirectEntityPOC.hs b/persistent-postgresql-ng/test/DirectEntityPOC.hs
new file mode 100644
index 000000000..bf88a4e36
--- /dev/null
+++ b/persistent-postgresql-ng/test/DirectEntityPOC.hs
@@ -0,0 +1,183 @@
+{-# LANGUAGE AllowAmbiguousTypes #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE InstanceSigs #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeApplications #-}
+
+-- | Proof-of-concept: the DirectEntity + Typeable/HRefl approach for
+-- bridging the SqlBackend type-erasure gap.
+--
+-- Part 1 (unit tests): Mock env, compile-time + runtime verification.
+-- Part 2 (integration): Real PostgreSQL query through SqlBackend using
+-- rawSqlDirectCompat + DirectEntity with PgRowEnv.
+module DirectEntityPOC (specs, integrationSpecs, rawSqlDirectCompatTest, PgPair (..)) where
+
+import Control.Monad.IO.Class (MonadIO, liftIO)
+import Control.Monad.Trans.Reader (ReaderT)
+import Data.Int (Int64)
+import Data.Proxy (Proxy (..))
+import Data.Text (Text)
+import qualified Data.Text as T
+import Data.Typeable (Typeable)
+import Type.Reflection (eqTypeRep, typeRep, (:~~:)(HRefl))
+import Test.Hspec
+
+import Database.Persist.DirectDecode
+import Database.Persist.Names (FieldNameDB (..))
+import Database.Persist.Sql.DirectRaw (rawSqlDirectCompat)
+import Database.Persist.SqlBackend.Internal (SqlBackend)
+import Database.Persist.Postgresql.Internal.DirectDecode (PgRowEnv (..))
+
+---------------------------------------------------------------------------
+-- Mock env for unit tests
+---------------------------------------------------------------------------
+
+data MockEnv = MockEnv
+ { mockRow :: !Int
+ , mockData :: ![(Text, Text)]
+ }
+
+instance FieldDecode MockEnv Text where
+ prepareField _env name _col _onErr onOk =
+ onOk $ FieldRunner $ \env' _onErr' onOk' ->
+ case lookup (unFieldNameDB name) (mockData env') of
+ Just t -> onOk' t
+ Nothing -> onOk' ""
+
+---------------------------------------------------------------------------
+-- Test record with MockEnv support
+---------------------------------------------------------------------------
+
+data TestUser = TestUser
+ { testUserName :: !Text
+ , testUserCity :: !Text
+ } deriving (Eq, Show)
+
+instance FromRow MockEnv TestUser where
+ rowReader = TestUser
+ <$> nextField (FieldNameDB "name")
+ <*> nextField (FieldNameDB "city")
+
+ prepareRow env ctr onErr onOk = do
+ col0 <- advanceCounter ctr
+ prepareField @MockEnv @Text env (FieldNameDB "name") col0 onErr $ \rName -> do
+ col1 <- advanceCounter ctr
+ prepareField @MockEnv @Text env (FieldNameDB "city") col1 onErr $ \rCity ->
+ onOk $ RowDecoder $ \env' onErr' onOk' ->
+ runField rName env' onErr' $ \n ->
+ runField rCity env' onErr' $ \c ->
+ onOk' (TestUser n c)
+
+instance DirectEntity TestUser where
+ lookupDirectDecoder :: forall env. Typeable env
+ => Proxy env -> Maybe (RowDecoder env TestUser)
+ lookupDirectDecoder _ =
+ case eqTypeRep (typeRep @env) (typeRep @MockEnv) of
+ Just HRefl -> Just $ RowDecoder $ \env' onErr' onOk' -> do
+ ctr <- newCounter
+ prepareRow env' ctr onErr' $ \decoder ->
+ runRowDecoder decoder env' onErr' onOk'
+ Nothing -> Nothing
+
+---------------------------------------------------------------------------
+-- A record for the real PgRowEnv integration test
+---------------------------------------------------------------------------
+
+data PgPair = PgPair
+ { pgPairName :: !Text
+ , pgPairAge :: !Int64
+ } deriving (Eq, Show)
+
+instance FromRow PgRowEnv PgPair where
+ rowReader = PgPair
+ <$> nextField (FieldNameDB "name")
+ <*> nextField (FieldNameDB "age")
+
+ prepareRow env ctr onErr onOk = do
+ col0 <- advanceCounter ctr
+ prepareField @PgRowEnv @Text env (FieldNameDB "name") col0 onErr $ \rName -> do
+ col1 <- advanceCounter ctr
+ prepareField @PgRowEnv @Int64 env (FieldNameDB "age") col1 onErr $ \rAge ->
+ onOk $ RowDecoder $ \env' onErr' onOk' ->
+ runField rName env' onErr' $ \n ->
+ runField rAge env' onErr' $ \a ->
+ onOk' (PgPair n a)
+
+instance DirectEntity PgPair where
+ lookupDirectDecoder :: forall env. Typeable env
+ => Proxy env -> Maybe (RowDecoder env PgPair)
+ lookupDirectDecoder _ =
+ case eqTypeRep (typeRep @env) (typeRep @PgRowEnv) of
+ Just HRefl -> Just $ RowDecoder $ \env' onErr' onOk' -> do
+ ctr <- newCounter
+ prepareRow env' ctr onErr' $ \decoder ->
+ runRowDecoder decoder env' onErr' onOk'
+ Nothing -> Nothing
+
+---------------------------------------------------------------------------
+-- Unit tests (no database needed)
+---------------------------------------------------------------------------
+
+specs :: Spec
+specs = describe "DirectEntity / SqlBackend bridge (unit)" $ do
+ it "lookupDirectDecoder returns Just for matching env" $ do
+ let mDecoder = lookupDirectDecoder @TestUser (Proxy @MockEnv)
+ case mDecoder of
+ Nothing -> expectationFailure "expected Just"
+ Just _ -> pure ()
+
+ it "lookupDirectDecoder returns Nothing for non-matching env" $ do
+ let mDecoder = lookupDirectDecoder @TestUser (Proxy @PgRowEnv)
+ case mDecoder of
+ Nothing -> pure ()
+ Just _ -> expectationFailure "expected Nothing"
+
+ it "end-to-end mock: decode through existential" $ do
+ let rows =
+ [ MockEnv 0 [("name", "Alice"), ("city", "NYC")]
+ , MockEnv 1 [("name", "Bob"), ("city", "SF")]
+ ]
+ case lookupDirectDecoder @TestUser (Proxy @MockEnv) of
+ Nothing -> expectationFailure "expected Just"
+ Just decoder -> do
+ let onErr e = fail (T.unpack e)
+ results <- mapM (\env -> runRowDecoderCPS decoder env onErr pure) rows
+ results `shouldBe`
+ [ TestUser "Alice" "NYC"
+ , TestUser "Bob" "SF"
+ ]
+
+ it "PgPair: lookupDirectDecoder returns Just for PgRowEnv" $ do
+ let mDecoder = lookupDirectDecoder @PgPair (Proxy @PgRowEnv)
+ case mDecoder of
+ Nothing -> expectationFailure "expected Just"
+ Just _ -> pure ()
+
+---------------------------------------------------------------------------
+-- Integration tests (require PostgreSQL)
+---------------------------------------------------------------------------
+
+-- | Run inside a SqlPersistT from the test harness.
+integrationSpecs :: Spec
+integrationSpecs = describe "DirectEntity / SqlBackend bridge (integration)" $
+ pure ()
+
+-- | The actual integration test action, run inside SqlPersistT by main.hs.
+--
+-- Executes a raw SQL query through 'rawSqlDirectCompat' which uses:
+-- 1. connDirectQueryCap on SqlBackend (existential PgRowEnv)
+-- 2. DirectEntity PgPair (eqTypeRep recovers PgRowEnv)
+-- 3. RowDecoder applied to real PGresult rows
+rawSqlDirectCompatTest
+ :: (MonadIO m)
+ => ReaderT SqlBackend m (Maybe [PgPair])
+rawSqlDirectCompatTest = do
+ -- Use a CTE to avoid needing a real table
+ rawSqlDirectCompat
+ "SELECT name, age FROM (VALUES ('Alice'::text, 30::int8), ('Bob'::text, 25::int8)) AS t(name, age)"
+ []
diff --git a/persistent-postgresql-ng/test/EquivalentTypeTestPostgres.hs b/persistent-postgresql-ng/test/EquivalentTypeTestPostgres.hs
new file mode 100644
index 000000000..74f85e196
--- /dev/null
+++ b/persistent-postgresql-ng/test/EquivalentTypeTestPostgres.hs
@@ -0,0 +1,54 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -Wno-unused-top-binds #-}
+
+module EquivalentTypeTestPostgres (specs) where
+
+import Control.Monad.Trans.Resource (runResourceT)
+import qualified Data.Text as T
+
+import Database.Persist.TH
+import PgPipelineInit
+
+share
+ [mkPersist sqlSettings, mkMigrate "migrateAll1"]
+ [persistLowerCase|
+EquivalentType sql=equivalent_types
+ field1 Int sqltype=bigint
+ field2 T.Text sqltype=text
+ field3 T.Text sqltype=us_postal_code
+ deriving Eq Show
+|]
+
+share
+ [mkPersist sqlSettings, mkMigrate "migrateAll2"]
+ [persistLowerCase|
+EquivalentType2 sql=equivalent_types
+ field1 Int sqltype=int8
+ field2 T.Text
+ field3 T.Text sqltype=us_postal_code
+ deriving Eq Show
+|]
+
+specs :: Spec
+specs = describe "doesn't migrate equivalent types" $ do
+ it "works" $ asIO $ runResourceT $ runConn $ do
+ _ <- rawExecute "DROP DOMAIN IF EXISTS us_postal_code CASCADE" []
+ _ <-
+ rawExecute "CREATE DOMAIN us_postal_code AS TEXT CHECK(VALUE ~ '^\\d{5}$')" []
+
+ _ <- runMigrationSilent migrateAll1
+ xs <- getMigration migrateAll2
+ liftIO $ xs @?= []
diff --git a/persistent-postgresql-ng/test/ImplicitUuidSpec.hs b/persistent-postgresql-ng/test/ImplicitUuidSpec.hs
new file mode 100644
index 000000000..a8a7197cc
--- /dev/null
+++ b/persistent-postgresql-ng/test/ImplicitUuidSpec.hs
@@ -0,0 +1,82 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+module ImplicitUuidSpec where
+
+import PgPipelineInit
+
+import Data.Proxy
+import Database.Persist.Postgresql.Pipeline
+
+import Database.Persist.ImplicitIdDef
+import Database.Persist.ImplicitIdDef.Internal (fieldTypeFromTypeable)
+
+share
+ [ mkPersist (sqlSettingsUuid "uuid_generate_v1mc()")
+ , mkEntityDefList "entities"
+ ]
+ [persistLowerCase|
+
+WithDefUuid
+ name Text sqltype=varchar(80)
+
+ deriving Eq Show Ord
+
+|]
+
+implicitUuidMigrate :: Migration
+implicitUuidMigrate = do
+ runSqlCommand $ rawExecute "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"" []
+ migrateModels entities
+
+wipe :: IO ()
+wipe = runConnAssert $ do
+ rawExecute "DROP TABLE with_def_uuid;" []
+ runMigration implicitUuidMigrate
+
+itDb
+ :: String -> SqlPersistT (LoggingT (ResourceT IO)) a -> SpecWith (Arg (IO ()))
+itDb msg action = it msg $ runConnAssert $ void action
+
+pass :: IO ()
+pass = pure ()
+
+spec :: Spec
+spec = describe "ImplicitUuidSpec" $ before_ wipe $ do
+ describe "WithDefUuidKey" $ do
+ it "works on UUIDs" $ do
+ let
+ withDefUuidKey = WithDefUuidKey (UUID "Hello")
+ pass
+ describe "getEntityId" $ do
+ let
+ Just idField = getEntityIdField (entityDef (Proxy @WithDefUuid))
+ it "has a UUID SqlType" $ asIO $ do
+ fieldSqlType idField `shouldBe` SqlOther "UUID"
+ it "is an implicit ID column" $ asIO $ do
+ fieldIsImplicitIdColumn idField `shouldBe` True
+
+ describe "insert" $ do
+ itDb "successfully has a default" $ do
+ let
+ matt =
+ WithDefUuid
+ { withDefUuidName =
+ "Matt"
+ }
+ k <- insert matt
+ mrec <- get k
+ mrec `shouldBe` Just matt
diff --git a/persistent-postgresql-ng/test/InCollapseSpec.hs b/persistent-postgresql-ng/test/InCollapseSpec.hs
new file mode 100644
index 000000000..03b66d0bd
--- /dev/null
+++ b/persistent-postgresql-ng/test/InCollapseSpec.hs
@@ -0,0 +1,76 @@
+{-# LANGUAGE OverloadedStrings #-}
+
+-- | Tests for the IN -> ANY SQL rewriting in the pipeline backend.
+module InCollapseSpec (specs) where
+
+import Test.Hspec
+import Database.Persist (PersistValue (..))
+import Database.Persist.Postgresql.Pipeline (inlineAndRewrite, collapseInClauses)
+
+specs :: Spec
+specs = describe "IN -> ANY collapsing" $ do
+ describe "collapseInClauses" $ do
+ it "collapses IN (?,?,?) into = ANY(?) with PersistArray" $ do
+ let (sql, params) = collapseInClauses
+ "SELECT * FROM t WHERE \"id\" IN (?,?,?)"
+ [PersistInt64 1, PersistInt64 2, PersistInt64 3]
+ sql `shouldBe` "SELECT * FROM t WHERE \"id\" = ANY(?)"
+ params `shouldBe` [PersistArray [PersistInt64 1, PersistInt64 2, PersistInt64 3]]
+
+ it "collapses NOT IN (?,?,?) into <> ALL(?) with PersistArray" $ do
+ let (sql, params) = collapseInClauses
+ "SELECT * FROM t WHERE \"id\" NOT IN (?,?,?)"
+ [PersistInt64 1, PersistInt64 2, PersistInt64 3]
+ sql `shouldBe` "SELECT * FROM t WHERE \"id\" <> ALL(?)"
+ params `shouldBe` [PersistArray [PersistInt64 1, PersistInt64 2, PersistInt64 3]]
+
+ it "does NOT collapse IN (?) with single param" $ do
+ let (sql, params) = collapseInClauses
+ "SELECT * FROM t WHERE \"id\" IN (?)"
+ [PersistInt64 42]
+ sql `shouldBe` "SELECT * FROM t WHERE \"id\" IN (?)"
+ params `shouldBe` [PersistInt64 42]
+
+ it "preserves params outside IN clauses" $ do
+ let (sql, params) = collapseInClauses
+ "SELECT * FROM t WHERE \"name\" IN (?,?) AND \"age\" > ?"
+ [PersistText "a", PersistText "b", PersistInt64 21]
+ sql `shouldBe` "SELECT * FROM t WHERE \"name\" = ANY(?) AND \"age\" > ?"
+ params `shouldBe` [PersistArray [PersistText "a", PersistText "b"], PersistInt64 21]
+
+ it "handles multiple IN clauses" $ do
+ let (sql, params) = collapseInClauses
+ "WHERE \"a\" IN (?,?) AND \"b\" IN (?,?,?)"
+ [ PersistInt64 1, PersistInt64 2
+ , PersistText "x", PersistText "y", PersistText "z"
+ ]
+ sql `shouldBe` "WHERE \"a\" = ANY(?) AND \"b\" = ANY(?)"
+ params `shouldBe`
+ [ PersistArray [PersistInt64 1, PersistInt64 2]
+ , PersistArray [PersistText "x", PersistText "y", PersistText "z"]
+ ]
+
+ it "does not collapse ? inside string literals" $ do
+ let (sql, params) = collapseInClauses
+ "SELECT * FROM t WHERE name = 'IN (?,?)' AND id IN (?,?)"
+ [PersistInt64 1, PersistInt64 2]
+ sql `shouldBe` "SELECT * FROM t WHERE name = 'IN (?,?)' AND id = ANY(?)"
+ params `shouldBe` [PersistArray [PersistInt64 1, PersistInt64 2]]
+
+ it "passes through when no IN clauses" $ do
+ let (sql, params) = collapseInClauses
+ "SELECT * FROM t WHERE id = ? AND name = ?"
+ [PersistInt64 1, PersistText "foo"]
+ sql `shouldBe` "SELECT * FROM t WHERE id = ? AND name = ?"
+ params `shouldBe` [PersistInt64 1, PersistText "foo"]
+
+ describe "inlineAndRewrite (full pipeline)" $ do
+ it "collapses IN and rewrites to $N" $ do
+ let (sql, params) = inlineAndRewrite
+ "SELECT * FROM t WHERE \"id\" IN (?,?,?) AND \"name\" = ?"
+ [PersistInt64 1, PersistInt64 2, PersistInt64 3, PersistText "foo"]
+ sql `shouldBe` "SELECT * FROM t WHERE \"id\" = ANY($1) AND \"name\" = $2"
+ params `shouldBe`
+ [ PersistArray [PersistInt64 1, PersistInt64 2, PersistInt64 3]
+ , PersistText "foo"
+ ]
diff --git a/persistent-postgresql-ng/test/JSONTest.hs b/persistent-postgresql-ng/test/JSONTest.hs
new file mode 100644
index 000000000..a0ab87b11
--- /dev/null
+++ b/persistent-postgresql-ng/test/JSONTest.hs
@@ -0,0 +1,760 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+module JSONTest where
+
+import Control.Monad.IO.Class (MonadIO)
+import Data.Aeson hiding (Key)
+import qualified Data.Vector as V (fromList)
+import Test.HUnit (assertBool)
+import Test.Hspec.Expectations ()
+
+import Database.Persist
+import Database.Persist.Postgresql.JSON
+
+import PgPipelineInit
+
+share
+ [mkPersist persistSettings, mkMigrate "jsonTestMigrate"]
+ [persistLowerCase|
+ TestValue
+ json Value
+ deriving Show
+|]
+
+cleanDB
+ :: (BaseBackend backend ~ SqlBackend, PersistQueryWrite backend, MonadIO m)
+ => ReaderT backend m ()
+cleanDB = deleteWhere ([] :: [Filter TestValue])
+
+emptyArr :: Value
+emptyArr = toJSON ([] :: [Value])
+
+insert'
+ :: (MonadIO m, PersistStoreWrite backend, BaseBackend backend ~ SqlBackend)
+ => Value -> ReaderT backend m (Key TestValue)
+insert' = insert . TestValue
+
+matchKeys
+ :: (Show record, Show (Key record), MonadIO m, Eq (Key record))
+ => [Key record] -> [Entity record] -> m ()
+matchKeys ys xs = do
+ msg1 `assertBoolIO` (xLen == yLen)
+ forM_ ys $ \y -> msg2 y `assertBoolIO` (y `elem` ks)
+ where
+ ks = entityKey <$> xs
+ xLen = length xs
+ yLen = length ys
+ msg1 =
+ mconcat
+ [ "\nexpected: "
+ , show yLen
+ , "\n but got: "
+ , show xLen
+ , "\n[xs: "
+ , show xs
+ , "]"
+ , "\n[ys: "
+ , show ys
+ , "]"
+ ]
+ msg2 y =
+ mconcat
+ [ "key \""
+ , show y
+ , "\" not in result:\n "
+ , show ks
+ ]
+
+setup :: IO TestKeys
+setup = asIO $ runConn_ $ do
+ void $ runMigrationSilent jsonTestMigrate
+ testKeys
+
+teardown :: IO ()
+teardown = asIO $ runConn_ $ do
+ cleanDB
+
+shouldBeIO :: (Show a, Eq a, MonadIO m) => a -> a -> m ()
+shouldBeIO x y = liftIO $ shouldBe x y
+
+assertBoolIO :: (MonadIO m) => String -> Bool -> m ()
+assertBoolIO s b = liftIO $ assertBool s b
+
+testKeys :: (Monad m, MonadIO m) => ReaderT SqlBackend m TestKeys
+testKeys = do
+ nullK <- insert' Null
+
+ boolTK <- insert' $ Bool True
+ boolFK <- insert' $ toJSON False
+
+ num0K <- insert' $ Number 0
+ num1K <- insert' $ Number 1
+ numBigK <- insert' $ toJSON (1234567890 :: Int)
+ numFloatK <- insert' $ Number 0.0
+ numSmallK <- insert' $ Number 0.0000000000000000123
+ numFloat2K <- insert' $ Number 1.5
+ -- numBigFloatK will turn into 9876543210.123457 because JSON
+ numBigFloatK <- insert' $ toJSON (9876543210.123456789 :: Double)
+
+ strNullK <- insert' $ String ""
+ strObjK <- insert' $ String "{}"
+ strArrK <- insert' $ String "[]"
+ strAK <- insert' $ String "a"
+ strTestK <- insert' $ toJSON ("testing" :: Text)
+ str2K <- insert' $ String "2"
+ strFloatK <- insert' $ String "0.45876"
+
+ arrNullK <- insert' $ Array $ V.fromList []
+ arrListK <- insert' $ toJSON [emptyArr, emptyArr, toJSON [emptyArr, emptyArr]]
+ arrList2K <-
+ insert' $
+ toJSON
+ [ emptyArr
+ , toJSON [Number 3, Bool False]
+ , toJSON [emptyArr, toJSON [Object mempty]]
+ ]
+ arrFilledK <-
+ insert' $
+ toJSON
+ [ Null
+ , Number 4
+ , String "b"
+ , Object mempty
+ , emptyArr
+ , object ["test" .= [Null], "test2" .= String "yes"]
+ ]
+ arrList3K <- insert' $ toJSON [toJSON [String "a"], Number 1]
+ arrList4K <- insert' $ toJSON [String "a", String "b", String "c", String "d"]
+
+ objNullK <- insert' $ Object mempty
+ objTestK <- insert' $ object ["test" .= Null, "test1" .= String "no"]
+ objDeepK <-
+ insert' $ object ["c" .= Number 24.986, "foo" .= object ["deep1" .= Bool True]]
+ objEmptyK <- insert' $ object ["" .= Number 9001]
+ objFullK <-
+ insert' $
+ object
+ [ "a" .= Number 1
+ , "b" .= Number 2
+ , "c" .= Number 3
+ , "d" .= Number 4
+ ]
+ return TestKeys{..}
+
+data TestKeys
+ = TestKeys
+ { nullK :: Key TestValue
+ , boolTK :: Key TestValue
+ , boolFK :: Key TestValue
+ , num0K :: Key TestValue
+ , num1K :: Key TestValue
+ , numBigK :: Key TestValue
+ , numFloatK :: Key TestValue
+ , numSmallK :: Key TestValue
+ , numFloat2K :: Key TestValue
+ , numBigFloatK :: Key TestValue
+ , strNullK :: Key TestValue
+ , strObjK :: Key TestValue
+ , strArrK :: Key TestValue
+ , strAK :: Key TestValue
+ , strTestK :: Key TestValue
+ , str2K :: Key TestValue
+ , strFloatK :: Key TestValue
+ , arrNullK :: Key TestValue
+ , arrListK :: Key TestValue
+ , arrList2K :: Key TestValue
+ , arrFilledK :: Key TestValue
+ , objNullK :: Key TestValue
+ , objTestK :: Key TestValue
+ , objDeepK :: Key TestValue
+ , arrList3K :: Key TestValue
+ , arrList4K :: Key TestValue
+ , objEmptyK :: Key TestValue
+ , objFullK :: Key TestValue
+ }
+ deriving (Eq, Ord, Show)
+
+specs :: Spec
+specs = afterAll_ teardown $ do
+ beforeAll setup $ do
+ describe "Testing JSON operators" $ do
+ describe "@>. object queries" $ do
+ it "matches an empty Object with any object" $
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. Object mempty] []
+ [objNullK, objTestK, objDeepK, objEmptyK, objFullK] `matchKeys` vals
+
+ it "matches a subset of object properties" $
+ -- {test: null, test1: no} @>. {test: null} == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. object ["test" .= Null]] []
+ [objTestK] `matchKeys` vals
+
+ it "matches a nested object against an empty object at the same key" $
+ -- {c: 24.986, foo: {deep1: true}} @>. {foo: {}} == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. object ["foo" .= object []]] []
+ [objDeepK] `matchKeys` vals
+
+ it "doesn't match a nested object against a string at the same key" $
+ -- {c: 24.986, foo: {deep1: true}} @>. {foo: nope} == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. object ["foo" .= String "nope"]] []
+ [] `matchKeys` vals
+
+ it "matches a nested object when the query object is identical" $
+ -- {c: 24.986, foo: {deep1: true}} @>. {foo: {deep1: true}} == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <-
+ selectList [TestValueJson @>. (object ["foo" .= object ["deep1" .= True]])] []
+ [objDeepK] `matchKeys` vals
+
+ it "doesn't match a nested object when queried with that exact object" $
+ -- {c: 24.986, foo: {deep1: true}} @>. {deep1: true} == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. object ["deep1" .= True]] []
+ [] `matchKeys` vals
+
+ describe "@>. array queries" $ do
+ it "matches an empty Array with any list" $
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. emptyArr] []
+ [arrNullK, arrListK, arrList2K, arrFilledK, arrList3K, arrList4K]
+ `matchKeys` vals
+
+ it "matches list when queried with subset (1 item)" $
+ -- [null, 4, 'b', {}, [], {test: [null], test2: 'yes'}] @>. [4] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [4 :: Int]] []
+ [arrFilledK] `matchKeys` vals
+
+ it "matches list when queried with subset (2 items)" $
+ -- [null, 4, 'b', {}, [], {test: [null], test2: 'yes'}] @>. [null,'b'] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [Null, String "b"]] []
+ [arrFilledK] `matchKeys` vals
+
+ it "doesn't match list when queried with intersecting list (1 match, 1 diff)" $
+ -- [null, 4, 'b', {}, [], {test: [null], test2: 'yes'}] @>. [null,'d'] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [emptyArr, String "d"]] []
+ [] `matchKeys` vals
+
+ it "matches list when queried with same list in different order" $
+ -- [null, 4, 'b', {}, [], {test: [null], test2: 'yes'}] @>.
+ -- [[],'b',{test: [null],test2: 'yes'},4,null,{}] == True
+ \TestKeys{..} -> runConnAssert $ do
+ let
+ queryList =
+ toJSON
+ [ emptyArr
+ , String "b"
+ , object ["test" .= [Null], "test2" .= String "yes"]
+ , Number 4
+ , Null
+ , Object mempty
+ ]
+
+ vals <- selectList [TestValueJson @>. queryList] []
+ [arrFilledK] `matchKeys` vals
+
+ it "doesn't match list when queried with same list + 1 item" $
+ -- [null,4,'b',{},[],{test:[null],test2:'yes'}] @>.
+ -- [null,4,'b',{},[],{test:[null],test2: 'yes'}, false] == False
+ \TestKeys{..} -> runConnAssert $ do
+ let
+ testList =
+ toJSON
+ [ Null
+ , Number 4
+ , String "b"
+ , Object mempty
+ , emptyArr
+ , object ["test" .= [Null], "test2" .= String "yes"]
+ , Bool False
+ ]
+
+ vals <- selectList [TestValueJson @>. testList] []
+ [] `matchKeys` vals
+
+ it "matches list when it shares an empty object with the query list" $
+ -- [null,4,'b',{},[],{test: [null],test2: 'yes'}] @>. [{}] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [Object mempty]] []
+ [arrFilledK] `matchKeys` vals
+
+ it "matches list with nested list, when queried with an empty nested list" $
+ -- [null,4,'b',{},[],{test:[null],test2:'yes'}] @>. [{test:[]}] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [object ["test" .= emptyArr]]] []
+ [arrFilledK] `matchKeys` vals
+
+ it "doesn't match list with nested list, when queried with a diff. nested list" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>.
+ -- [{"test1":[null]}] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [object ["test1" .= [Null]]]] []
+ [] `matchKeys` vals
+
+ it "matches many nested lists when queried with empty nested list" $
+ -- [[],[],[[],[]]] @>. [[]] == True
+ -- [[],[3,false],[[],[{}]]] @>. [[]] == True
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. [[]] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [emptyArr]] []
+ [arrListK, arrList2K, arrFilledK, arrList3K] `matchKeys` vals
+
+ it "matches nested list when queried with a subset of that list" $
+ -- [[],[3,false],[[],[{}]]] @>. [[3]] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [[3 :: Int]]] []
+ [arrList2K] `matchKeys` vals
+
+ it "doesn't match nested list againts a partial intersection of that list" $
+ -- [[],[3,false],[[],[{}]]] @>. [[true,3]] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON [[Bool True, Number 3]]] []
+ [] `matchKeys` vals
+
+ it "matches list when queried with raw number contained in the list" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. 4 == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. Number 4] []
+ [arrFilledK] `matchKeys` vals
+
+ it "doesn't match list when queried with raw value not contained in the list" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. 99 == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. Number 99] []
+ [] `matchKeys` vals
+
+ it "matches list when queried with raw string contained in the list" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. "b" == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. String "b"] []
+ [arrFilledK, arrList4K] `matchKeys` vals
+
+ it "doesn't match list with empty object when queried with \"{}\" " $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. "{}" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. String "{}"] []
+ [strObjK] `matchKeys` vals
+
+ it "doesnt match list with nested object when queried with object (not in list)" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>.
+ -- {"test":[null],"test2":"yes"} == False
+ \TestKeys{..} -> runConnAssert $ do
+ let
+ queryObject = object ["test" .= [Null], "test2" .= String "yes"]
+ vals <- selectList [TestValueJson @>. queryObject] []
+ [] `matchKeys` vals
+
+ describe "@>. string queries" $ do
+ it "matches identical strings" $
+ -- "testing" @>. "testing" == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. String "testing"] []
+ [strTestK] `matchKeys` vals
+
+ it "doesnt match case insensitive" $
+ -- "testing" @>. "Testing" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. String "Testing"] []
+ [] `matchKeys` vals
+
+ it "doesn't match substrings" $
+ -- "testing" @>. "test" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. String "test"] []
+ [] `matchKeys` vals
+
+ it "doesn't match strings with object keys" $
+ -- "testing" @>. {"testing":1} == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. object ["testing" .= Number 1]] []
+ [] `matchKeys` vals
+
+ describe "@>. number queries" $ do
+ it "matches identical numbers" $
+ -- 1 @>. 1 == True
+ -- [1] @>. 1 == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON (1 :: Int)] []
+ [num1K, arrList3K] `matchKeys` vals
+
+ it "matches numbers when queried with float" $
+ -- 0 @>. 0.0 == True
+ -- 0.0 @>. 0.0 == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON (0.0 :: Double)] []
+ [num0K, numFloatK] `matchKeys` vals
+
+ it "does not match numbers when queried with a substring of that number" $
+ -- 1234567890 @>. 123456789 == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON (123456789 :: Int)] []
+ [] `matchKeys` vals
+
+ it "does not match number when queried with different number" $
+ -- 1234567890 @>. 234567890 == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON (234567890 :: Int)] []
+ [] `matchKeys` vals
+
+ it "does not match number when queried with string of that number" $
+ -- 1 @>. "1" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. String "1"] []
+ [] `matchKeys` vals
+
+ it "does not match number when queried with list of digits" $
+ -- 1234567890 @>. [1,2,3,4,5,6,7,8,9,0] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <-
+ selectList
+ [TestValueJson @>. toJSON ([1, 2, 3, 4, 5, 6, 7, 8, 9, 0] :: [Int])]
+ []
+ [] `matchKeys` vals
+
+ describe "@>. boolean queries" $ do
+ it "matches identical booleans (True)" $
+ -- true @>. true == True
+ -- false @>. true == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. toJSON True] []
+ [boolTK] `matchKeys` vals
+
+ it "matches identical booleans (False)" $
+ -- false @>. false == True
+ -- true @>. false == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. Bool False] []
+ [boolFK] `matchKeys` vals
+
+ it "does not match boolean with string of boolean" $
+ -- true @>. "true" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. String "true"] []
+ [] `matchKeys` vals
+
+ describe "@>. null queries" $ do
+ it "matches nulls" $
+ -- null @>. null == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. Null] []
+ [nullK, arrFilledK] `matchKeys` vals
+
+ it "does not match null with string of null" $
+ -- null @>. "null" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson @>. String "null"] []
+ [] `matchKeys` vals
+
+ describe "<@. queries" $ do
+ it "matches subobject when queried with superobject" $
+ -- {} <@. {"test":null,"test1":"no","blabla":[]} == True
+ -- {"test":null,"test1":"no"} <@. {"test":null,"test1":"no","blabla":[]} == True
+ \TestKeys{..} -> runConnAssert $ do
+ let
+ queryObject =
+ object
+ [ "test" .= Null
+ , "test1" .= String "no"
+ , "blabla" .= emptyArr
+ ]
+ vals <- selectList [TestValueJson <@. queryObject] []
+ [objNullK, objTestK] `matchKeys` vals
+
+ it "matches raw values and sublists when queried with superlist" $
+ -- [] <@. [null,4,"b",{},[],{"test":[null],"test2":"yes"},false] == True
+ -- null <@. [null,4,"b",{},[],{"test":[null],"test2":"yes"},false] == True
+ -- false <@. [null,4,"b",{},[],{"test":[null],"test2":"yes"},false] == True
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] <@.
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"},false] == True
+ \TestKeys{..} -> runConnAssert $ do
+ let
+ queryList =
+ toJSON
+ [ Null
+ , Number 4
+ , String "b"
+ , Object mempty
+ , emptyArr
+ , object ["test" .= [Null], "test2" .= String "yes"]
+ , Bool False
+ ]
+
+ vals <- selectList [TestValueJson <@. queryList] []
+ [arrNullK, arrFilledK, boolFK, nullK] `matchKeys` vals
+
+ it "matches identical strings" $
+ -- "a" <@. "a" == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson <@. String "a"] []
+ [strAK] `matchKeys` vals
+
+ it "matches identical big floats" $
+ -- 9876543210.123457 <@ 9876543210.123457 == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson <@. Number 9876543210.123457] []
+ [numBigFloatK] `matchKeys` vals
+
+ it "doesn't match different big floats" $
+ -- 9876543210.123457 <@. 9876543210.123456789 == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson <@. Number 9876543210.123456789] []
+ [] `matchKeys` vals
+
+ it "matches nulls" $
+ -- null <@. null == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson <@. Null] []
+ [nullK] `matchKeys` vals
+
+ describe "?. queries" $ do
+ it "matches top level keys and not the keys of nested objects" $
+ -- {"test":null,"test1":"no"} ?. "test" == True
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?. "test" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. "test"] []
+ [objTestK] `matchKeys` vals
+
+ it "doesn't match nested key" $
+ -- {"c":24.986,"foo":{"deep1":true"}} ?. "deep1" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. "deep1"] []
+ [] `matchKeys` vals
+
+ it "matches \"{}\" but not empty object when queried with \"{}\"" $
+ -- "{}" ?. "{}" == True
+ -- {} ?. "{}" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. "{}"] []
+ [strObjK] `matchKeys` vals
+
+ it "matches raw empty str and empty str key when queried with \"\"" $
+ ---- {} ?. "" == False
+ ---- "" ?. "" == True
+ ---- {"":9001} ?. "" == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. ""] []
+ [strNullK, objEmptyK] `matchKeys` vals
+
+ it "matches lists containing string value when queried with raw string value" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?. "b" == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. "b"] []
+ [arrFilledK, arrList4K, objFullK] `matchKeys` vals
+
+ it "matches lists, objects, and raw values correctly when queried with string" $
+ -- [["a"]] ?. "a" == False
+ -- "a" ?. "a" == True
+ -- ["a","b","c","d"] ?. "a" == True
+ -- {"a":1,"b":2,"c":3,"d":4} ?. "a" == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. "a"] []
+ [strAK, arrList4K, objFullK] `matchKeys` vals
+
+ it "matches string list but not real list when queried with \"[]\"" $
+ -- "[]" ?. "[]" == True
+ -- [] ?. "[]" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. "[]"] []
+ [strArrK] `matchKeys` vals
+
+ it "does not match null when queried with string null" $
+ -- null ?. "null" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. "null"] []
+ [] `matchKeys` vals
+
+ it "does not match bool whe nqueried with string bool" $
+ -- true ?. "true" == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?. "true"] []
+ [] `matchKeys` vals
+
+ describe "?|. queries" $ do
+ it "matches raw vals, lists, objects, and nested objects" $
+ -- "a" ?|. ["a","b","c"] == True
+ -- [["a"],1] ?|. ["a","b","c"] == False
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?|. ["a","b","c"] == True
+ -- ["a","b","c","d"] ?|. ["a","b","c"] == True
+ -- {"a":1,"b":2,"c":3,"d":4} ?|. ["a","b","c"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?|. ["a", "b", "c"]] []
+ [strAK, arrFilledK, objDeepK, arrList4K, objFullK] `matchKeys` vals
+
+ it "matches str object but not object when queried with \"{}\"" $
+ -- "{}" ?|. ["{}"] == True
+ -- {} ?|. ["{}"] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?|. ["{}"]] []
+ [strObjK] `matchKeys` vals
+
+ it "doesn't match superstrings when queried with substring" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?|. ["test"] == False
+ -- "testing" ?|. ["test"] == False
+ -- {"test":null,"test1":"no"} ?|. ["test"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?|. ["test"]] []
+ [objTestK] `matchKeys` vals
+
+ it "doesn't match nested keys" $
+ -- {"c":24.986,"foo":{"deep1":true"}} ?|. ["deep1"] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?|. ["deep1"]] []
+ [] `matchKeys` vals
+
+ it "doesn't match anything when queried with empty list" $
+ -- ANYTHING ?|. [] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?|. []] []
+ [] `matchKeys` vals
+
+ it "doesn't match raw, non-string, values when queried with strings" $
+ -- true ?|. ["true","null","1"] == False
+ -- null ?|. ["true","null","1"] == False
+ -- 1 ?|. ["true","null","1"] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?|. ["true", "null", "1"]] []
+ [] `matchKeys` vals
+
+ it "matches string array when queried with \"[]\"" $
+ -- [] ?|. ["[]"] == False
+ -- "[]" ?|. ["[]"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?|. ["[]"]] []
+ [strArrK] `matchKeys` vals
+
+ describe "?&. queries" $ do
+ it "matches anything when queried with an empty list" $
+ -- ANYTHING ?&. [] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. []] []
+ flip
+ matchKeys
+ vals
+ [ nullK
+ , boolTK
+ , boolFK
+ , num0K
+ , num1K
+ , numBigK
+ , numFloatK
+ , numSmallK
+ , numFloat2K
+ , numBigFloatK
+ , strNullK
+ , strObjK
+ , strArrK
+ , strAK
+ , strTestK
+ , str2K
+ , strFloatK
+ , arrNullK
+ , arrListK
+ , arrList2K
+ , arrFilledK
+ , arrList3K
+ , arrList4K
+ , objNullK
+ , objTestK
+ , objDeepK
+ , objEmptyK
+ , objFullK
+ ]
+
+ it "matches raw values, lists, and objects when queried with string" $
+ -- "a" ?&. ["a"] == True
+ -- [["a"],1] ?&. ["a"] == False
+ -- ["a","b","c","d"] ?&. ["a"] == True
+ -- {"a":1,"b":2,"c":3,"d":4} ?&. ["a"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. ["a"]] []
+ [strAK, arrList4K, objFullK] `matchKeys` vals
+
+ it "matches raw values, lists, and objects when queried with multiple string" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?&. ["b","c"] == False
+ -- {"c":24.986,"foo":{"deep1":true"}} ?&. ["b","c"] == False
+ -- ["a","b","c","d"] ?&. ["b","c"] == True
+ -- {"a":1,"b":2,"c":3,"d":4} ?&. ["b","c"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. ["b", "c"]] []
+ [arrList4K, objFullK] `matchKeys` vals
+
+ it "matches object string when queried with \"{}\"" $
+ -- {} ?&. ["{}"] == False
+ -- "{}" ?&. ["{}"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. ["{}"]] []
+ [strObjK] `matchKeys` vals
+
+ it "doesn't match superstrings when queried with substring" $
+ -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?&. ["test"] == False
+ -- "testing" ?&. ["test"] == False
+ -- {"test":null,"test1":"no"} ?&. ["test"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. ["test"]] []
+ [objTestK] `matchKeys` vals
+
+ it "doesn't match nested keys" $
+ -- {"c":24.986,"foo":{"deep1":true"}} ?&. ["deep1"] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. ["deep1"]] []
+ [] `matchKeys` vals
+
+ it "doesn't match anything when there is a partial match" $
+ -- "a" ?&. ["a","e"] == False
+ -- ["a","b","c","d"] ?&. ["a","e"] == False
+ -- {"a":1,"b":2,"c":3,"d":4} ?&. ["a","e"] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. ["a", "e"]] []
+ [] `matchKeys` vals
+
+ it "matches string array when queried with \"[]\"" $
+ -- [] ?&. ["[]"] == False
+ -- "[]" ?&. ["[]"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. ["[]"]] []
+ [strArrK] `matchKeys` vals
+
+ it "doesn't match null when queried with string null" $
+ -- THIS WILL FAIL IF THE IMPLEMENTATION USES
+ -- @ '{null}' @
+ -- INSTEAD OF
+ -- @ ARRAY['null'] @
+ -- null ?&. ["null"] == False
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. ["null"]] []
+ [] `matchKeys` vals
+
+ it "doesn't match number when queried with str of that number" $
+ -- [["a"],1] ?&. ["1"] == False
+ -- "1" ?&. ["1"] == True
+ \TestKeys{..} -> runConnAssert $ do
+ str1 <- insert' $ toJSON $ String "1"
+ vals <- selectList [TestValueJson ?&. ["1"]] []
+ [str1] `matchKeys` vals
+
+ it "doesn't match empty objs or list when queried with empty string" $
+ -- {} ?&. [""] == False
+ -- [] ?&. [""] == False
+ -- "" ?&. [""] == True
+ -- {"":9001} ?&. [""] == True
+ \TestKeys{..} -> runConnAssert $ do
+ vals <- selectList [TestValueJson ?&. [""]] []
+ [strNullK, objEmptyK] `matchKeys` vals
diff --git a/persistent-postgresql-ng/test/MigrationReferenceSpec.hs b/persistent-postgresql-ng/test/MigrationReferenceSpec.hs
new file mode 100644
index 000000000..f2079a065
--- /dev/null
+++ b/persistent-postgresql-ng/test/MigrationReferenceSpec.hs
@@ -0,0 +1,61 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -Wno-unused-top-binds #-}
+
+module MigrationReferenceSpec where
+
+import PgPipelineInit
+
+import Control.Monad.Trans.Writer (censor, mapWriterT)
+import Data.Text (Text, isInfixOf)
+
+share
+ [mkPersist sqlSettings, mkMigrate "referenceMigrate"]
+ [persistLowerCase|
+
+LocationCapabilities
+ Id Text
+ bio Text
+
+LocationCapabilitiesPrintingProcess
+ locationCapabilitiesId LocationCapabilitiesId
+
+LocationCapabilitiesPrintingFinish
+ locationCapabilitiesId LocationCapabilitiesId
+
+LocationCapabilitiesSubstrate
+ locationCapabilitiesId LocationCapabilitiesId
+
+|]
+
+spec :: Spec
+spec = describe "MigrationReferenceSpec" $ do
+ it "works" $ runConnAssert $ do
+ let
+ noForeignKeys :: CautiousMigration -> CautiousMigration
+ noForeignKeys = filter ((not . isReference) . snd)
+
+ onlyForeignKeys :: CautiousMigration -> CautiousMigration
+ onlyForeignKeys = filter (isReference . snd)
+
+ isReference :: Text -> Bool
+ isReference migration = "REFERENCES" `isInfixOf` migration
+
+ runMigration $
+ mapWriterT (censor noForeignKeys) $
+ referenceMigrate
+
+ runMigration $
+ mapWriterT (censor onlyForeignKeys) $
+ referenceMigrate
diff --git a/persistent-postgresql-ng/test/MigrationSpec.hs b/persistent-postgresql-ng/test/MigrationSpec.hs
new file mode 100644
index 000000000..44ed41876
--- /dev/null
+++ b/persistent-postgresql-ng/test/MigrationSpec.hs
@@ -0,0 +1,687 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+module MigrationSpec where
+
+import PgPipelineInit
+
+import Data.Foldable (traverse_)
+import qualified Data.Map as Map
+import Data.Proxy
+import qualified Data.Set as Set
+import qualified Data.Text as T
+import Database.Persist.Postgresql.Internal.Migration
+
+getStmtGetter
+ :: (Monad m) => SqlPersistT m (Text -> IO Statement)
+getStmtGetter = do
+ backend <- ask
+ pure (getStmtConn backend)
+
+-- NB: we do not perform these migrations in main.hs
+share
+ [mkPersist persistSettings{mpsGeneric = False}]
+ [persistLowerCase|
+User sql=users
+ name Text
+ title Text Maybe
+ deriving Show Eq
+
+UserFriendship sql=user_friendships
+ user1Id UserId Maybe
+ user2Id UserId Maybe
+ deriving Show Eq
+
+Password sql=passwords
+ passwordHash Text
+ userId UserId Maybe
+ UniqueUserId userId !force
+
+Password2 sql=passwords_2
+ passwordHash Text
+ userId UserId Maybe OnDeleteCascade OnUpdateSetNull
+ UniqueUserId2 userId !force
+
+AdminUser sql=admin_users
+ userId UserId
+ Primary userId
+
+ promotedByUserId UserId
+ UniquePromotedByUserId promotedByUserId
+
+FKParent sql=migration_fk_parent
+
+FKChildV1 sql=migration_fk_child
+
+-- Simulate creating a new FK field on an existing table
+FKChildV2 sql=migration_fk_child
+ parentId FKParentId
+
+ExplicitPrimaryKey sql=explicit_primary_key
+ Id Text
+|]
+
+userEntityDef :: EntityDef
+userEntityDef = entityDef (Proxy :: Proxy User)
+
+userFriendshipEntityDef :: EntityDef
+userFriendshipEntityDef = entityDef (Proxy :: Proxy UserFriendship)
+
+passwordEntityDef :: EntityDef
+passwordEntityDef = entityDef (Proxy :: Proxy Password)
+
+password2EntityDef :: EntityDef
+password2EntityDef = entityDef (Proxy :: Proxy Password2)
+
+adminUserEntityDef :: EntityDef
+adminUserEntityDef = entityDef (Proxy :: Proxy AdminUser)
+
+fkParentEntityDef :: EntityDef
+fkParentEntityDef = entityDef (Proxy :: Proxy FKParent)
+
+fkChildV1EntityDef :: EntityDef
+fkChildV1EntityDef = entityDef (Proxy :: Proxy FKChildV1)
+
+fkChildV2EntityDef :: EntityDef
+fkChildV2EntityDef = entityDef (Proxy :: Proxy FKChildV2)
+
+explicitPrimaryKeyEntityDef :: EntityDef
+explicitPrimaryKeyEntityDef = entityDef (Proxy :: Proxy ExplicitPrimaryKey)
+
+-- Note that FKChild is deliberately omitted here because we have two
+-- versions of it
+allEntityDefs :: [EntityDef]
+allEntityDefs =
+ [ userEntityDef
+ , userFriendshipEntityDef
+ , passwordEntityDef
+ , password2EntityDef
+ , adminUserEntityDef
+ , fkParentEntityDef
+ , explicitPrimaryKeyEntityDef
+ ]
+
+-- Note that this function migrates to the schema expected by FKChildV1
+migrateManually :: (HasCallStack, MonadIO m) => SqlPersistT m ()
+migrateManually = do
+ cleanDB
+ let
+ rawEx sql = rawExecute sql []
+ rawEx
+ "CREATE TABLE users(id int8 primary key, name text not null, title text);"
+ rawEx $
+ T.concat
+ [ "CREATE TABLE user_friendships("
+ , " id int8 primary key,"
+ , " user1_id int8 references users(id) on delete restrict on update restrict,"
+ , " user2_id int8 references users(id) on delete restrict on update restrict"
+ , ");"
+ ]
+ rawEx $
+ T.concat
+ [ "CREATE TABLE passwords("
+ , " id int8 primary key,"
+ , " password_hash text not null,"
+ , " user_id int8 references users(id) on delete restrict on update restrict"
+ , ");"
+ ]
+ rawEx $
+ T.concat
+ [ "ALTER TABLE passwords"
+ , " ADD CONSTRAINT unique_user_id"
+ , " UNIQUE(user_id);"
+ ]
+ rawEx $
+ T.concat
+ [ "CREATE TABLE passwords_2("
+ , " id int8 primary key,"
+ , " password_hash text not null,"
+ , " user_id int8 references users(id) on delete cascade on update set null"
+ , ");"
+ ]
+ rawEx $
+ T.concat
+ [ "ALTER TABLE passwords_2"
+ , " ADD CONSTRAINT unique_user_id2"
+ , " UNIQUE(user_id);"
+ ]
+ -- Add an extra redundant FK constraint on passwords_2.user_id, so that we
+ -- can test that the migrator ignores it
+ rawEx $
+ T.concat
+ [ "ALTER TABLE passwords_2"
+ , " ADD CONSTRAINT duplicate_passwords_2_user_id_fkey"
+ , " FOREIGN KEY (user_id) REFERENCES users(id);"
+ ]
+ rawEx $
+ T.concat
+ [ "CREATE TABLE admin_users("
+ , " user_id int8 not null references users(id) on delete restrict on update restrict primary key,"
+ , " promoted_by_user_id int8 not null references users(id) on delete restrict on update restrict"
+ , ");"
+ ]
+ rawEx $
+ T.concat
+ [ "ALTER TABLE admin_users"
+ , " ADD CONSTRAINT unique_promoted_by_user_id"
+ , " UNIQUE(promoted_by_user_id);"
+ ]
+ rawEx "CREATE TABLE migration_fk_parent(id int8 primary key);"
+ rawEx "CREATE TABLE migration_fk_child(id int8 primary key);"
+ rawEx "CREATE TABLE explicit_primary_key(id text primary key);"
+ rawEx "CREATE TABLE ignored(id int8 primary key);"
+
+cleanDB :: (HasCallStack, MonadIO m) => SqlPersistT m ()
+cleanDB = do
+ let
+ rawEx sql = rawExecute sql []
+ rawEx "DROP TABLE IF EXISTS user_friendships;"
+ rawEx "DROP TABLE IF EXISTS passwords;"
+ rawEx "DROP TABLE IF EXISTS passwords_2;"
+ rawEx "DROP TABLE IF EXISTS ignored;"
+ rawEx "DROP TABLE IF EXISTS admin_users;"
+ rawEx "DROP TABLE IF EXISTS users;"
+ rawEx "DROP TABLE IF EXISTS migration_fk_child;"
+ rawEx "DROP TABLE IF EXISTS migration_fk_parent;"
+ rawEx "DROP TABLE IF EXISTS explicit_primary_key;"
+
+spec :: Spec
+spec = describe "MigrationSpec" $ do
+ it "gathers schema state" $ runConnAssert $ do
+ migrateManually
+
+ getter <- getStmtGetter
+ actual <-
+ liftIO $
+ collectSchemaState getter $
+ map
+ EntityNameDB
+ [ "users"
+ , "admin_users"
+ , "user_friendships"
+ , "passwords"
+ , "passwords_2"
+ , "nonexistent"
+ ]
+
+ cleanDB
+
+ let
+ expected =
+ SchemaState
+ ( Map.fromList
+ [
+ ( EntityNameDB{unEntityNameDB = "admin_users"}
+ , EntityExists
+ ( ExistingEntitySchemaState
+ { essColumns =
+ Map.fromList
+ [
+ ( FieldNameDB{unFieldNameDB = "promoted_by_user_id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "promoted_by_user_id"}
+ , cNull = False
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList
+ [ ColumnReference
+ { crTableName = EntityNameDB{unEntityNameDB = "users"}
+ , crConstraintName =
+ ConstraintNameDB{unConstraintNameDB = "admin_users_promoted_by_user_id_fkey"}
+ , crFieldCascade =
+ FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict}
+ }
+ ]
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "user_id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "user_id"}
+ , cNull = False
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList
+ [ ColumnReference
+ { crTableName = EntityNameDB{unEntityNameDB = "users"}
+ , crConstraintName =
+ ConstraintNameDB{unConstraintNameDB = "admin_users_user_id_fkey"}
+ , crFieldCascade =
+ FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict}
+ }
+ ]
+ )
+ )
+ ]
+ , essUniqueConstraints =
+ Map.fromList
+ [
+ ( ConstraintNameDB{unConstraintNameDB = "unique_promoted_by_user_id"}
+ , [FieldNameDB{unFieldNameDB = "promoted_by_user_id"}]
+ )
+ ]
+ }
+ )
+ )
+ , (EntityNameDB{unEntityNameDB = "nonexistent"}, EntityDoesNotExist)
+ ,
+ ( EntityNameDB{unEntityNameDB = "passwords"}
+ , EntityExists
+ ( ExistingEntitySchemaState
+ { essColumns =
+ Map.fromList
+ [
+ ( FieldNameDB{unFieldNameDB = "id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "id"}
+ , cNull = False
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList []
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "password_hash"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "password_hash"}
+ , cNull = False
+ , cSqlType = SqlString
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList []
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "user_id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "user_id"}
+ , cNull = True
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList
+ [ ColumnReference
+ { crTableName = EntityNameDB{unEntityNameDB = "users"}
+ , crConstraintName =
+ ConstraintNameDB{unConstraintNameDB = "passwords_user_id_fkey"}
+ , crFieldCascade =
+ FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict}
+ }
+ ]
+ )
+ )
+ ]
+ , essUniqueConstraints =
+ Map.fromList
+ [
+ ( ConstraintNameDB{unConstraintNameDB = "unique_user_id"}
+ , [FieldNameDB{unFieldNameDB = "user_id"}]
+ )
+ ]
+ }
+ )
+ )
+ ,
+ ( EntityNameDB{unEntityNameDB = "passwords_2"}
+ , EntityExists
+ ( ExistingEntitySchemaState
+ { essColumns =
+ Map.fromList
+ [
+ ( FieldNameDB{unFieldNameDB = "id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "id"}
+ , cNull = False
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList []
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "password_hash"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "password_hash"}
+ , cNull = False
+ , cSqlType = SqlString
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList []
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "user_id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "user_id"}
+ , cNull = True
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList
+ [ ColumnReference
+ { crTableName = EntityNameDB{unEntityNameDB = "users"}
+ , crConstraintName =
+ ConstraintNameDB{unConstraintNameDB = "duplicate_passwords_2_user_id_fkey"}
+ , crFieldCascade =
+ FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction}
+ }
+ , ColumnReference
+ { crTableName = EntityNameDB{unEntityNameDB = "users"}
+ , crConstraintName =
+ ConstraintNameDB{unConstraintNameDB = "passwords_2_user_id_fkey"}
+ , crFieldCascade =
+ FieldCascade{fcOnUpdate = Just SetNull, fcOnDelete = Just Cascade}
+ }
+ ]
+ )
+ )
+ ]
+ , essUniqueConstraints =
+ Map.fromList
+ [
+ ( ConstraintNameDB{unConstraintNameDB = "unique_user_id2"}
+ , [FieldNameDB{unFieldNameDB = "user_id"}]
+ )
+ ]
+ }
+ )
+ )
+ ,
+ ( EntityNameDB{unEntityNameDB = "user_friendships"}
+ , EntityExists
+ ( ExistingEntitySchemaState
+ { essColumns =
+ Map.fromList
+ [
+ ( FieldNameDB{unFieldNameDB = "id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "id"}
+ , cNull = False
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList []
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "user1_id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "user1_id"}
+ , cNull = True
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList
+ [ ColumnReference
+ { crTableName = EntityNameDB{unEntityNameDB = "users"}
+ , crConstraintName =
+ ConstraintNameDB{unConstraintNameDB = "user_friendships_user1_id_fkey"}
+ , crFieldCascade =
+ FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict}
+ }
+ ]
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "user2_id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "user2_id"}
+ , cNull = True
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList
+ [ ColumnReference
+ { crTableName = EntityNameDB{unEntityNameDB = "users"}
+ , crConstraintName =
+ ConstraintNameDB{unConstraintNameDB = "user_friendships_user2_id_fkey"}
+ , crFieldCascade =
+ FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict}
+ }
+ ]
+ )
+ )
+ ]
+ , essUniqueConstraints = Map.fromList []
+ }
+ )
+ )
+ ,
+ ( EntityNameDB{unEntityNameDB = "users"}
+ , EntityExists
+ ( ExistingEntitySchemaState
+ { essColumns =
+ Map.fromList
+ [
+ ( FieldNameDB{unFieldNameDB = "id"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "id"}
+ , cNull = False
+ , cSqlType = SqlInt64
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList []
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "name"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "name"}
+ , cNull = False
+ , cSqlType = SqlString
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList []
+ )
+ )
+ ,
+ ( FieldNameDB{unFieldNameDB = "title"}
+ ,
+ ( Column
+ { cName = FieldNameDB{unFieldNameDB = "title"}
+ , cNull = True
+ , cSqlType = SqlString
+ , cDefault = Nothing
+ , cGenerated = Nothing
+ , cDefaultConstraintName = Nothing
+ , cMaxLen = Nothing
+ , cReference = Nothing
+ }
+ , Set.fromList []
+ )
+ )
+ ]
+ , essUniqueConstraints = Map.fromList []
+ }
+ )
+ )
+ ]
+ )
+
+ actual `shouldBe` Right expected
+
+ it "no-ops on a migrated DB" $ runConnAssert $ do
+ migrateManually
+
+ getter <- getStmtGetter
+ result <-
+ liftIO $
+ migrateEntitiesStructured
+ emptyBackendSpecificOverrides
+ getter
+ allEntityDefs
+ allEntityDefs
+
+ cleanDB
+
+ case result of
+ Right [] ->
+ pure ()
+ Left err ->
+ expectationFailure $ show err
+ Right alters ->
+ map (snd . showAlterDb) alters `shouldBe` []
+
+ it "migrates a clean DB" $ runConnAssert $ do
+ cleanDB
+
+ getter <- getStmtGetter
+ result <-
+ liftIO $
+ migrateEntitiesStructured
+ emptyBackendSpecificOverrides
+ getter
+ allEntityDefs
+ allEntityDefs
+
+ cleanDB
+
+ case result of
+ Right [] ->
+ pure ()
+ Left err ->
+ expectationFailure $ show err
+ Right alters -> do
+ traverse_ (flip rawExecute [] . snd . showAlterDb) alters
+ result2 <-
+ liftIO $
+ migrateEntitiesStructured
+ emptyBackendSpecificOverrides
+ getter
+ allEntityDefs
+ allEntityDefs
+ result2 `shouldBe` Right []
+
+ it "suggests FK constraints for new fields first time" $ runConnAssert $ do
+ migrateManually
+
+ getter <- getStmtGetter
+ result <-
+ liftIO $
+ migrateEntitiesStructured
+ emptyBackendSpecificOverrides
+ getter
+ (fkChildV2EntityDef : allEntityDefs)
+ [fkChildV2EntityDef]
+
+ cleanDB
+
+ case result of
+ Right [] ->
+ pure ()
+ Left err ->
+ expectationFailure $ show err
+ Right alters ->
+ map (snd . showAlterDb) alters
+ `shouldBe` [ "ALTER TABLE \"migration_fk_child\" ADD COLUMN \"parent_id\" INT8 NOT NULL"
+ , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE RESTRICT ON UPDATE RESTRICT"
+ ]
+
+ it "Uses overrides for empty cascade action" $ runConnAssert $ do
+ migrateManually
+
+ getter <- getStmtGetter
+
+ let
+ overrideWithDefault =
+ setBackendSpecificForeignKeyCascadeDefault Cascade emptyBackendSpecificOverrides
+ result <-
+ liftIO $
+ migrateEntitiesStructured
+ overrideWithDefault
+ getter
+ (fkChildV2EntityDef : allEntityDefs)
+ [fkChildV2EntityDef]
+
+ cleanDB
+
+ case result of
+ Right [] ->
+ pure ()
+ Left err ->
+ expectationFailure $ show err
+ Right alters ->
+ map (snd . showAlterDb) alters
+ `shouldBe` [ "ALTER TABLE \"migration_fk_child\" ADD COLUMN \"parent_id\" INT8 NOT NULL"
+ , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE CASCADE ON UPDATE CASCADE"
+ ]
diff --git a/persistent-postgresql-ng/test/PgPipelineInit.hs b/persistent-postgresql-ng/test/PgPipelineInit.hs
new file mode 100644
index 000000000..53de436e0
--- /dev/null
+++ b/persistent-postgresql-ng/test/PgPipelineInit.hs
@@ -0,0 +1,242 @@
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# OPTIONS_GHC -fno-warn-orphans #-}
+
+module PgPipelineInit
+ ( runConn
+ , runConn_
+ , runConnAssert
+ , runConnWith
+ , runConnWith_
+ , runConnAssertWith
+ , FetchMode (..)
+ , MonadIO
+ , persistSettings
+ , MkPersistSettings (..)
+ , BackendKey (..)
+ , GenerateKey (..)
+ -- re-exports
+ , module Control.Monad.Trans.Reader
+ , module Control.Monad
+ , module Database.Persist.Sql
+ , module Database.Persist.SqlBackend
+ , module Database.Persist
+ , module Database.Persist.Sql.Raw.QQ
+ , module Init
+ , module Test.Hspec
+ , module Test.Hspec.Expectations.Lifted
+ , module Test.HUnit
+ , AValue (..)
+ , BS.ByteString
+ , Int32
+ , Int64
+ , liftIO
+ , mkPersist
+ , migrateModels
+ , mkMigrate
+ , share
+ , sqlSettings
+ , persistLowerCase
+ , persistUpperCase
+ , mkEntityDefList
+ , setImplicitIdDef
+ , SomeException
+ , Text
+ , TestFn (..)
+ , LoggingT
+ , ResourceT
+ , UUID (..)
+ , sqlSettingsUuid
+ ) where
+
+import Init
+ ( GenerateKey (..)
+ , MonadFail
+ , RunDb
+ , TestFn (..)
+ , UUID (..)
+ , arbText
+ , asIO
+ , assertEmpty
+ , assertNotEmpty
+ , assertNotEqual
+ , isTravis
+ , liftA2
+ , sqlSettingsUuid
+ , truncateTimeOfDay
+ , truncateToMicro
+ , truncateUTCTime
+ , (==@)
+ , (@/=)
+ , (@==)
+ )
+
+-- re-exports
+import Control.Exception (SomeException)
+import Control.Monad (forM_, liftM, replicateM, void, when)
+import Control.Monad.Trans.Reader
+import Data.Aeson (FromJSON, ToJSON, Value (..), object)
+import qualified Data.Text.Encoding as TE
+import Database.Persist.Postgresql.JSON ()
+import Database.Persist.Sql.Raw.QQ
+import Database.Persist.SqlBackend
+import Database.Persist.TH
+ ( MkPersistSettings (..)
+ , migrateModels
+ , mkEntityDefList
+ , mkMigrate
+ , mkPersist
+ , persistLowerCase
+ , persistUpperCase
+ , setImplicitIdDef
+ , share
+ , sqlSettings
+ )
+import Test.Hspec
+ ( Arg
+ , Spec
+ , SpecWith
+ , afterAll_
+ , before
+ , beforeAll
+ , before_
+ , describe
+ , fdescribe
+ , fit
+ , hspec
+ , it
+ )
+import Test.Hspec.Expectations.Lifted
+import Test.QuickCheck.Instances ()
+import UnliftIO
+
+-- testing
+import Test.HUnit (Assertion, assertBool, assertFailure, (@=?), (@?=))
+import Test.QuickCheck
+
+import Control.Monad (unless, (>=>))
+import Control.Monad.IO.Unlift (MonadUnliftIO)
+import Control.Monad.Logger
+import Control.Monad.Trans.Resource (ResourceT, runResourceT)
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Char8 as B8
+import qualified Data.HashMap.Strict as HM
+import Data.Int (Int32, Int64)
+import Data.Maybe (fromMaybe)
+import Data.Monoid ((<>))
+import Data.Text (Text)
+import Data.Vector (Vector)
+import System.Environment (getEnvironment, lookupEnv)
+import System.Log.FastLogger (fromLogStr)
+
+import Database.Persist
+import Database.Persist.Postgresql.Pipeline
+import Database.Persist.Sql
+import Database.Persist.TH ()
+
+_debugOn :: Bool
+_debugOn = False
+
+dockerPg :: IO (Maybe BS.ByteString)
+dockerPg = do
+ env <- liftIO getEnvironment
+ return $ case lookup "POSTGRES_NAME" env of
+ Just _name -> Just "postgres"
+ _ -> Nothing
+
+persistSettings :: MkPersistSettings
+persistSettings = sqlSettings{mpsGeneric = True}
+
+-- | Run with default 'FetchAll' mode.
+runConn :: (MonadUnliftIO m) => SqlPersistT (LoggingT m) t -> m ()
+runConn f = runConn_ f >>= const (return ())
+
+-- | Run with default 'FetchAll' mode, returning the result.
+runConn_ :: (MonadUnliftIO m) => SqlPersistT (LoggingT m) t -> m t
+runConn_ = runConnWith_ FetchAll
+
+-- | Run with the given 'FetchMode', discarding the result.
+runConnWith :: (MonadUnliftIO m) => FetchMode -> SqlPersistT (LoggingT m) t -> m ()
+runConnWith mode f = runConnWith_ mode f >>= const (return ())
+
+-- | Run with the given 'FetchMode', returning the result.
+runConnWith_ :: (MonadUnliftIO m) => FetchMode -> SqlPersistT (LoggingT m) t -> m t
+runConnWith_ fetchMode f = do
+ travis <- liftIO isTravis
+ let
+ debugPrint = not travis && _debugOn
+ printDebug = if debugPrint then print . fromLogStr else void . return
+ poolSize = 1
+ connString <-
+ if travis
+ then do
+ pure
+ "host=localhost port=5432 user=perstest password=perstest dbname=persistent"
+ else do
+ host <- fromMaybe "localhost" <$> liftIO dockerPg
+ port <- fromMaybe "5432" <$> liftIO (lookupEnv "PGPORT")
+ mpass <- liftIO (lookupEnv "PGPASSWORD")
+ let passFragment = case mpass of
+ Just pw -> " password=" <> B8.pack pw
+ Nothing -> ""
+ pure ("host=" <> host <> " port=" <> B8.pack port <> " user=postgres" <> passFragment <> " dbname=test")
+
+ flip runLoggingT (\_ _ _ s -> printDebug s) $ do
+ logInfoN (if travis then "Running in CI" else "CI not detected")
+ let settings = defaultPipelineSettings
+ { pipelineFetchMode = fetchMode
+ , pipelineConnStr = connString
+ , pipelinePoolSize = poolSize
+ }
+ go = withPostgresqlPipelinePool settings $ runSqlPool (withBaseBackend f)
+ -- Retry on connection failure, same as persistent-postgresql
+ eres <- try go
+ case eres of
+ Left (err :: SomeException) -> do
+ eres' <- try go
+ case eres' of
+ Left (err' :: SomeException) ->
+ if show err == show err'
+ then throwIO err
+ else throwIO err'
+ Right a ->
+ pure a
+ Right a ->
+ pure a
+
+-- | Default mode assertion runner.
+runConnAssert :: SqlPersistT (LoggingT (ResourceT IO)) () -> Assertion
+runConnAssert = runConnAssertWith FetchAll
+
+-- | Assertion runner parameterized by 'FetchMode'.
+runConnAssertWith :: FetchMode -> SqlPersistT (LoggingT (ResourceT IO)) () -> Assertion
+runConnAssertWith mode actions = do
+ runResourceT $ runConnWith mode $ actions >> transactionUndo
+
+newtype AValue = AValue {getValue :: Value}
+
+-- Need a specialized Arbitrary instance
+instance Arbitrary AValue where
+ arbitrary =
+ AValue
+ <$> frequency
+ [ (1, pure Null)
+ , (1, Bool <$> arbitrary)
+ , (2, Number <$> arbitrary)
+ , (2, String <$> arbText)
+ , (3, Array <$> limitIt 4 (fmap (fmap getValue) arbitrary))
+ , (3, object <$> arbObject)
+ ]
+ where
+ limitIt :: Int -> Gen a -> Gen a
+ limitIt i x = sized $ \n -> do
+ let
+ m = if n > i then i else n
+ resize m x
+ arbObject =
+ limitIt 4
+ $ listOf
+ . liftA2 (,) arbText
+ $ limitIt 4 (fmap getValue arbitrary)
diff --git a/persistent-postgresql-ng/test/PgPipelineIntervalTest.hs b/persistent-postgresql-ng/test/PgPipelineIntervalTest.hs
new file mode 100644
index 000000000..71b4e486f
--- /dev/null
+++ b/persistent-postgresql-ng/test/PgPipelineIntervalTest.hs
@@ -0,0 +1,56 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveAnyClass #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE EmptyDataDecls #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+module PgPipelineIntervalTest where
+
+import Data.Fixed (Fixed (MkFixed), Micro, Pico)
+import Data.Time.Clock (secondsToNominalDiffTime)
+import Database.Persist.Postgresql.Pipeline (PgInterval (..))
+import PgPipelineInit
+import Test.Hspec.QuickCheck
+
+share
+ [mkPersist sqlSettings, mkMigrate "pgIntervalMigrate"]
+ [persistLowerCase|
+PgIntervalDb
+ interval_field PgInterval
+ deriving Eq
+ deriving Show
+|]
+
+clamp :: (Ord a) => a -> a -> a -> a
+clamp lo hi = max lo . min hi
+
+microsecondLimit :: Int64
+microsecondLimit = 2147483647 * 60 * 60 * 1000000
+
+specs :: Spec
+specs = do
+ describe "Postgres Interval Property tests" $ do
+ prop "Round trips" $ \int64 -> runConnAssert $ do
+ let
+ eg =
+ PgIntervalDb
+ . PgInterval
+ . secondsToNominalDiffTime
+ . (realToFrac :: Micro -> Pico)
+ . MkFixed
+ . toInteger
+ $ clamp (-microsecondLimit) microsecondLimit int64
+ rid <- insert eg
+ r <- getJust rid
+ liftIO $ r `shouldBe` eg
diff --git a/persistent-postgresql-ng/test/PipelineDeferralSpec.hs b/persistent-postgresql-ng/test/PipelineDeferralSpec.hs
new file mode 100644
index 000000000..2cb47486c
--- /dev/null
+++ b/persistent-postgresql-ng/test/PipelineDeferralSpec.hs
@@ -0,0 +1,407 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+-- | Tests proving that pipeline mode defers result evaluation.
+--
+-- Reads the @pgPending@ counter directly to verify that fire-and-forget
+-- operations (delete, update, replace, etc.) increment the counter without
+-- reading results, and that drain points (get, selectList, commit) reset it.
+module PipelineDeferralSpec (specs) where
+
+import Control.Exception (SomeException, try)
+import Control.Monad.Trans.Resource (runResourceT)
+import Data.IORef (readIORef)
+import Database.Persist.Postgresql.Pipeline (getPipelineConn)
+import Database.Persist.Postgresql.Pipeline.Internal (PgConn (..))
+import PgPipelineInit
+
+share
+ [mkPersist sqlSettings, mkMigrate "deferralMigrate"]
+ [persistLowerCase|
+DeferItem
+ name Text
+ value Int
+ deriving Show Eq
+|]
+
+db :: SqlPersistT (LoggingT (ResourceT IO)) a -> IO a
+db actions = runResourceT $ runConn_ $ do
+ void $ runMigrationSilent deferralMigrate
+ actions
+
+-- | Read the pending result count from the pipeline.
+getPending :: (MonadIO m) => SqlPersistT m Int
+getPending = do
+ backend <- ask
+ case getPipelineConn backend of
+ Nothing -> error "getPending: no PgConn (not a pipeline backend)"
+ Just pc -> liftIO $ length <$> readIORef (pgPending pc)
+
+specs :: Spec
+specs = describe "Pipeline deferral (pgPending counter)" $ do
+
+ it "delete increments pgPending, get drains it" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ k1 <- insert $ DeferItem "a" 1
+ k2 <- insert $ DeferItem "b" 2
+ k3 <- insert $ DeferItem "c" 3
+
+ -- insert uses stmtQuery (RETURNING), so pending should be 0
+ n0 <- getPending
+ liftIO $ n0 `shouldBe` 0
+
+ -- delete uses stmtExecute (fire-and-forget)
+ delete k1
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ delete k2
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 2
+
+ -- get uses stmtQuery → drains all pending first
+ _ <- get k3
+ n3 <- getPending
+ liftIO $ n3 `shouldBe` 0
+
+ it "update increments pgPending, selectList drains it" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ k1 <- insert $ DeferItem "x" 10
+ k2 <- insert $ DeferItem "y" 20
+
+ n0 <- getPending
+ liftIO $ n0 `shouldBe` 0
+
+ update k1 [DeferItemValue =. 99]
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ update k2 [DeferItemValue =. 99]
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 2
+
+ -- selectList drains pending
+ items <- selectList ([] :: [Filter DeferItem]) []
+ n3 <- getPending
+ liftIO $ n3 `shouldBe` 0
+ liftIO $ map (deferItemValue . entityVal) items `shouldBe` [99, 99]
+
+ it "replace increments pgPending" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ k <- insert $ DeferItem "orig" 1
+
+ n0 <- getPending
+ liftIO $ n0 `shouldBe` 0
+
+ replace k (DeferItem "replaced" 2)
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ mitem <- get k
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 0
+ liftIO $ mitem `shouldBe` Just (DeferItem "replaced" 2)
+
+ it "deleteWhere increments pgPending" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ _ <- insert $ DeferItem "del1" 1
+ _ <- insert $ DeferItem "del2" 2
+
+ n0 <- getPending
+ liftIO $ n0 `shouldBe` 0
+
+ deleteWhere [DeferItemValue ==. 1]
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ remaining <- selectList ([] :: [Filter DeferItem]) []
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 0
+ liftIO $ length remaining `shouldBe` 1
+
+ it "updateWhere increments pgPending" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ _ <- insert $ DeferItem "uw1" 1
+ _ <- insert $ DeferItem "uw2" 2
+
+ n0 <- getPending
+ liftIO $ n0 `shouldBe` 0
+
+ updateWhere [DeferItemValue ==. 1] [DeferItemValue =. 100]
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ updateWhere [DeferItemValue ==. 2] [DeferItemValue =. 200]
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 2
+
+ items <- selectList ([] :: [Filter DeferItem]) [Asc DeferItemValue]
+ n3 <- getPending
+ liftIO $ n3 `shouldBe` 0
+ liftIO $ map (deferItemValue . entityVal) items `shouldBe` [100, 200]
+
+ it "many fire-and-forget operations accumulate" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ keys <- mapM insert
+ [ DeferItem "m1" 1, DeferItem "m2" 2, DeferItem "m3" 3
+ , DeferItem "m4" 4, DeferItem "m5" 5, DeferItem "m6" 6
+ , DeferItem "m7" 7, DeferItem "m8" 8, DeferItem "m9" 9
+ , DeferItem "m10" 10
+ ]
+
+ -- 10 deletes, all fire-and-forget
+ forM_ keys delete
+ n <- getPending
+ liftIO $ n `shouldBe` 10
+
+ -- selectList drains all 10 at once
+ remaining <- selectList ([] :: [Filter DeferItem]) []
+ nAfter <- getPending
+ liftIO $ nAfter `shouldBe` 0
+ liftIO $ remaining `shouldBe` []
+
+ it "first select forces only prior DML, second batch stays pending" $ db $ do
+ -- Proves that drain is scoped: only DML before the read is forced.
+ -- DML after the read remains pending until the next drain point.
+ deleteWhere ([] :: [Filter DeferItem])
+ k1 <- insert $ DeferItem "a" 1
+ k2 <- insert $ DeferItem "b" 2
+
+ -- Batch 1: two updates (fire-and-forget)
+ update k1 [DeferItemValue =. 10]
+ update k2 [DeferItemValue =. 20]
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 2
+
+ -- First select: drains batch 1
+ items1 <- selectList ([] :: [Filter DeferItem]) [Asc DeferItemName]
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 0
+ liftIO $ map (deferItemValue . entityVal) items1 `shouldBe` [10, 20]
+
+ -- Batch 2: two more updates (new pending, independent of batch 1)
+ update k1 [DeferItemValue =. 100]
+ update k2 [DeferItemValue =. 200]
+ n3 <- getPending
+ liftIO $ n3 `shouldBe` 2 -- batch 2 is pending, NOT forced by first select
+
+ -- Second select: drains batch 2
+ items2 <- selectList ([] :: [Filter DeferItem]) [Asc DeferItemName]
+ n4 <- getPending
+ liftIO $ n4 `shouldBe` 0
+ liftIO $ map (deferItemValue . entityVal) items2 `shouldBe` [100, 200]
+
+ it "without a second read, second batch remains pending" $ db $ do
+ -- Proves that DML results are ONLY forced by explicit drain points.
+ -- If you never read after the second batch, those results stay pending
+ -- (commit at transaction end will eventually drain them).
+ deleteWhere ([] :: [Filter DeferItem])
+ k <- insert $ DeferItem "x" 1
+
+ -- First DML + read
+ update k [DeferItemValue =. 10]
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ Just item1 <- get k
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 0
+ liftIO $ deferItemValue item1 `shouldBe` 10
+
+ -- Second DML — NOT forced by the prior get
+ update k [DeferItemValue =. 20]
+ n3 <- getPending
+ liftIO $ n3 `shouldBe` 1 -- still pending, first read didn't force this
+
+ -- Third DML — also not forced
+ update k [DeferItemValue =. 30]
+ n4 <- getPending
+ liftIO $ n4 `shouldBe` 2 -- both second and third are pending
+
+ -- Only when we read again does the second batch drain
+ Just item2 <- get k
+ n5 <- getPending
+ liftIO $ n5 `shouldBe` 0 -- both forced now
+ liftIO $ deferItemValue item2 `shouldBe` 30
+
+ it "interleaved read-write: each read only forces prior writes" $ db $ do
+ -- Step-by-step interleaving to prove each read forces exactly
+ -- the writes that preceded it, and no more.
+ deleteWhere ([] :: [Filter DeferItem])
+ k <- insert $ DeferItem "z" 0
+
+ -- Write 1
+ update k [DeferItemValue =. 1]
+ liftIO . (1 `shouldBe`) =<< getPending
+
+ -- Read 1: forces write 1
+ Just v1 <- get k
+ liftIO . (0 `shouldBe`) =<< getPending
+ liftIO $ deferItemValue v1 `shouldBe` 1
+
+ -- Write 2
+ update k [DeferItemValue =. 2]
+ liftIO . (1 `shouldBe`) =<< getPending
+
+ -- Write 3 (stacks on write 2)
+ update k [DeferItemValue =. 3]
+ liftIO . (2 `shouldBe`) =<< getPending
+
+ -- Read 2: forces writes 2 and 3 together
+ Just v2 <- get k
+ liftIO . (0 `shouldBe`) =<< getPending
+ liftIO $ deferItemValue v2 `shouldBe` 3
+
+ -- Write 4
+ update k [DeferItemValue =. 4]
+ liftIO . (1 `shouldBe`) =<< getPending
+
+ -- Read 3: forces write 4
+ Just v3 <- get k
+ liftIO . (0 `shouldBe`) =<< getPending
+ liftIO $ deferItemValue v3 `shouldBe` 4
+
+ it "commit drains all pending (no explicit read needed)" $ db $ do
+ -- This test uses transactionSave (which commits the current
+ -- transaction and begins a new one) to prove that commit is
+ -- a drain point even without any read operations.
+ deleteWhere ([] :: [Filter DeferItem])
+ k1 <- insert $ DeferItem "c1" 1
+ k2 <- insert $ DeferItem "c2" 2
+
+ delete k1
+ delete k2
+ n <- getPending
+ liftIO $ n `shouldBe` 2
+
+ -- transactionSave commits + begins a new transaction
+ transactionSave
+
+ -- After commit, pending should be 0
+ -- (the new BEGIN from transactionSave is also fire-and-forget,
+ -- but it gets drained by the next read operation)
+ nAfterCommit <- getPending
+ -- BEGIN from the new transaction is pending
+ liftIO $ nAfterCommit `shouldBe` 1
+
+ -- Verify the deletes actually took effect
+ remaining <- selectList ([] :: [Filter DeferItem]) []
+ liftIO $ remaining `shouldBe` []
+
+ it "insert (RETURNING) is a drain point" $ db $ do
+ -- insert uses stmtQuery for the RETURNING clause, so it must
+ -- drain all pending DML before executing.
+ deleteWhere ([] :: [Filter DeferItem])
+ k1 <- insert $ DeferItem "drain1" 1
+
+ -- Fire-and-forget DML
+ update k1 [DeferItemValue =. 99]
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ -- insert triggers drain (stmtQuery for RETURNING)
+ _ <- insert $ DeferItem "drain2" 2
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 0
+
+ -- Verify the update was applied (it was drained by the insert)
+ Just item <- get k1
+ liftIO $ deferItemValue item `shouldBe` 99
+
+ it "rawExecute is fire-and-forget" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ _ <- insert $ DeferItem "raw" 1
+
+ n0 <- getPending
+ liftIO $ n0 `shouldBe` 0
+
+ -- rawExecute uses stmtExecute → fire-and-forget
+ rawExecute "UPDATE \"defer_item\" SET \"value\" = 42" []
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ rawExecute "UPDATE \"defer_item\" SET \"value\" = 99" []
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 2
+
+ -- Read forces drain
+ items <- selectList ([] :: [Filter DeferItem]) []
+ n3 <- getPending
+ liftIO $ n3 `shouldBe` 0
+ liftIO $ map (deferItemValue . entityVal) items `shouldBe` [99]
+
+ it "count is a drain point" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ _ <- insert $ DeferItem "cnt1" 1
+ _ <- insert $ DeferItem "cnt2" 2
+
+ -- Fire-and-forget
+ deleteWhere [DeferItemValue ==. 1]
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ -- count uses stmtQuery → drains pending
+ c <- count ([] :: [Filter DeferItem])
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 0
+ liftIO $ c `shouldBe` 1
+
+ it "exists is a drain point" $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ k <- insert $ DeferItem "ex" 1
+
+ delete k
+ n1 <- getPending
+ liftIO $ n1 `shouldBe` 1
+
+ -- exists uses stmtQuery → drains pending
+ e <- exists [DeferItemName ==. "ex"]
+ n2 <- getPending
+ liftIO $ n2 `shouldBe` 0
+ liftIO $ e `shouldBe` False
+
+ it "error in deferred DML surfaces at next drain point" $ do
+ -- A fire-and-forget query that will fail (bad table name).
+ -- The error is deferred until the next drain point (selectList).
+ result <- (try $ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ _ <- insert $ DeferItem "err" 1
+ -- Fire-and-forget an invalid query
+ rawExecute "UPDATE \"nonexistent_table\" SET x = 1" []
+ n <- getPending
+ liftIO $ n `shouldBe` 1
+ -- selectList triggers drain → error surfaces
+ _ <- selectList ([] :: [Filter DeferItem]) []
+ return ()
+ ) :: IO (Either SomeException ())
+ case result of
+ Left _e -> return () -- error propagated correctly
+ Right () -> expectationFailure "Expected exception from deferred bad query"
+
+ it "connection pool works after deferred error + rollback" $ do
+ -- After the above error, verify the pool gives us a working
+ -- connection. (runSqlPool called connRollback, cleaning up
+ -- the pipeline before returning the connection to the pool.)
+ _ <- try $ db $ do
+ rawExecute "UPDATE \"nonexistent_table\" SET x = 1" []
+ _ <- selectList ([] :: [Filter DeferItem]) []
+ return ()
+ :: IO (Either SomeException ())
+ -- Fresh transaction should work fine
+ db $ do
+ deleteWhere ([] :: [Filter DeferItem])
+ k <- insert $ DeferItem "recovery" 42
+ Just item <- get k
+ liftIO $ deferItemValue item `shouldBe` 42
diff --git a/persistent-postgresql-ng/test/PipelineModeSpec.hs b/persistent-postgresql-ng/test/PipelineModeSpec.hs
new file mode 100644
index 000000000..347662b2a
--- /dev/null
+++ b/persistent-postgresql-ng/test/PipelineModeSpec.hs
@@ -0,0 +1,107 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+module PipelineModeSpec (specs, specsWith) where
+
+import Control.Monad.Trans.Resource (runResourceT)
+import Database.Persist.Postgresql.Pipeline (withPipeline)
+import PgPipelineInit
+
+share
+ [mkPersist sqlSettings, mkMigrate "pipelineMigrate"]
+ [persistLowerCase|
+PipelineItem
+ name Text
+ value Int
+ deriving Show Eq
+|]
+
+db :: SqlPersistT (LoggingT (ResourceT IO)) a -> IO a
+db = dbWith FetchAll
+
+dbWith :: FetchMode -> SqlPersistT (LoggingT (ResourceT IO)) a -> IO a
+dbWith mode actions = runResourceT $ runConnWith_ mode $ do
+ void $ runMigrationSilent pipelineMigrate
+ actions
+
+specs :: Spec
+specs = specsWith FetchAll
+
+specsWith :: FetchMode -> Spec
+specsWith mode = describe "Pipeline mode (automatic)" $ do
+ let db' = dbWith mode
+
+ it "multiple deletes are pipelined" $ db' $ do
+ deleteWhere ([] :: [Filter PipelineItem])
+ keys <- mapM insert
+ [ PipelineItem "a" 1
+ , PipelineItem "b" 2
+ , PipelineItem "c" 3
+ ]
+ forM_ keys delete
+ remaining <- selectList ([] :: [Filter PipelineItem]) []
+ liftIO $ remaining @?= []
+
+ it "multiple updates are pipelined" $ db' $ do
+ deleteWhere ([] :: [Filter PipelineItem])
+ keys <- mapM insert
+ [ PipelineItem "x" 10
+ , PipelineItem "y" 20
+ , PipelineItem "z" 30
+ ]
+ forM_ keys $ \k -> update k [PipelineItemValue =. 99]
+ items <- selectList ([] :: [Filter PipelineItem]) []
+ let vals = map (pipelineItemValue . entityVal) items
+ liftIO $ vals @?= [99, 99, 99]
+
+ it "mixed DML then read sees prior writes" $ db' $ do
+ deleteWhere ([] :: [Filter PipelineItem])
+ k1 <- insert $ PipelineItem "first" 1
+ update k1 [PipelineItemValue =. 42]
+ items <- selectList ([] :: [Filter PipelineItem]) []
+ liftIO $ length items @?= 1
+ liftIO $ (pipelineItemValue . entityVal . head) items @?= 42
+
+ it "deleteWhere is pipelined" $ db' $ do
+ deleteWhere ([] :: [Filter PipelineItem])
+ _ <- insert $ PipelineItem "del1" 1
+ _ <- insert $ PipelineItem "del2" 2
+ deleteWhere [PipelineItemValue ==. 1]
+ remaining <- selectList ([] :: [Filter PipelineItem]) []
+ liftIO $ length remaining @?= 1
+ liftIO $ (pipelineItemName . entityVal . head) remaining @?= "del2"
+
+ it "connection works across multiple transactions" $ db' $ do
+ deleteWhere ([] :: [Filter PipelineItem])
+ _ <- insert $ PipelineItem "txn1" 100
+ items1 <- selectList ([] :: [Filter PipelineItem]) []
+ liftIO $ length items1 @?= 1
+
+ it "empty transaction succeeds" $ db' $ do
+ return ()
+
+ it "replace is pipelined" $ db' $ do
+ deleteWhere ([] :: [Filter PipelineItem])
+ k <- insert $ PipelineItem "orig" 1
+ replace k (PipelineItem "replaced" 2)
+ mitem <- get k
+ liftIO $ mitem @?= Just (PipelineItem "replaced" 2)
+
+ it "withPipeline is a no-op (pipeline always on)" $ db' $ do
+ deleteWhere ([] :: [Filter PipelineItem])
+ withPipeline $ do
+ _ <- insert $ PipelineItem "inside" 42
+ items <- selectList ([] :: [Filter PipelineItem]) []
+ liftIO $ length items @?= 1
diff --git a/persistent-postgresql-ng/test/PipelineRegressionSpec.hs b/persistent-postgresql-ng/test/PipelineRegressionSpec.hs
new file mode 100644
index 000000000..e290542e5
--- /dev/null
+++ b/persistent-postgresql-ng/test/PipelineRegressionSpec.hs
@@ -0,0 +1,278 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+-- | Regression tests for Hedis-style automatic pipelining.
+--
+-- These tests verify that all pipelined operations produce correct
+-- results when interleaved in various patterns:
+--
+-- * Back-to-back reads (get, getBy, count, exists)
+-- * Back-to-back inserts with RETURNING
+-- * Writes followed by reads
+-- * Reads followed by writes followed by reads
+-- * Mixed operations across entity types
+-- * Large batches (100+ operations)
+module PipelineRegressionSpec (specs) where
+
+import Control.Monad.Trans.Resource (runResourceT)
+import qualified Data.Map.Strict as Map
+import Data.Maybe (catMaybes, isJust, isNothing)
+import qualified Data.Text as T
+import PgPipelineInit
+
+share
+ [mkPersist sqlSettings, mkMigrate "pipeRegMigrate"]
+ [persistLowerCase|
+PipeReg
+ name Text
+ value Int
+ UniquePipeRegName name
+ deriving Show Eq
+
+PipeRegOther
+ label Text
+ score Double
+ deriving Show Eq
+|]
+
+dbWith :: FetchMode -> SqlPersistT (LoggingT (ResourceT IO)) a -> IO a
+dbWith mode actions = runResourceT $ runConnWith_ mode $ do
+ void $ runMigrationSilent pipeRegMigrate
+ deleteWhere ([] :: [Filter PipeReg])
+ deleteWhere ([] :: [Filter PipeRegOther])
+ actions
+
+specs :: Spec
+specs = describe "Pipeline regression" $ do
+ forM_ [("FetchAll", FetchAll), ("FetchSingleRow", FetchSingleRow)] $
+ \(modeName, mode) -> describe modeName $ do
+ let db = dbWith mode
+
+ ---------------------------------------------------------------------------
+ -- Back-to-back gets (Hedis-style pipelining)
+ ---------------------------------------------------------------------------
+
+ describe "back-to-back get" $ do
+ it "returns correct results for 10 keys" $ db $ do
+ keys <- mapM insert
+ [ PipeReg (T.pack $ "g" <> show i) (i * 10) | i <- [1..10 :: Int] ]
+ results <- mapM get keys
+ liftIO $ length (catMaybes results) @?= 10
+ liftIO $ forM_ (zip [1..] results) $ \(i :: Int, mr) ->
+ mr @?= Just (PipeReg (T.pack $ "g" <> show i) (i * 10))
+
+ it "returns Nothing for missing keys" $ db $ do
+ k <- insert $ PipeReg "exists" 1
+ let missingK = toSqlKey 999999
+ results <- sequence [get k, get missingK, get k]
+ liftIO $ map isJust results @?= [True, False, True]
+
+ it "handles 100 gets" $ db $ do
+ keys <- mapM insert
+ [ PipeReg (T.pack $ "h" <> show i) i | i <- [1..100 :: Int] ]
+ results <- mapM get keys
+ liftIO $ length (catMaybes results) @?= 100
+
+ ---------------------------------------------------------------------------
+ -- Back-to-back inserts (pipelined RETURNING)
+ ---------------------------------------------------------------------------
+
+ describe "back-to-back insert" $ do
+ it "returns distinct keys for 10 inserts" $ db $ do
+ keys <- mapM insert
+ [ PipeReg (T.pack $ "ins" <> show i) i | i <- [1..10 :: Int] ]
+ liftIO $ length keys @?= 10
+ liftIO $ length (Map.fromList (zip keys keys)) @?= 10
+
+ it "inserted records are readable" $ db $ do
+ keys <- mapM insert
+ [ PipeReg (T.pack $ "ir" <> show i) (i * 100) | i <- [1..5 :: Int] ]
+ results <- mapM get keys
+ liftIO $ forM_ (zip [1..] results) $ \(i :: Int, mr) ->
+ mr @?= Just (PipeReg (T.pack $ "ir" <> show i) (i * 100))
+
+ ---------------------------------------------------------------------------
+ -- Back-to-back getBy (unique lookups)
+ ---------------------------------------------------------------------------
+
+ describe "back-to-back getBy" $ do
+ it "finds existing records" $ db $ do
+ mapM_ insert
+ [ PipeReg "alice" 1, PipeReg "bob" 2, PipeReg "carol" 3 ]
+ results <- mapM (getBy . UniquePipeRegName)
+ ["alice", "bob", "carol"]
+ liftIO $ length (catMaybes results) @?= 3
+ liftIO $ map (fmap (pipeRegValue . entityVal)) results
+ @?= [Just 1, Just 2, Just 3]
+
+ it "returns Nothing for missing" $ db $ do
+ _ <- insert $ PipeReg "present" 1
+ results <- sequence
+ [ getBy (UniquePipeRegName "present")
+ , getBy (UniquePipeRegName "absent")
+ ]
+ liftIO $ isJust (head results) @?= True
+ liftIO $ isNothing (results !! 1) @?= True
+
+ ---------------------------------------------------------------------------
+ -- Back-to-back count / exists
+ ---------------------------------------------------------------------------
+
+ describe "back-to-back count and exists" $ do
+ it "counts correctly" $ db $ do
+ mapM_ insert [ PipeReg (T.pack $ "c" <> show i) i | i <- [1..5 :: Int] ]
+ counts <- sequence
+ [ count ([] :: [Filter PipeReg])
+ , count [PipeRegValue >. 3]
+ , count [PipeRegValue <. 0]
+ ]
+ liftIO $ counts @?= [5, 2, 0]
+
+ it "exists correctly" $ db $ do
+ mapM_ insert [ PipeReg "e1" 1, PipeReg "e2" 2 ]
+ results <- sequence
+ [ exists ([] :: [Filter PipeReg])
+ , exists [PipeRegValue ==. 1]
+ , exists [PipeRegValue ==. 999]
+ ]
+ liftIO $ results @?= [True, True, False]
+
+ ---------------------------------------------------------------------------
+ -- Interleaved writes and reads
+ ---------------------------------------------------------------------------
+
+ describe "interleaved writes and reads" $ do
+ it "write-read-write-read" $ db $ do
+ k1 <- insert $ PipeReg "wr1" 10
+ r1 <- get k1
+ k2 <- insert $ PipeReg "wr2" 20
+ r2 <- get k2
+ liftIO $ r1 @?= Just (PipeReg "wr1" 10)
+ liftIO $ r2 @?= Just (PipeReg "wr2" 20)
+
+ it "fire-and-forget writes then reads see results" $ db $ do
+ k1 <- insert $ PipeReg "ff1" 1
+ k2 <- insert $ PipeReg "ff2" 2
+ k3 <- insert $ PipeReg "ff3" 3
+ update k1 [PipeRegValue =. 100]
+ update k2 [PipeRegValue =. 200]
+ delete k3
+ r1 <- get k1
+ r2 <- get k2
+ r3 <- get k3
+ liftIO $ r1 @?= Just (PipeReg "ff1" 100)
+ liftIO $ r2 @?= Just (PipeReg "ff2" 200)
+ liftIO $ r3 @?= Nothing
+
+ it "replace then get" $ db $ do
+ k <- insert $ PipeReg "rep" 1
+ replace k (PipeReg "rep" 999)
+ r <- get k
+ liftIO $ r @?= Just (PipeReg "rep" 999)
+
+ it "deleteWhere then count" $ db $ do
+ mapM_ insert [ PipeReg (T.pack $ "dw" <> show i) i | i <- [1..10 :: Int] ]
+ deleteWhere [PipeRegValue <=. 5]
+ n <- count ([] :: [Filter PipeReg])
+ liftIO $ n @?= 5
+
+ it "updateWhere then selectList" $ db $ do
+ mapM_ insert [ PipeReg (T.pack $ "uw" <> show i) i | i <- [1..5 :: Int] ]
+ updateWhere [PipeRegValue <=. 3] [PipeRegValue =. 0]
+ items <- selectList ([] :: [Filter PipeReg]) [Asc PipeRegValue]
+ let vals = map (pipeRegValue . entityVal) items
+ liftIO $ filter (== 0) vals @?= [0, 0, 0]
+ liftIO $ length vals @?= 5
+
+ ---------------------------------------------------------------------------
+ -- Cross-entity operations
+ ---------------------------------------------------------------------------
+
+ describe "cross-entity operations" $ do
+ it "interleave inserts across tables" $ db $ do
+ k1 <- insert $ PipeReg "cross1" 1
+ k2 <- insert $ PipeRegOther "other1" 1.5
+ k3 <- insert $ PipeReg "cross2" 2
+ k4 <- insert $ PipeRegOther "other2" 2.5
+ r1 <- get k1
+ r2 <- get k2
+ r3 <- get k3
+ r4 <- get k4
+ liftIO $ r1 @?= Just (PipeReg "cross1" 1)
+ liftIO $ r2 @?= Just (PipeRegOther "other1" 1.5)
+ liftIO $ r3 @?= Just (PipeReg "cross2" 2)
+ liftIO $ r4 @?= Just (PipeRegOther "other2" 2.5)
+
+ ---------------------------------------------------------------------------
+ -- Large batches
+ ---------------------------------------------------------------------------
+
+ describe "large batches" $ do
+ it "100 inserts then 100 gets" $ db $ do
+ keys <- mapM insert
+ [ PipeReg (T.pack $ "batch" <> show i) i | i <- [1..100 :: Int] ]
+ results <- mapM get keys
+ liftIO $ length (catMaybes results) @?= 100
+
+ it "100 deletes then count" $ db $ do
+ keys <- mapM insert
+ [ PipeReg (T.pack $ "del" <> show i) i | i <- [1..100 :: Int] ]
+ forM_ keys delete
+ n <- count ([] :: [Filter PipeReg])
+ liftIO $ n @?= 0
+
+ it "100 updates then verify" $ db $ do
+ keys <- mapM insert
+ [ PipeReg (T.pack $ "upd" <> show i) i | i <- [1..100 :: Int] ]
+ forM_ keys $ \k -> update k [PipeRegValue =. 42]
+ results <- mapM get keys
+ liftIO $ all (== Just 42) (map (fmap pipeRegValue) results) @?= True
+
+ ---------------------------------------------------------------------------
+ -- Edge cases
+ ---------------------------------------------------------------------------
+
+ describe "edge cases" $ do
+ it "empty transaction" $ db $ return ()
+
+ it "get nonexistent key" $ db $ do
+ r <- get (toSqlKey 999999 :: Key PipeReg)
+ liftIO $ r @?= Nothing
+
+ it "count empty table" $ db $ do
+ n <- count ([] :: [Filter PipeReg])
+ liftIO $ n @?= 0
+
+ it "exists empty table" $ db $ do
+ b <- exists ([] :: [Filter PipeReg])
+ liftIO $ b @?= False
+
+ it "getBy nonexistent unique" $ db $ do
+ r <- getBy (UniquePipeRegName "nonexistent")
+ liftIO $ r @?= Nothing
+
+ it "insert then immediate get same transaction" $ db $ do
+ k <- insert $ PipeReg "immediate" 42
+ r <- get k
+ liftIO $ r @?= Just (PipeReg "immediate" 42)
+
+ it "upsert pipeline" $ db $ do
+ e1 <- upsert (PipeReg "upserted" 1) [PipeRegValue =. 1]
+ e2 <- upsert (PipeReg "upserted" 2) [PipeRegValue =. 2]
+ r <- getBy (UniquePipeRegName "upserted")
+ liftIO $ pipeRegValue (entityVal e1) @?= 1
+ liftIO $ pipeRegValue (entityVal e2) @?= 2
+ liftIO $ (pipeRegValue . entityVal) <$> r @?= Just 2
diff --git a/persistent-postgresql-ng/test/PlaceholderSpec.hs b/persistent-postgresql-ng/test/PlaceholderSpec.hs
new file mode 100644
index 000000000..9252bb5e3
--- /dev/null
+++ b/persistent-postgresql-ng/test/PlaceholderSpec.hs
@@ -0,0 +1,59 @@
+{-# LANGUAGE OverloadedStrings #-}
+
+module PlaceholderSpec (specs) where
+
+import Test.Hspec
+import Database.Persist.Postgresql.Internal.Placeholders (rewritePlaceholders)
+
+specs :: Spec
+specs = describe "rewritePlaceholders" $ do
+ it "rewrites a single ? to $1" $ do
+ rewritePlaceholders "SELECT * FROM t WHERE id = ?"
+ `shouldBe` ("SELECT * FROM t WHERE id = $1", 1)
+
+ it "rewrites multiple ? to $1, $2, ..." $ do
+ rewritePlaceholders "INSERT INTO t (a, b, c) VALUES (?, ?, ?)"
+ `shouldBe` ("INSERT INTO t (a, b, c) VALUES ($1, $2, $3)", 3)
+
+ it "replaces ?? with literal ?" $ do
+ rewritePlaceholders "RETURNING ??"
+ `shouldBe` ("RETURNING ?", 0)
+
+ it "handles mixed ? and ??" $ do
+ rewritePlaceholders "INSERT INTO t (a) VALUES (?) RETURNING ??"
+ `shouldBe` ("INSERT INTO t (a) VALUES ($1) RETURNING ?", 1)
+
+ it "does not replace ? inside single-quoted strings" $ do
+ rewritePlaceholders "SELECT * FROM t WHERE name = '?'"
+ `shouldBe` ("SELECT * FROM t WHERE name = '?'", 0)
+
+ it "handles escaped quotes in string literals" $ do
+ rewritePlaceholders "SELECT * FROM t WHERE name = 'it''s a ?' AND id = ?"
+ `shouldBe` ("SELECT * FROM t WHERE name = 'it''s a ?' AND id = $1", 1)
+
+ it "does not replace ? inside double-quoted identifiers" $ do
+ rewritePlaceholders "SELECT \"what?\" FROM t WHERE id = ?"
+ `shouldBe` ("SELECT \"what?\" FROM t WHERE id = $1", 1)
+
+ it "does not replace ? inside line comments" $ do
+ rewritePlaceholders "SELECT * FROM t -- where id = ?\nWHERE name = ?"
+ `shouldBe` ("SELECT * FROM t -- where id = ?\nWHERE name = $1", 1)
+
+ it "does not replace ? inside block comments" $ do
+ rewritePlaceholders "SELECT * FROM t /* where id = ? */ WHERE name = ?"
+ `shouldBe` ("SELECT * FROM t /* where id = ? */ WHERE name = $1", 1)
+
+ it "handles nested block comments" $ do
+ rewritePlaceholders "SELECT /* outer /* inner ? */ still comment */ ?"
+ `shouldBe` ("SELECT /* outer /* inner ? */ still comment */ $1", 1)
+
+ it "handles empty input" $ do
+ rewritePlaceholders "" `shouldBe` ("", 0)
+
+ it "handles no placeholders" $ do
+ rewritePlaceholders "SELECT 1" `shouldBe` ("SELECT 1", 0)
+
+ it "handles large parameter counts" $ do
+ let sql = mconcat $ replicate 100 "?,"
+ (result, count) = rewritePlaceholders sql
+ count `shouldBe` 100
diff --git a/persistent-postgresql-ng/test/UpsertWhere.hs b/persistent-postgresql-ng/test/UpsertWhere.hs
new file mode 100644
index 000000000..ebd91ffbd
--- /dev/null
+++ b/persistent-postgresql-ng/test/UpsertWhere.hs
@@ -0,0 +1,207 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeApplications #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+
+module UpsertWhere where
+
+import PgPipelineInit
+
+import Data.Time
+import Database.Persist.Postgresql.Pipeline
+
+share
+ [mkPersist sqlSettings, mkMigrate "upsertWhereMigrate"]
+ [persistLowerCase|
+
+Item
+ name Text sqltype=varchar(80)
+ description Text
+ price Double Maybe
+ quantity Int Maybe
+
+ UniqueName name
+ deriving Eq Show Ord
+
+ItemMigOnly
+ name Text
+ price Double
+ quantity Int
+
+ UniqueNameMigOnly name
+
+ createdAt UTCTime MigrationOnly default=CURRENT_TIMESTAMP
+
+|]
+
+wipe :: IO ()
+wipe = runConnAssert $ do
+ deleteWhere ([] :: [Filter Item])
+ deleteWhere ([] :: [Filter ItemMigOnly])
+
+itDb
+ :: String -> SqlPersistT (LoggingT (ResourceT IO)) a -> SpecWith (Arg (IO ()))
+itDb msg action = it msg $ runConnAssert $ void action
+
+specs :: Spec
+specs = describe "UpsertWhere" $ do
+ let
+ item1 = Item "item1" "" (Just 3) Nothing
+ item2 = Item "item2" "hello world" Nothing (Just 2)
+ items = [item1, item2]
+
+ describe "upsertWhere" $ before_ wipe $ do
+ itDb "inserts appropriately" $ do
+ upsertWhere item1 [ItemDescription =. "i am item 1"] []
+ Just item <- fmap entityVal <$> getBy (UniqueName "item1")
+ item `shouldBe` item1
+ itDb "performs only updates given if record already exists" $ do
+ let
+ newDescription = "I am a new description"
+ insert_ item1
+ upsertWhere
+ (Item "item1" "i am an inserted description" (Just 1) (Just 2))
+ [ItemDescription =. newDescription]
+ []
+ Just item <- fmap entityVal <$> getBy (UniqueName "item1")
+ item `shouldBe` item1{itemDescription = newDescription}
+
+ itDb "inserts with MigrationOnly fields (#1330)" $ do
+ upsertWhere
+ (ItemMigOnly "foobar" 20 1)
+ [ItemMigOnlyPrice +=. 2]
+ []
+
+ describe "upsertManyWhere" $ do
+ itDb "inserts fresh records" $ do
+ insertMany_ items
+ let
+ newItem = Item "item3" "fresh" Nothing Nothing
+ upsertManyWhere
+ (newItem : items)
+ [copyField ItemDescription]
+ []
+ []
+ dbItems <- map entityVal <$> selectList [] []
+ dbItems `shouldMatchList` (newItem : items)
+ itDb "updates existing records" $ do
+ let
+ postUpdate =
+ map (\i -> i{itemQuantity = fmap (+ 1) (itemQuantity i)}) items
+ insertMany_ items
+ upsertManyWhere
+ items
+ []
+ [ItemQuantity +=. Just 1]
+ []
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` postUpdate
+ itDb "only copies passing values" $ do
+ insertMany_ items
+ let
+ newItems =
+ map
+ (\i -> i{itemQuantity = Just 0, itemPrice = fmap (* 2) (itemPrice i)})
+ items
+ postUpdate = map (\i -> i{itemPrice = fmap (* 2) (itemPrice i)}) items
+ upsertManyWhere
+ newItems
+ [ copyUnlessEq ItemQuantity (Just 0)
+ , copyField ItemPrice
+ ]
+ []
+ []
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` postUpdate
+ itDb "inserts without modifying existing records if no updates specified" $ do
+ let
+ newItem = Item "item3" "hi friends!" Nothing Nothing
+ insertMany_ items
+ upsertManyWhere
+ (newItem : items)
+ []
+ []
+ []
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` (newItem : items)
+ itDb "inserts without modifying existing records with True filter condition" $ do
+ let
+ newItem = Item "item3" "hi friends!" Nothing Nothing
+ insertMany_ items
+ upsertManyWhere
+ (newItem : items)
+ []
+ []
+ [ItemDescription ==. "hi friends!"]
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` (newItem : items)
+ itDb "inserts without updating with False filter condition" $ do
+ let
+ newItem = Item "item3" "hi friends!" Nothing Nothing
+ insertMany_ items
+ upsertManyWhere
+ (newItem : items)
+ []
+ [ItemQuantity +=. Just 1]
+ [ItemDescription ==. "hi friends!"]
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` (newItem : items)
+ itDb "doesn't apply update with excludeNotEqualToOriginal" $ do
+ let
+ newItem = Item "item3" "hi friends!" Nothing Nothing
+ insertMany_ items
+ upsertManyWhere
+ (newItem : items)
+ []
+ [ItemQuantity +=. Just 1]
+ [excludeNotEqualToOriginal ItemDescription]
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` (newItem : items)
+ itDb "inserts new and updates existing with empty filter" $ do
+ let
+ newItem = Item "item3" "hello world" Nothing Nothing
+ postUpdate = map (\i -> i{itemQuantity = fmap (+ 1) (itemQuantity i)}) items
+ insertMany_ items
+ upsertManyWhere
+ (newItem : items)
+ []
+ [ItemQuantity +=. Just 1]
+ []
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` (newItem : postUpdate)
+ itDb "inserts new and updates existing with matching filter" $ do
+ let
+ newItem = Item "item3" "hi friends!" Nothing Nothing
+ postUpdate = map (\i -> i{itemQuantity = fmap (+ 1) (itemQuantity i)}) items
+ insertMany_ items
+ upsertManyWhere
+ (newItem : items)
+ [ copyUnlessEq ItemDescription "hi friends!"
+ , copyField ItemPrice
+ ]
+ [ItemQuantity +=. Just 1]
+ [ItemDescription !=. "bye friends!"]
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` (newItem : postUpdate)
+ itDb "insert item doesn't apply update with excludeNotEqualToOriginal" $ do
+ let
+ newItem = Item "item3" "hello world" Nothing Nothing
+ insertMany_ items
+ upsertManyWhere
+ (newItem : items)
+ []
+ [ItemQuantity +=. Just 1]
+ [excludeNotEqualToOriginal ItemDescription]
+ dbItems <- fmap entityVal <$> selectList [] []
+ dbItems `shouldMatchList` (newItem : items)
diff --git a/persistent-postgresql-ng/test/main.hs b/persistent-postgresql-ng/test/main.hs
new file mode 100644
index 000000000..5fde21e64
--- /dev/null
+++ b/persistent-postgresql-ng/test/main.hs
@@ -0,0 +1,280 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DerivingStrategies #-}
+{-# LANGUAGE ExistentialQuantification #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TemplateHaskell #-}
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -Wno-unused-top-binds #-}
+
+import PgPipelineInit
+
+import Data.Aeson
+import qualified Data.ByteString as BS
+import Data.Fixed
+import Data.IntMap (IntMap)
+import qualified Data.Text as T
+import Data.Time
+import Test.QuickCheck
+
+import qualified ArrayAggTest
+import qualified BinaryRoundTripSpec
+import qualified DirectDecodeSpec
+import qualified DirectEntityPOC
+import qualified CompositeTest
+import qualified CustomConstraintTest
+import qualified CustomPersistFieldTest
+import qualified CustomPrimaryKeyReferenceTest
+import qualified DataTypeTest
+import qualified EmbedOrderTest
+import qualified EmbedTest
+import qualified EmptyEntityTest
+import qualified EquivalentTypeTestPostgres
+import qualified ForeignKey
+import qualified GeneratedColumnTestSQL
+import qualified HtmlTest
+import qualified ImplicitUuidSpec
+import qualified InCollapseSpec
+import qualified JSONTest
+import qualified LargeNumberTest
+import qualified LongIdentifierTest
+import qualified MaxLenTest
+import qualified MaybeFieldDefsTest
+import qualified MigrationColumnLengthTest
+import qualified MigrationOnlyTest
+import qualified MigrationReferenceSpec
+import qualified MigrationSpec
+import qualified MigrationTest
+import qualified MpsCustomPrefixTest
+import qualified MpsNoPrefixTest
+import qualified PersistUniqueTest
+import qualified PersistentTest
+import qualified PgPipelineIntervalTest
+import qualified PipelineDeferralSpec
+import qualified PipelineModeSpec
+import qualified PipelineRegressionSpec
+import qualified PlaceholderSpec
+import qualified PrimaryTest
+import qualified RawSqlTest
+import qualified ReadWriteTest
+import qualified Recursive
+import qualified RenameTest
+import qualified SumTypeTest
+import qualified TransactionLevelTest
+import qualified TreeTest
+import qualified TypeLitFieldDefsTest
+import qualified UniqueTest
+import qualified UpsertTest
+import qualified UpsertWhere
+
+type Tuple = (,)
+
+-- Test lower case names
+share
+ [mkPersist persistSettings, mkMigrate "dataTypeMigrate"]
+ [persistLowerCase|
+DataTypeTable no-json
+ text Text
+ textMaxLen Text maxlen=100
+ bytes ByteString
+ bytesTextTuple (Tuple ByteString Text)
+ bytesMaxLen ByteString maxlen=100
+ int Int
+ intList [Int]
+ intMap (IntMap Int)
+ double Double
+ bool Bool
+ day Day
+ pico Pico
+ time TimeOfDay
+ utc UTCTime
+ jsonb Value
+|]
+
+instance Arbitrary DataTypeTable where
+ arbitrary =
+ DataTypeTable
+ <$> arbText -- text
+ <*> (T.take 100 <$> arbText) -- textManLen
+ <*> arbitrary -- bytes
+ <*> liftA2 (,) arbitrary arbText -- bytesTextTuple
+ <*> (BS.take 100 <$> arbitrary) -- bytesMaxLen
+ <*> arbitrary -- int
+ <*> arbitrary -- intList
+ <*> arbitrary -- intMap
+ <*> arbitrary -- double
+ <*> arbitrary -- bool
+ <*> arbitrary -- day
+ <*> arbitrary -- pico
+ <*> (arbitrary) -- utc
+ <*> (truncateUTCTime =<< arbitrary) -- utc
+ <*> fmap getValue arbitrary -- value
+
+setup :: (MonadIO m) => Migration -> ReaderT SqlBackend m ()
+setup migration = do
+ printMigration migration
+ runMigrationUnsafe migration
+
+-- | All fetch modes to test. FetchChunked uses a small chunk size to
+-- exercise the multi-result loop even on small result sets.
+allFetchModes :: [(String, FetchMode)]
+allFetchModes =
+ [ ("FetchAll", FetchAll)
+ , ("FetchSingleRow", FetchSingleRow)
+ , ("FetchChunked 2", FetchChunked 2)
+ ]
+
+main :: IO ()
+main = do
+ -- Pure tests run first (no database required)
+ hspec $ do
+ PlaceholderSpec.specs
+ InCollapseSpec.specs
+ BinaryRoundTripSpec.specs
+ DirectDecodeSpec.specs
+ DirectEntityPOC.specs
+
+ -- Database tests require PostgreSQL
+ -- Migrations run once under FetchAll (DDL doesn't use row fetch modes)
+ runConn $ do
+ mapM_
+ setup
+ [ PersistentTest.testMigrate
+ , PersistentTest.noPrefixMigrate
+ , PersistentTest.customPrefixMigrate
+ , PersistentTest.treeMigrate
+ , EmbedTest.embedMigrate
+ , EmbedOrderTest.embedOrderMigrate
+ , LargeNumberTest.numberMigrate
+ , UniqueTest.uniqueMigrate
+ , MaxLenTest.maxlenMigrate
+ , MaybeFieldDefsTest.maybeFieldDefMigrate
+ , TypeLitFieldDefsTest.typeLitFieldDefsMigrate
+ , Recursive.recursiveMigrate
+ , CompositeTest.compositeMigrate
+ , TreeTest.treeMigrate
+ , PersistUniqueTest.migration
+ , RenameTest.migration
+ , CustomPersistFieldTest.customFieldMigrate
+ , PrimaryTest.migration
+ , CustomPrimaryKeyReferenceTest.migration
+ , MigrationColumnLengthTest.migration
+ , TransactionLevelTest.migration
+ , LongIdentifierTest.migration
+ , ForeignKey.compositeMigrate
+ , MigrationTest.migrationMigrate
+ , PgPipelineIntervalTest.pgIntervalMigrate
+ , UpsertWhere.upsertWhereMigrate
+ , ImplicitUuidSpec.implicitUuidMigrate
+ ]
+ PersistentTest.cleanDB
+ ForeignKey.cleanDB
+
+ -- Run database tests for each fetch mode
+ forM_ allFetchModes $ \(modeName, mode) -> do
+ -- Clean up between mode runs to prevent test isolation issues
+ runConnWith mode $ do
+ PersistentTest.cleanDB
+ ForeignKey.cleanDB
+ let run = runConnAssertWith mode
+ hspec $ describe modeName $ do
+ -- Pipeline-specific tests
+ PipelineModeSpec.specsWith mode
+ PipelineDeferralSpec.specs
+ PgPipelineIntervalTest.specs
+ PipelineRegressionSpec.specs
+
+ -- DirectEntity integration: rawSqlDirectCompat through SqlBackend
+ DirectEntityPOC.integrationSpecs
+ describe "rawSqlDirectCompat through SqlBackend" $ do
+ it "decodes rows via DirectEntity + PgRowEnv" $ run $ do
+ result <- DirectEntityPOC.rawSqlDirectCompatTest
+ liftIO $ result `shouldBe` Just
+ [ DirectEntityPOC.PgPair (T.pack "Alice") 30
+ , DirectEntityPOC.PgPair (T.pack "Bob") 25
+ ]
+
+ -- Postgres-specific tests
+ ImplicitUuidSpec.spec
+ MigrationReferenceSpec.spec
+ MigrationSpec.spec
+ EquivalentTypeTestPostgres.specs
+ JSONTest.specs
+ CustomConstraintTest.specs
+ UpsertWhere.specs
+ ArrayAggTest.specs
+
+ -- Shared persistent-test suite
+ RenameTest.specsWith run
+ DataTypeTest.specsWith
+ run
+ (Just (runMigrationSilent dataTypeMigrate))
+ [ TestFn "text" dataTypeTableText
+ , TestFn "textMaxLen" dataTypeTableTextMaxLen
+ , TestFn "bytes" dataTypeTableBytes
+ , TestFn "bytesTextTuple" dataTypeTableBytesTextTuple
+ , TestFn "bytesMaxLen" dataTypeTableBytesMaxLen
+ , TestFn "int" dataTypeTableInt
+ , TestFn "intList" dataTypeTableIntList
+ , TestFn "intMap" dataTypeTableIntMap
+ , TestFn "bool" dataTypeTableBool
+ , TestFn "day" dataTypeTableDay
+ , TestFn "time" (DataTypeTest.roundTime . dataTypeTableTime)
+ , TestFn "utc" (DataTypeTest.roundUTCTime . dataTypeTableUtc)
+ , TestFn "jsonb" dataTypeTableJsonb
+ ]
+ [("pico", dataTypeTablePico)]
+ dataTypeTableDouble
+ HtmlTest.specsWith
+ run
+ (Just (runMigrationSilent HtmlTest.htmlMigrate))
+
+ EmbedTest.specsWith run
+ EmbedOrderTest.specsWith run
+ LargeNumberTest.specsWith run
+ ForeignKey.specsWith run
+ UniqueTest.specsWith run
+ MaxLenTest.specsWith run
+ MaybeFieldDefsTest.specsWith run
+ TypeLitFieldDefsTest.specsWith run
+ Recursive.specsWith run
+ SumTypeTest.specsWith
+ run
+ (Just (runMigrationSilent SumTypeTest.sumTypeMigrate))
+ MigrationTest.specsWith run
+ MigrationOnlyTest.specsWith
+ run
+ ( Just $
+ runMigrationSilent MigrationOnlyTest.migrateAll1
+ >> runMigrationSilent MigrationOnlyTest.migrateAll2
+ )
+ PersistentTest.specsWith run
+ ReadWriteTest.specsWith run
+ PersistentTest.filterOrSpecs run
+ RawSqlTest.specsWith run
+ UpsertTest.specsWith
+ run
+ UpsertTest.Don'tUpdateNull
+ UpsertTest.UpsertPreserveOldKey
+
+ MpsNoPrefixTest.specsWith run
+ MpsCustomPrefixTest.specsWith run
+ EmptyEntityTest.specsWith
+ run
+ (Just (runMigrationSilent EmptyEntityTest.migration))
+ CompositeTest.specsWith run
+ TreeTest.specsWith run
+ PersistUniqueTest.specsWith run
+ PrimaryTest.specsWith run
+ CustomPersistFieldTest.specsWith run
+ CustomPrimaryKeyReferenceTest.specsWith run
+ MigrationColumnLengthTest.specsWith run
+ TransactionLevelTest.specsWith run
+ LongIdentifierTest.specsWith run
+ GeneratedColumnTestSQL.specsWith run