diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/CustomType.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/CustomType.hs new file mode 100644 index 000000000..b71f082af --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/CustomType.hs @@ -0,0 +1,275 @@ +{-# LANGUAGE OverloadedStrings #-} + +-- | +-- Module : Database.Persist.Postgresql.CustomType +-- Description : Guide for writing custom PersistField instances with the binary protocol backend +-- +-- = Overview +-- +-- The pipeline backend uses PostgreSQL's binary wire protocol instead of the +-- text protocol used by @persistent-postgresql@ (via @postgresql-simple@). This +-- is faster — no parsing/rendering of text representations — but it changes +-- how some custom types need to be encoded. +-- +-- Most custom @PersistField@ instances work without modification. This module +-- documents the cases where you may need to adjust your approach. +-- +-- = How PersistValue maps to PostgreSQL +-- +-- The binary backend encodes each @PersistValue@ with a specific PostgreSQL +-- OID (type identifier) and binary representation: +-- +-- +------------------------+------------+-------------+-----------------------------------+ +-- | PersistValue | PG Type | OID | Notes | +-- +========================+============+=============+===================================+ +-- | @PersistText@ | text | 25 | Binary UTF-8 | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistByteString@ | bytea | 17 | Raw binary | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistInt64@ | int8 | 20 | 8-byte big-endian | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistDouble@ | float8 | 701 | IEEE 754 double | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistRational@ | numeric | 1700 | Arbitrary precision | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistBool@ | bool | 16 | Single byte | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistDay@ | date | 1082 | Days since 2000-01-01 | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistTimeOfDay@ | time | 1083 | Microseconds since midnight | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistUTCTime@ | timestamptz| 1184 | Microseconds since 2000-01-01 UTC | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistList@ | (unknown) | 0 | JSON text, PG infers type | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistArray@ | \[] | varies | Native PostgreSQL array | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistMap@ | (unknown) | 0 | JSON text, PG infers type | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistLiteral_ Escaped@ | (unknown) | 0 | Text format, PG infers type | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistLiteral_ DbSpecific@| (unknown) | 0 | Text format, PG infers type | +-- +------------------------+------------+-------------+-----------------------------------+ +-- | @PersistLiteral_ Unescaped@ | — | — | Inlined into SQL text | +-- +------------------------+------------+-------------+-----------------------------------+ +-- +-- = Custom Types That Just Work +-- +-- If your custom type stores as one of the standard @PersistValue@ constructors, +-- it will work without any changes: +-- +-- @ +-- newtype Email = Email Text +-- +-- instance PersistField Email where +-- toPersistValue (Email t) = PersistText t +-- fromPersistValue (PersistText t) = Right (Email t) +-- fromPersistValue _ = Left "Expected PersistText for Email" +-- +-- instance PersistFieldSql Email where +-- sqlType _ = SqlString +-- @ +-- +-- This covers the vast majority of custom types: newtypes over @Text@, @Int@, +-- @Double@, @Bool@, @Day@, @UTCTime@, @ByteString@, etc. +-- +-- = Types Requiring OID 0 (Type Inference) +-- +-- When PostgreSQL's column type doesn't match any built-in @PersistValue@ OID, +-- you need the value to be sent with OID 0 so PostgreSQL infers the type from +-- the column. Use @PersistLiteral_ Escaped@ or @PersistLiteral_ DbSpecific@: +-- +-- @ +-- newtype UUID = UUID Text +-- +-- instance PersistField UUID where +-- toPersistValue (UUID t) = PersistLiteral_ Escaped (encodeUtf8 t) +-- fromPersistValue (PersistLiteral_ Escaped bs) = Right (UUID (decodeUtf8 bs)) +-- fromPersistValue _ = Left "Expected PersistLiteral_ Escaped for UUID" +-- +-- instance PersistFieldSql UUID where +-- sqlType _ = SqlOther "UUID" +-- @ +-- +-- The bytes in @PersistLiteral_ Escaped@ are sent as-is with OID 0 in text +-- format. PostgreSQL sees the text representation and casts it to the column +-- type (UUID in this case). This is the same behavior as @postgresql-simple@'s +-- @Unknown@ type. +-- +-- __Other types that use this pattern:__ @INET@, @CIDR@, @MACADDR@, @LTREE@, +-- @HSTORE@, custom enum types, PostGIS geometry types. +-- +-- = Types Requiring SQL Inlining +-- +-- Some PostgreSQL types cannot accept /any/ parameterized input — even with +-- OID 0, PostgreSQL won't cast. For these, use @PersistLiteral_ Unescaped@ +-- to inline the value directly into the SQL text: +-- +-- @ +-- newtype PgInterval = PgInterval NominalDiffTime +-- +-- instance PersistField PgInterval where +-- toPersistValue (PgInterval ndt) = +-- let (sci, _) = fromRationalRepetendUnlimited (toRational ndt) +-- s = formatScientific Fixed Nothing sci +-- in PersistLiteral_ Unescaped (encodeUtf8 ("'" <> pack s <> " seconds'::interval")) +-- fromPersistValue (PersistRational r) = Right (PgInterval (fromRational r)) +-- fromPersistValue _ = Left "Expected PersistRational for PgInterval" +-- @ +-- +-- __Important:__ @PersistLiteral_ Unescaped@ is inlined into the SQL string +-- /before/ parameter encoding. The value must be a valid SQL expression. Always +-- use explicit casts (e.g. @'...'::interval@) and ensure the content is safe +-- (no user-controlled input without validation). +-- +-- = JSON and JSONB Columns +-- +-- Aeson's @Value@ type is stored as @PersistLiteral_ Escaped@ with the JSON +-- bytes. The binary backend sends this with OID 0, so PostgreSQL infers the +-- @jsonb@ type from the column. +-- +-- If you have a custom type that serializes to JSON for a @jsonb@ column: +-- +-- @ +-- import Data.Aeson (ToJSON, FromJSON, encode, eitherDecodeStrict) +-- import qualified Data.ByteString.Lazy as BSL +-- +-- instance PersistField MyJsonType where +-- toPersistValue = PersistLiteralEscaped . BSL.toStrict . encode +-- fromPersistValue (PersistByteString bs) = +-- case eitherDecodeStrict bs of +-- Right v -> Right v +-- Left err -> Left (pack err) +-- fromPersistValue _ = Left "Expected PersistByteString for MyJsonType" +-- +-- instance PersistFieldSql MyJsonType where +-- sqlType _ = SqlOther "jsonb" +-- @ +-- +-- The @PersistLiteralEscaped@ (= @PersistLiteral_ Escaped@) encoding sends +-- with OID 0, so PostgreSQL accepts the JSON text for a @jsonb@ column. +-- +-- = PersistList vs PersistArray +-- +-- These two constructors serve different purposes in this backend: +-- +-- [@PersistList@] Encodes as __JSON text__ with OID 0 (unknown). PostgreSQL +-- infers the column type, so this works for both @VARCHAR@ and @jsonb@ +-- columns. Used by persistent internally for embedded entities. If your +-- custom type uses @PersistList@, it will be stored as a JSON text array +-- like @[1,2,3]@. +-- +-- [@PersistArray@] Encodes as a __native PostgreSQL array__ (e.g. @int8[]@, +-- @text[]@). The element type is inferred from the first non-null element. +-- This is used by the @IN \u2192 ANY@ rewriting: @WHERE id IN (?,?,?)@ +-- becomes @WHERE id = ANY($1)@ with a single @PersistArray@ parameter. +-- +-- If you want your custom list type to use native arrays: +-- +-- @ +-- newtype IntList = IntList [Int64] +-- +-- instance PersistField IntList where +-- toPersistValue (IntList xs) = PersistArray (map PersistInt64 xs) +-- fromPersistValue (PersistList xs) = IntList \<$\> mapM fromPersistValue xs +-- fromPersistValue _ = Left "Expected PersistList for IntList" +-- +-- instance PersistFieldSql IntList where +-- sqlType _ = SqlOther "int8[]" +-- @ +-- +-- = PostgreSQL Enum Types +-- +-- Custom PostgreSQL enums work with @PersistLiteral_ Escaped@ or @PersistText@: +-- +-- @ +-- data Color = Red | Green | Blue +-- +-- instance PersistField Color where +-- toPersistValue Red = PersistLiteral_ Escaped "red" +-- toPersistValue Green = PersistLiteral_ Escaped "green" +-- toPersistValue Blue = PersistLiteral_ Escaped "blue" +-- fromPersistValue (PersistLiteral_ Escaped "red") = Right Red +-- fromPersistValue (PersistLiteral_ Escaped "green") = Right Green +-- fromPersistValue (PersistLiteral_ Escaped "blue") = Right Blue +-- fromPersistValue _ = Left "Invalid Color" +-- +-- instance PersistFieldSql Color where +-- sqlType _ = SqlOther "color" -- must match CREATE TYPE color AS ENUM (...) +-- @ +-- +-- Using @PersistLiteral_ Escaped@ sends with OID 0, letting PostgreSQL match +-- the text @\"red\"@ against the enum type. +-- +-- = When Do You Need Explicit SQL Casts? +-- +-- The short answer: __almost never__ with standard persistent operations. +-- +-- Standard persistent operations (@insert@, @update@, @selectList@, +-- @deleteWhere@, etc.) generate SQL like @INSERT INTO "table" ("col") VALUES (?)@ +-- or @UPDATE "table" SET "col" = ? WHERE ...@. In these contexts, PostgreSQL +-- knows the column type from the table definition and will infer the parameter +-- type from OID 0. This covers: +-- +-- * Custom enums via @PersistLiteral_ Escaped@ +-- * UUIDs via @PersistLiteral_ Escaped@ +-- * JSONB via @PersistLiteral_ Escaped@ or @PersistList@/@PersistMap@ +-- * All standard @PersistValue@ types +-- +-- You __do__ need explicit casts (@::type@) in these cases: +-- +-- * __@rawSql@ / @rawExecute@__ where PostgreSQL can't infer the type from +-- context. For example, @SELECT ? + 1@ is ambiguous — PostgreSQL doesn't +-- know if @?@ is @int4@, @int8@, @numeric@, etc. Write +-- @SELECT ?::int8 + 1@ instead. +-- +-- * __Function arguments__ where PostgreSQL has multiple overloads. For example, +-- @jsonb_build_object('key', ?)@ may need @?::text@ if PostgreSQL can't +-- resolve the overload. +-- +-- * __@PersistLiteral_ Unescaped@__ values are always inlined into SQL and +-- should include their own cast (e.g. @'5 seconds'::interval@). +-- +-- = Unsupported Operations +-- +-- The following PostgreSQL features are __incompatible with pipeline mode__ and +-- will not work with this backend: +-- +-- * __COPY__ (@COPY FROM@ / @COPY TO@) — uses a separate sub-protocol that +-- conflicts with pipeline mode. +-- +-- * __LISTEN/NOTIFY__ — asynchronous notifications require consuming results +-- outside the pipeline flow, which conflicts with the pending result counter. +-- +-- * __Large Objects__ — large object operations use a separate API that is not +-- supported in pipeline mode. +-- +-- These operations are also uncommon in persistent-based applications. If you +-- need them, use the raw @LibPQ.Connection@ via @getPipelineConn@ and manage +-- the pipeline state yourself. +-- +-- = Migration from persistent-postgresql +-- +-- When migrating from @persistent-postgresql@ to @persistent-postgresql-ng@: +-- +-- 1. __Most types work unchanged.__ Standard @PersistField@ instances using +-- @PersistText@, @PersistInt64@, @PersistBool@, etc. need no changes. +-- +-- 2. __Types using @PersistRational@ for non-numeric columns__ (like +-- @interval@) need to switch to @PersistLiteral_ Unescaped@ with an +-- explicit cast. In the text protocol, PostgreSQL would auto-cast a text +-- number to @interval@; the binary protocol sends with OID 1700 (numeric) +-- which PostgreSQL won't cast. +-- +-- 3. __UUID types__ work if they use @PersistLiteral_ Escaped@ with the +-- hex text representation. The backend sends with OID 0 and decodes +-- binary UUIDs to hex text on the way back. +-- +-- 4. __Array columns__ can now use @PersistArray@ for native PostgreSQL +-- arrays instead of JSON-encoded text. The @IN@ operator automatically +-- uses native arrays via the @= ANY(...)@ rewriting. +-- +-- 5. __JSONB__ values stored via @PersistLiteral_ Escaped@ (the standard +-- pattern from the JSON module) work correctly — OID 0 lets PostgreSQL +-- accept the JSON text for @jsonb@ columns. +module Database.Persist.Postgresql.CustomType () where diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal.hs new file mode 100644 index 000000000..c0404ee6a --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal.hs @@ -0,0 +1,70 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} + +-- | Shared internal utilities for the pipeline backend. +-- +-- Contains escape functions, the 'PgInterval' type, and re-exports from the +-- migration module. +module Database.Persist.Postgresql.Internal + ( PgInterval (..) + , AlterDB (..) + , AlterTable (..) + , AlterColumn (..) + , SafeToRemove + , migrateStructured + , migrateEntitiesStructured + , mockMigrateStructured + , addTable + , findAlters + , maySerial + , mayDefault + , showSqlType + , showColumn + , showAlter + , showAlterDb + , showAlterTable + , getAddReference + , udToPair + , safeToRemove + , postgresMkColumns + , getAlters + , escapeE + , escapeF + , escape + ) where + +import Data.Scientific (FPFormat (Fixed), formatScientific, fromRationalRepetendUnlimited) +import qualified Data.Text +import qualified Data.Text.Encoding +import Data.Time (NominalDiffTime) +import Database.Persist.Sql +import Database.Persist.Postgresql.Internal.Migration + +-- | Represent Postgres interval using NominalDiffTime. +-- +-- Note that this type cannot be losslessly round tripped through PostgreSQL. +-- For example the value @'PgInterval' 0.0000009@ will truncate extra +-- precision. And the value @'PgInterval' 9223372036854.775808@ will overflow. +newtype PgInterval = PgInterval {getPgInterval :: NominalDiffTime} + deriving (Eq, Show) + +instance PersistField PgInterval where + toPersistValue (PgInterval ndt) = + -- Inline the interval literal into the SQL text because binary-protocol + -- numeric (PersistRational) can't be implicitly cast to interval, and + -- text with OID 25 also can't. Unescaped literals get inlined before + -- parameter encoding. + let r = toRational ndt + -- Format as a quoted interval literal: '123.456000 seconds'::interval + -- Use Scientific for exact decimal representation (no Double precision loss) + (sci, _) = fromRationalRepetendUnlimited r + s = formatScientific Fixed Nothing sci + in PersistLiteral_ Unescaped (Data.Text.Encoding.encodeUtf8 + (Data.Text.pack ("'" <> s <> " seconds'::interval"))) + fromPersistValue (PersistRational r) = + Right $ PgInterval (fromRational r) + fromPersistValue x = + Left $ "PgInterval: expected PersistRational, got: " <> Data.Text.pack (show x) + +instance PersistFieldSql PgInterval where + sqlType _ = SqlOther "interval" diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Decoding.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Decoding.hs new file mode 100644 index 000000000..88d4655d0 --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Decoding.hs @@ -0,0 +1,162 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE ViewPatterns #-} + +-- | Decode binary result columns from PostgreSQL into 'PersistValue'. +-- +-- Replaces the @getGetter@/@builtinGetters@ from the existing +-- @persistent-postgresql@ package, using @postgresql-binary@ decoders +-- directly instead of going through @postgresql-simple@. +module Database.Persist.Postgresql.Internal.Decoding + ( decodePersistValue + ) where + +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS +import Data.Int (Int16, Int32, Int64) +import Data.Maybe (fromMaybe) +import Control.Monad (replicateM) +import Data.Word (Word8) +import Data.Scientific (Scientific) +import Data.Text (Text) +import Data.Time (localTimeToUTC, utc) +import qualified Database.PostgreSQL.LibPQ as LibPQ +import Database.Persist.Sql (PersistValue (..)) +import qualified PostgreSQL.Binary.Decoding as PD + +import Database.Persist.Postgresql.Internal.PgType + +-- | Decode a binary column value from PostgreSQL into a 'PersistValue'. +-- +-- If the value is @Nothing@, returns @PersistNull@. +-- Otherwise, dispatches on the column OID to the appropriate decoder. +-- Falls back to @PersistLiteralEscaped@ for unknown OIDs. +decodePersistValue :: LibPQ.Oid -> Maybe ByteString -> Either Text PersistValue +decodePersistValue _ Nothing = Right PersistNull +decodePersistValue (fromOid -> pt) (Just bs) = decodeByType pt bs + +-- | Decode based on classified 'PgType'. +decodeByType :: PgType -> ByteString -> Either Text PersistValue + +-- Scalar types +decodeByType (Scalar PgBool) bs = PersistBool <$> run PD.bool bs +decodeByType (Scalar PgBytea) bs = PersistByteString <$> run PD.bytea_strict bs +decodeByType (Scalar PgChar) bs = PersistText <$> run PD.text_strict bs +decodeByType (Scalar PgName) bs = PersistText <$> run PD.text_strict bs +decodeByType (Scalar PgInt8) bs = PersistInt64 <$> run (PD.int :: PD.Value Int64) bs +decodeByType (Scalar PgInt2) bs = PersistInt64 . fromIntegral <$> run (PD.int :: PD.Value Int16) bs +decodeByType (Scalar PgInt4) bs = PersistInt64 . fromIntegral <$> run (PD.int :: PD.Value Int32) bs +decodeByType (Scalar PgText) bs = PersistText <$> run PD.text_strict bs +decodeByType (Scalar PgXml) bs = PersistByteString <$> run PD.bytea_strict bs +decodeByType (Scalar PgFloat4) bs = PersistDouble . realToFrac <$> run PD.float4 bs +decodeByType (Scalar PgFloat8) bs = PersistDouble <$> run PD.float8 bs +decodeByType (Scalar PgMoney) bs = PersistRational . fromIntegral <$> run (PD.int :: PD.Value Int64) bs +decodeByType (Scalar PgBpchar) bs = PersistText <$> run PD.text_strict bs +decodeByType (Scalar PgVarchar) bs = PersistText <$> run PD.text_strict bs +decodeByType (Scalar PgDate) bs = PersistDay <$> run PD.date bs +decodeByType (Scalar PgTime) bs = PersistTimeOfDay <$> run PD.time_int bs +decodeByType (Scalar PgTimestamp) bs = PersistUTCTime . localTimeToUTC utc <$> run PD.timestamp_int bs +decodeByType (Scalar PgTimestamptz) bs = PersistUTCTime <$> run PD.timestamptz_int bs +decodeByType (Scalar PgInterval) bs = PersistRational . toRational <$> run PD.interval_int bs +decodeByType (Scalar PgBit) bs = PersistByteString <$> run PD.bytea_strict bs +decodeByType (Scalar PgVarbit) bs = PersistByteString <$> run PD.bytea_strict bs +decodeByType (Scalar PgNumeric) bs = decodeNumeric bs +decodeByType (Scalar PgVoid) _ = Right PersistNull +decodeByType (Scalar PgJson) bs = PersistByteString <$> run PD.bytea_strict bs +decodeByType (Scalar PgJsonb) bs = PersistByteString <$> decodeJsonb bs +decodeByType (Scalar PgUnknown) bs = PersistByteString <$> run PD.bytea_strict bs +decodeByType (Scalar PgUuid) bs = do + raw <- run PD.bytea_strict bs + Right $ PersistLiteralEscaped (uuidBytesToText raw) + +-- Array types +decodeByType (Array PgBool) bs = decodeArray (PersistBool <$> PD.bool) bs +decodeByType (Array PgBytea) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs +decodeByType (Array PgChar) bs = decodeArray (PersistText <$> PD.text_strict) bs +decodeByType (Array PgName) bs = decodeArray (PersistText <$> PD.text_strict) bs +decodeByType (Array PgInt8) bs = decodeArray (PersistInt64 <$> (PD.int :: PD.Value Int64)) bs +decodeByType (Array PgInt2) bs = decodeArray (PersistInt64 . fromIntegral <$> (PD.int :: PD.Value Int16)) bs +decodeByType (Array PgInt4) bs = decodeArray (PersistInt64 . fromIntegral <$> (PD.int :: PD.Value Int32)) bs +decodeByType (Array PgText) bs = decodeArray (PersistText <$> PD.text_strict) bs +decodeByType (Array PgXml) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs +decodeByType (Array PgFloat4) bs = decodeArray (PersistDouble . realToFrac <$> PD.float4) bs +decodeByType (Array PgFloat8) bs = decodeArray (PersistDouble <$> PD.float8) bs +decodeByType (Array PgTimestamp) bs = decodeArray (PersistUTCTime . localTimeToUTC utc <$> PD.timestamp_int) bs +decodeByType (Array PgTimestamptz) bs = decodeArray (PersistUTCTime <$> PD.timestamptz_int) bs +decodeByType (Array PgDate) bs = decodeArray (PersistDay <$> PD.date) bs +decodeByType (Array PgTime) bs = decodeArray (PersistTimeOfDay <$> PD.time_int) bs +decodeByType (Array PgMoney) bs = decodeArray (PersistRational . fromIntegral <$> (PD.int :: PD.Value Int64)) bs +decodeByType (Array PgBpchar) bs = decodeArray (PersistText <$> PD.text_strict) bs +decodeByType (Array PgVarchar) bs = decodeArray (PersistText <$> PD.text_strict) bs +decodeByType (Array PgInterval) bs = decodeArray (PersistRational . toRational <$> PD.interval_int) bs +decodeByType (Array PgBit) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs +decodeByType (Array PgVarbit) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs +decodeByType (Array PgNumeric) bs = decodeArrayNumeric bs +decodeByType (Array PgJson) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs +decodeByType (Array PgJsonb) bs = decodeArray (PersistByteString <$> PD.bytea_strict) bs +decodeByType (Array PgUuid) bs = decodeArray (PersistLiteralEscaped . uuidBytesToText <$> PD.bytea_strict) bs + +-- Unhandled array element types fall through to raw bytes +decodeByType (Array PgVoid) bs = Right $ PersistLiteralEscaped bs +decodeByType (Array PgUnknown) bs = Right $ PersistLiteralEscaped bs + +-- User-defined types: return raw bytes (callers should use the direct path) +decodeByType (Composite _ _) bs = Right $ PersistLiteralEscaped bs +decodeByType (CompositeArray _) bs = Right $ PersistLiteralEscaped bs +decodeByType (Enum _ _) bs = PersistText <$> run PD.text_strict bs +decodeByType (EnumArray _) bs = decodeArray (PersistText <$> PD.text_strict) bs + +-- Unrecognized OID: return raw bytes +decodeByType (Unrecognized _) bs = Right $ PersistLiteralEscaped bs + +run :: PD.Value a -> ByteString -> Either Text a +run = PD.valueParser + +decodeNumeric :: ByteString -> Either Text PersistValue +decodeNumeric bs = do + sci <- run PD.numeric bs + Right $ PersistRational (toRational (sci :: Scientific)) + +-- | Decode JSONB binary format. The binary format has a 1-byte version prefix +-- that @jsonb_bytes@ strips for us. +decodeJsonb :: ByteString -> Either Text ByteString +decodeJsonb = PD.valueParser (PD.jsonb_bytes Right) + +decodeArray :: PD.Value PersistValue -> ByteString -> Either Text PersistValue +decodeArray elemDecoder bs = + PersistList . map nullable <$> + run (PD.array (PD.dimensionArray replicateM (PD.nullableValueArray elemDecoder))) bs + where + nullable = fromMaybe PersistNull + +decodeArrayNumeric :: ByteString -> Either Text PersistValue +decodeArrayNumeric bs = + PersistList . map nullable <$> + run (PD.array (PD.dimensionArray replicateM (PD.nullableValueArray numDecoder))) bs + where + numDecoder = do + sci <- PD.numeric + pure $ PersistRational (toRational (sci :: Scientific)) + nullable = fromMaybe PersistNull + +-- | Convert 16 raw UUID bytes to the text representation +-- "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" as a ByteString. +uuidBytesToText :: ByteString -> ByteString +uuidBytesToText raw + | BS.length raw /= 16 = raw + | otherwise = BS.pack $ concat + [ hexBytes 0 4, [0x2D] + , hexBytes 4 2, [0x2D] + , hexBytes 6 2, [0x2D] + , hexBytes 8 2, [0x2D] + , hexBytes 10 6 + ] + where + hexBytes offset count = + concatMap byteToHex [BS.index raw (offset + i) | i <- [0 .. count - 1]] + byteToHex :: Word8 -> [Word8] + byteToHex b = [hexNibble (b `div` 16), hexNibble (b `mod` 16)] + hexNibble :: Word8 -> Word8 + hexNibble n + | n < 10 = 0x30 + n + | otherwise = 0x61 + n - 10 diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectDecode.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectDecode.hs new file mode 100644 index 000000000..997d01013 --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectDecode.hs @@ -0,0 +1,255 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | Direct field decoding for PostgreSQL. +-- +-- 'prepareField' inspects the column OID once (classifying it into a +-- 'PgType'), then returns a 'FieldRunner' closure that calls the +-- appropriate @postgresql-binary@ decoder on each row without +-- re-dispatching on the OID. +module Database.Persist.Postgresql.Internal.DirectDecode + ( PgRowEnv (..) + , compositeFieldDecode + ) where + +import Data.ByteString (ByteString) +import Data.Int (Int16, Int32, Int64) +import Data.Scientific (Scientific) +import Data.Text (Text) +import qualified Data.Text as T +import Data.Time (Day, TimeOfDay, UTCTime, localTimeToUTC, utc) +import qualified Data.Vector as V +import qualified Database.PostgreSQL.LibPQ as LibPQ +import qualified PostgreSQL.Binary.Decoding as PD + +import Database.Persist.DirectDecode (FieldDecode (..), FieldRunner (..)) +import Database.Persist.Names (FieldNameDB (..)) +import Database.Persist.Postgresql.Internal.PgCodec (PgDecode (..), runPgDecoder) +import Database.Persist.Postgresql.Internal.PgType + +-- | Row environment for PostgreSQL direct decoding. +data PgRowEnv = PgRowEnv + { pgResult :: !LibPQ.Result + , pgRow :: !LibPQ.Row + , pgCols :: !(V.Vector (LibPQ.Column, PgType)) + , pgRowCache :: !OidCache + } + +--------------------------------------------------------------------------- +-- Helpers +--------------------------------------------------------------------------- + +{-# INLINE decodeWith #-} +decodeWith :: PD.Value a -> ByteString -> (Text -> IO r) -> (a -> IO r) -> IO r +decodeWith decoder bs onErr onOk = + case PD.valueParser decoder bs of + Left err -> onErr err + Right v -> onOk v + +{-# INLINE readBytes #-} +readBytes :: PgRowEnv -> Int -> (Text -> IO r) -> (ByteString -> IO r) -> IO r +readBytes env col onErr onOk = do + let (c, _) = pgCols env V.! col + mbs <- LibPQ.getvalue' (pgResult env) (pgRow env) c + case mbs of + Nothing -> onErr "unexpected NULL" + Just bs -> onOk bs + +{-# INLINE mismatch #-} +mismatch :: PgType -> Text -> (Text -> IO r) -> IO r +mismatch pt expected onErr = + onErr ("cannot decode " <> T.pack (show pt) <> " as " <> expected) + +-- | Build a FieldRunner that reads bytes from the row and decodes them. +{-# INLINE mkRunner #-} +mkRunner :: Int -> PD.Value a -> FieldRunner PgRowEnv a +mkRunner col decoder = FieldRunner $ \env onErr onOk -> + readBytes env col onErr $ \bs -> decodeWith decoder bs onErr onOk + +-- | Build a FieldRunner that reads bytes and applies a post-decode conversion. +{-# INLINE mkRunnerWith #-} +mkRunnerWith :: Int -> PD.Value a -> (a -> b) -> FieldRunner PgRowEnv b +mkRunnerWith col decoder f = FieldRunner $ \env onErr onOk -> + readBytes env col onErr $ \bs -> decodeWith decoder bs onErr (onOk . f) + +--------------------------------------------------------------------------- +-- FieldDecode instances: prepareField inspects OID once, returns FieldRunner +--------------------------------------------------------------------------- + +instance FieldDecode PgRowEnv Bool where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgBool -> onOk (mkRunner col PD.bool) + _ -> mismatch pt "Bool" onErr + +instance FieldDecode PgRowEnv Int16 where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgInt2 -> onOk (mkRunner col (PD.int :: PD.Value Int16)) + _ -> mismatch pt "Int16" onErr + +instance FieldDecode PgRowEnv Int32 where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgInt4 -> onOk (mkRunner col (PD.int :: PD.Value Int32)) + Scalar PgInt2 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int16) fromIntegral) + _ -> mismatch pt "Int32" onErr + +instance FieldDecode PgRowEnv Int64 where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgInt8 -> onOk (mkRunner col (PD.int :: PD.Value Int64)) + Scalar PgInt4 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int32) fromIntegral) + Scalar PgInt2 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int16) fromIntegral) + _ -> mismatch pt "Int64" onErr + +instance FieldDecode PgRowEnv Int where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgInt8 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int64) fromIntegral) + Scalar PgInt4 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int32) fromIntegral) + Scalar PgInt2 -> onOk (mkRunnerWith col (PD.int :: PD.Value Int16) fromIntegral) + _ -> mismatch pt "Int" onErr + +instance FieldDecode PgRowEnv Double where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgFloat8 -> onOk (mkRunner col PD.float8) + Scalar PgFloat4 -> onOk (mkRunnerWith col PD.float4 realToFrac) + _ -> mismatch pt "Double" onErr + +instance FieldDecode PgRowEnv Scientific where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgNumeric -> onOk (mkRunner col PD.numeric) + _ -> mismatch pt "Scientific" onErr + +instance FieldDecode PgRowEnv Rational where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgNumeric -> onOk (mkRunnerWith col PD.numeric (toRational :: Scientific -> Rational)) + Scalar PgMoney -> onOk (mkRunnerWith col (PD.int :: PD.Value Int64) fromIntegral) + _ -> mismatch pt "Rational" onErr + +instance FieldDecode PgRowEnv Text where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgText -> onOk (mkRunner col PD.text_strict) + Scalar PgVarchar -> onOk (mkRunner col PD.text_strict) + Scalar PgBpchar -> onOk (mkRunner col PD.text_strict) + Scalar PgChar -> onOk (mkRunner col PD.text_strict) + Scalar PgName -> onOk (mkRunner col PD.text_strict) + _ -> mismatch pt "Text" onErr + +instance FieldDecode PgRowEnv ByteString where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgBytea -> onOk (mkRunner col PD.bytea_strict) + Scalar PgXml -> onOk (mkRunner col PD.bytea_strict) + Scalar PgJson -> onOk (mkRunner col PD.bytea_strict) + Scalar PgJsonb -> onOk (mkRunner col (PD.jsonb_bytes Right)) + Scalar PgBit -> onOk (mkRunner col PD.bytea_strict) + Scalar PgVarbit -> onOk (mkRunner col PD.bytea_strict) + Scalar PgUnknown -> onOk (mkRunner col PD.bytea_strict) + _ -> mismatch pt "ByteString" onErr + +instance FieldDecode PgRowEnv Day where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgDate -> onOk (mkRunner col PD.date) + _ -> mismatch pt "Day" onErr + +instance FieldDecode PgRowEnv TimeOfDay where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgTime -> onOk (mkRunner col PD.time_int) + _ -> mismatch pt "TimeOfDay" onErr + +instance FieldDecode PgRowEnv UTCTime where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + in case pt of + Scalar PgTimestamptz -> onOk (mkRunner col PD.timestamptz_int) + Scalar PgTimestamp -> onOk (mkRunnerWith col PD.timestamp_int (localTimeToUTC utc)) + _ -> mismatch pt "UTCTime" onErr + +--------------------------------------------------------------------------- +-- Maybe wrapper: NULL -> Nothing, otherwise delegate +--------------------------------------------------------------------------- + +instance FieldDecode PgRowEnv a => FieldDecode PgRowEnv (Maybe a) where + prepareField env name col onErr onOk = + prepareField @PgRowEnv @a env name col onErr $ \inner -> + onOk $ FieldRunner $ \env' onErr' onOk' -> do + let (c, _) = pgCols env' V.! col + mbs <- LibPQ.getvalue' (pgResult env') (pgRow env') c + case mbs of + Nothing -> onOk' Nothing + Just _ -> runField inner env' onErr' (onOk' . Just) + +--------------------------------------------------------------------------- +-- Generic array instance via PgDecode +--------------------------------------------------------------------------- + +instance {-# OVERLAPPABLE #-} PgDecode a => FieldDecode PgRowEnv [a] where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + cache = pgRowCache env + decoder = runPgDecoder (pgDecoder @[a]) cache + in case pt of + Array _ -> onOk (mkRunner col decoder) + Unrecognized _ -> onOk (mkRunner col decoder) + _ -> mismatch pt "array" onErr + +instance {-# OVERLAPPING #-} PgDecode a => FieldDecode PgRowEnv [Maybe a] where + prepareField env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + cache = pgRowCache env + decoder = runPgDecoder (pgDecoder @[Maybe a]) cache + in case pt of + Array _ -> onOk (mkRunner col decoder) + Unrecognized _ -> onOk (mkRunner col decoder) + _ -> mismatch pt "array" onErr + +--------------------------------------------------------------------------- +-- Composite helper +--------------------------------------------------------------------------- + +-- | Build a 'FieldDecode' for a PostgreSQL composite type from its +-- 'PgDecode' instance. Accepts any 'Unrecognized' OID (composites +-- have dynamically-assigned OIDs) or a resolved 'Composite' name. +-- +-- @ +-- instance FieldDecode PgRowEnv Address where +-- prepareField = compositeFieldDecode +-- @ +compositeFieldDecode + :: PgDecode a + => PgRowEnv -> FieldNameDB -> Int + -> (Text -> IO r) -> (FieldRunner PgRowEnv a -> IO r) -> IO r +compositeFieldDecode env _ col onErr onOk = + let (_, pt) = pgCols env V.! col + cache = pgRowCache env + decoder = runPgDecoder pgDecoder cache + in case pt of + Composite _ _ -> onOk (mkRunner col decoder) + Unrecognized _ -> onOk (mkRunner col decoder) + _ -> mismatch pt "composite" onErr diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectEncode.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectEncode.hs new file mode 100644 index 000000000..14fb7dafc --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/DirectEncode.hs @@ -0,0 +1,214 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | Direct parameter encoding for PostgreSQL, bypassing 'PersistValue'. +-- +-- Each Haskell type is encoded directly to the binary wire format that +-- @libpq@ expects, using @postgresql-binary@ encoders. No intermediate +-- 'PersistValue' wrapper is allocated. +-- +-- 'PgParam' is an unpacked ADT: 'PgNull' for SQL NULL, +-- 'PgValue' for a typed value with OID, payload, and format. +module Database.Persist.Postgresql.Internal.DirectEncode + ( -- * Parameter type + PgParam (..) + , pgParamToLibPQ + -- * Re-export for convenience + , module Database.Persist.DirectEncode + ) where + +import Data.ByteString (ByteString) +import Data.Int (Int16, Int32, Int64) +import Data.Scientific (Scientific, fromRationalRepetendUnlimited) +import Data.Text (Text) +import qualified Data.Text.Encoding as T +import Data.Time (Day, TimeOfDay, UTCTime) +import Data.Word (Word32) +import qualified Database.PostgreSQL.LibPQ as LibPQ +import qualified PostgreSQL.Binary.Encoding as PE + +import Database.Persist (PersistValue (..), listToJSON, mapToJSON) +import Database.Persist.DirectEncode +import Database.Persist.Postgresql.Internal.Encoding (encodePersistValue) +import Database.Persist.Postgresql.Internal.PgType +import Database.Persist.Types (LiteralType (..)) + +-- | A PostgreSQL query parameter in binary wire format. +-- +-- Unpacked ADT to avoid the boxing overhead of +-- @Maybe (Oid, ByteString, Format)@. +data PgParam + = PgNull + | PgValue + {-# UNPACK #-} !LibPQ.Oid + !ByteString + !LibPQ.Format + +-- | Convert to the representation that @libpq@ expects. +pgParamToLibPQ :: PgParam -> Maybe (LibPQ.Oid, ByteString, LibPQ.Format) +pgParamToLibPQ PgNull = Nothing +pgParamToLibPQ (PgValue oid bs fmt) = Just (oid, bs, fmt) +{-# INLINE pgParamToLibPQ #-} + +bin :: PgScalar -> PE.Encoding -> PgParam +bin s enc = PgValue (scalarOid s) (PE.encodingBytes enc) LibPQ.Binary +{-# INLINE bin #-} + +-- Scalar types + +instance FieldEncode PgParam Bool where + encodeField = bin PgBool . PE.bool + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Int16 where + encodeField = bin PgInt2 . PE.int2_int16 + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Int32 where + encodeField = bin PgInt4 . PE.int4_int32 + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Int64 where + encodeField = bin PgInt8 . PE.int8_int64 + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Int where + encodeField i + | i >= fromIntegral (minBound :: Int32) + && i <= fromIntegral (maxBound :: Int32) + = bin PgInt4 (PE.int4_int32 (fromIntegral i)) + | otherwise + = bin PgInt8 (PE.int8_int64 (fromIntegral i)) + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Double where + encodeField = bin PgFloat8 . PE.float8 + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Scientific where + encodeField = bin PgNumeric . PE.numeric + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Rational where + encodeField r = + let (sci, _) = fromRationalRepetendUnlimited r + in bin PgNumeric (PE.numeric sci) + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Text where + encodeField = bin PgText . PE.text_strict + {-# INLINE encodeField #-} + +instance FieldEncode PgParam ByteString where + encodeField = bin PgBytea . PE.bytea_strict + {-# INLINE encodeField #-} + +instance FieldEncode PgParam Day where + encodeField = bin PgDate . PE.date + {-# INLINE encodeField #-} + +instance FieldEncode PgParam TimeOfDay where + encodeField = bin PgTime . PE.time_int + {-# INLINE encodeField #-} + +instance FieldEncode PgParam UTCTime where + encodeField = bin PgTimestamptz . PE.timestamptz_int + {-# INLINE encodeField #-} + +-- Nullable wrapper: Nothing becomes SQL NULL + +instance FieldEncode PgParam a => FieldEncode PgParam (Maybe a) where + encodeField Nothing = PgNull + encodeField (Just a) = encodeField a + {-# INLINE encodeField #-} + +-- | Backward-compatibility instance: encode a 'PersistValue' directly. +-- Converts through the existing 'encodePersistValue' function. +instance FieldEncode PgParam PersistValue where + encodeField pv = case encodePersistValue pv of + Nothing -> PgNull + Just (oid, bs, fmt) -> PgValue oid bs fmt + {-# INLINE encodeField #-} + +-- Array encoding for = ANY($1) patterns + +instance FieldEncode PgParam [PersistValue] where + encodeField xs = encodeAsArray xs + +-- | Encode a list of 'PersistValue' as a native PostgreSQL array. +-- Reuses the logic from "Database.Persist.Postgresql.Internal.Encoding". +encodeAsArray :: [PersistValue] -> PgParam +encodeAsArray xs = + case inferElementType xs of + Just (elemOidW, arrOid, encElem) -> + let encoding = PE.array_foldable elemOidW encElem xs + in PgValue arrOid (PE.encodingBytes encoding) LibPQ.Binary + Nothing -> + PgValue (LibPQ.Oid 0) "{}" LibPQ.Text + +inferElementType + :: [PersistValue] + -> Maybe (Word32, LibPQ.Oid, PersistValue -> Maybe PE.Encoding) +inferElementType = go + where + go [] = Nothing + go (PersistNull : rest) = go rest + go (PersistInt64 _ : _) = arrayEnc PgInt8 encInt + go (PersistText _ : _) = arrayEnc PgText encText + go (PersistBool _ : _) = arrayEnc PgBool encBool + go (PersistDouble _ : _) = arrayEnc PgFloat8 encDouble + go (PersistByteString _ : _) = arrayEnc PgBytea encBytea + go (PersistDay _ : _) = arrayEnc PgDate encDay + go (PersistTimeOfDay _ : _) = arrayEnc PgTime encTime + go (PersistUTCTime _ : _) = arrayEnc PgTimestamptz encTimestamptz + go (PersistRational _ : _) = arrayEnc PgNumeric encNumeric + go _ = Nothing + + arrayEnc + :: PgScalar + -> (PersistValue -> Maybe PE.Encoding) + -> Maybe (Word32, LibPQ.Oid, PersistValue -> Maybe PE.Encoding) + arrayEnc s enc = do + arrOid <- arrayOid s + pure (scalarOidWord32 s, arrOid, enc) + + encInt (PersistInt64 i) = Just (PE.int8_int64 i) + encInt PersistNull = Nothing + encInt _ = Nothing + + encText (PersistText t) = Just (PE.text_strict t) + encText PersistNull = Nothing + encText _ = Nothing + + encBool (PersistBool b) = Just (PE.bool b) + encBool PersistNull = Nothing + encBool _ = Nothing + + encDouble (PersistDouble d) = Just (PE.float8 d) + encDouble PersistNull = Nothing + encDouble _ = Nothing + + encBytea (PersistByteString bs) = Just (PE.bytea_strict bs) + encBytea PersistNull = Nothing + encBytea _ = Nothing + + encDay (PersistDay d) = Just (PE.date d) + encDay PersistNull = Nothing + encDay _ = Nothing + + encTime (PersistTimeOfDay t) = Just (PE.time_int t) + encTime PersistNull = Nothing + encTime _ = Nothing + + encTimestamptz (PersistUTCTime t) = Just (PE.timestamptz_int t) + encTimestamptz PersistNull = Nothing + encTimestamptz _ = Nothing + + encNumeric (PersistRational r) = + let (sci :: Scientific, _) = fromRationalRepetendUnlimited r + in Just (PE.numeric sci) + encNumeric PersistNull = Nothing + encNumeric _ = Nothing diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Encoding.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Encoding.hs new file mode 100644 index 000000000..19c549c7a --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Encoding.hs @@ -0,0 +1,179 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | Encode 'PersistValue' to binary parameters for @LibPQ.execParams@. +-- +-- Replaces the @P@ newtype and @ToField@ instance from the existing +-- @persistent-postgresql@ package, using @postgresql-binary@ encoders +-- directly instead of going through @postgresql-simple@. +-- +-- 'PersistList' and 'PersistArray' are encoded as native PostgreSQL arrays +-- (not JSONB), with element type inferred from the first non-null element. +module Database.Persist.Postgresql.Internal.Encoding + ( encodePersistValue + ) where + +import Data.ByteString (ByteString) +import Data.Int (Int32) +import Data.Scientific (Scientific, fromRationalRepetendUnlimited) +import qualified Data.Text.Encoding as T +import Data.Word (Word32) +import qualified Database.PostgreSQL.LibPQ as LibPQ +import Database.Persist (PersistValue (..), listToJSON, mapToJSON) +import Database.Persist.Types (LiteralType (..)) +import qualified PostgreSQL.Binary.Encoding as PE + +import Database.Persist.Postgresql.Internal.PgType + +-- | Encode a 'PersistValue' to a libpq parameter triple. +-- +-- Returns 'Nothing' for 'PersistNull'. +-- Returns 'Just (oid, bytes, format)' for non-null values. +-- +-- 'PersistLiteral_ Unescaped' values should have been inlined into the SQL +-- text before calling this function. If encountered here, they are sent +-- as text-format parameters as a fallback. +encodePersistValue :: PersistValue -> Maybe (LibPQ.Oid, ByteString, LibPQ.Format) +encodePersistValue PersistNull = Nothing +encodePersistValue (PersistText t) = + Just $ bin PgText (PE.text_strict t) +encodePersistValue (PersistByteString bs) = + Just $ bin PgBytea (PE.bytea_strict bs) +encodePersistValue (PersistInt64 i) + | i >= fromIntegral (minBound :: Int32) && i <= fromIntegral (maxBound :: Int32) + = Just $ bin PgInt4 (PE.int4_int32 (fromIntegral i)) + | otherwise + = Just $ bin PgInt8 (PE.int8_int64 i) +encodePersistValue (PersistDouble d) = + Just $ bin PgFloat8 (PE.float8 d) +encodePersistValue (PersistRational r) = + let (sci, _) = fromRationalRepetendUnlimited r + in Just $ bin PgNumeric (PE.numeric sci) +encodePersistValue (PersistBool b) = + Just $ bin PgBool (PE.bool b) +encodePersistValue (PersistDay d) = + Just $ bin PgDate (PE.date d) +encodePersistValue (PersistTimeOfDay t) = + Just $ bin PgTime (PE.time_int t) +encodePersistValue (PersistUTCTime t) = + Just $ bin PgTimestamptz (PE.timestamptz_int t) +encodePersistValue (PersistList xs) = + -- PersistList is used by persistent for embedded entities stored in varchar + -- columns as JSON text. Send with OID 0 so PostgreSQL infers the type from + -- column context, handling both text/varchar and jsonb columns correctly. + -- PersistArray is used for native PostgreSQL arrays (e.g. from IN→ANY). + let jsonText = listToJSON xs + in Just (LibPQ.Oid 0, T.encodeUtf8 jsonText, LibPQ.Text) +encodePersistValue (PersistArray xs) = encodeAsArray xs +encodePersistValue (PersistMap m) = + -- Send with OID 0 (unknown) in text format so PostgreSQL infers the type + -- from column context. This handles both text/varchar columns (embedded + -- entities as JSON text) and jsonb columns. Using oidText (25) would fail + -- for jsonb columns since there is no implicit cast from text to jsonb. + let jsonText = mapToJSON m + in Just (LibPQ.Oid 0, T.encodeUtf8 jsonText, LibPQ.Text) +encodePersistValue (PersistLiteral_ DbSpecific s) = + -- Send with OID 0 (unknown) in text format, letting PostgreSQL infer the + -- type from context. This matches postgresql-simple's Unknown behavior and + -- handles UUID, bytea, and other backend-specific types correctly. + Just (LibPQ.Oid 0, s, LibPQ.Text) +encodePersistValue (PersistLiteral_ Escaped e) = + -- Same as DbSpecific: send with OID 0 in text format for PostgreSQL to + -- infer the type. Used for UUIDs and other opaque byte values. + Just (LibPQ.Oid 0, e, LibPQ.Text) +encodePersistValue (PersistLiteral_ Unescaped l) = + -- Unescaped literals are raw SQL fragments. They should normally be + -- inlined into the SQL text by inlineUnescaped before reaching here. + -- As a fallback, send as text format. + Just (scalarOid PgText, l, LibPQ.Text) +encodePersistValue (PersistObjectId _) = + error "Refusing to serialize a PersistObjectId to a PostgreSQL value" + +-- | Construct a binary-format parameter triple from a scalar type and encoding. +bin :: PgScalar -> PE.Encoding -> (LibPQ.Oid, ByteString, LibPQ.Format) +bin s enc = (scalarOid s, PE.encodingBytes enc, LibPQ.Binary) + +-- | Encode a list of 'PersistValue' as a native PostgreSQL array. +-- +-- Infers the element type from the first non-null element. Falls back +-- to a text[] with text-encoded values for empty or heterogeneous lists. +encodeAsArray :: [PersistValue] -> Maybe (LibPQ.Oid, ByteString, LibPQ.Format) +encodeAsArray xs = + case inferElementType xs of + Just (elemOidW, arrOid, encElem) -> + let encoding = PE.array_foldable elemOidW encElem xs + in Just (arrOid, PE.encodingBytes encoding, LibPQ.Binary) + Nothing -> + -- Empty list or all nulls: send as empty text representation with + -- OID 0 so PostgreSQL infers the array type from column context. + -- Using a specific array OID (e.g. text[] 1009) would fail for + -- non-text array columns like int8[], bool[], etc. + Just (LibPQ.Oid 0, "{}", LibPQ.Text) + +-- | Infer the PostgreSQL element type from the first non-null element +-- of a list. Returns (element OID as Word32, array OID, element encoder). +-- +-- Returns 'Nothing' if the list is empty or contains only nulls. +inferElementType + :: [PersistValue] + -> Maybe (Word32, LibPQ.Oid, PersistValue -> Maybe PE.Encoding) +inferElementType = go + where + go [] = Nothing + go (PersistNull : rest) = go rest + go (PersistInt64 _ : _) = arrayEnc PgInt8 encInt + go (PersistText _ : _) = arrayEnc PgText encText + go (PersistBool _ : _) = arrayEnc PgBool encBool + go (PersistDouble _ : _) = arrayEnc PgFloat8 encDouble + go (PersistByteString _ : _) = arrayEnc PgBytea encBytea + go (PersistDay _ : _) = arrayEnc PgDate encDay + go (PersistTimeOfDay _ : _) = arrayEnc PgTime encTime + go (PersistUTCTime _ : _) = arrayEnc PgTimestamptz encTimestamptz + go (PersistRational _ : _) = arrayEnc PgNumeric encNumeric + go _ = Nothing + + arrayEnc + :: PgScalar + -> (PersistValue -> Maybe PE.Encoding) + -> Maybe (Word32, LibPQ.Oid, PersistValue -> Maybe PE.Encoding) + arrayEnc s enc = do + arrOid <- arrayOid s + pure (scalarOidWord32 s, arrOid, enc) + + encInt (PersistInt64 i) = Just (PE.int8_int64 i) + encInt PersistNull = Nothing + encInt _ = Nothing + + encText (PersistText t) = Just (PE.text_strict t) + encText PersistNull = Nothing + encText _ = Nothing + + encBool (PersistBool b) = Just (PE.bool b) + encBool PersistNull = Nothing + encBool _ = Nothing + + encDouble (PersistDouble d) = Just (PE.float8 d) + encDouble PersistNull = Nothing + encDouble _ = Nothing + + encBytea (PersistByteString bs) = Just (PE.bytea_strict bs) + encBytea PersistNull = Nothing + encBytea _ = Nothing + + encDay (PersistDay d) = Just (PE.date d) + encDay PersistNull = Nothing + encDay _ = Nothing + + encTime (PersistTimeOfDay t) = Just (PE.time_int t) + encTime PersistNull = Nothing + encTime _ = Nothing + + encTimestamptz (PersistUTCTime t) = Just (PE.timestamptz_int t) + encTimestamptz PersistNull = Nothing + encTimestamptz _ = Nothing + + encNumeric (PersistRational r) = + let (sci :: Scientific, _) = fromRationalRepetendUnlimited r + in Just (PE.numeric sci) + encNumeric PersistNull = Nothing + encNumeric _ = Nothing diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Migration.hs new file mode 100644 index 000000000..e78c44ad3 --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Migration.hs @@ -0,0 +1,1186 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} + +-- | Generate postgresql migrations for a set of EntityDefs, either from scratch +-- or based on the current state of a database. +module Database.Persist.Postgresql.Internal.Migration where + +import Control.Arrow +import Control.Monad +import Control.Monad.Except +import Control.Monad.IO.Class +import Data.Acquire (with) +import Data.Conduit +import qualified Data.Conduit.List as CL +import Data.Either (partitionEithers) +import Data.FileEmbed (embedFileRelative) +import Data.List as List +import qualified Data.List.NonEmpty as NEL +import Data.Map (Map) +import qualified Data.Map as Map +import Data.Maybe +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Text (Text) +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import Data.Traversable +import Database.Persist.Sql +import qualified Database.Persist.Sql.Util as Util + +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity from its +-- current state in the database to the state described in +-- Haskell. +-- +-- @since 2.17.1.0 +migrateStructured + :: BackendSpecificOverrides + -> [EntityDef] + -> (Text -> IO Statement) + -> EntityDef + -> IO (Either [Text] [AlterDB]) +migrateStructured overrides allDefs getter entity = + migrateEntitiesStructured overrides getter allDefs [entity] + +-- | Returns a structured representation of all of the DB changes required to +-- migrate the listed entities from their current state in the database to the +-- state described in Haskell. This function avoids N+1 queries, so if you +-- have a lot of entities to migrate, it's much faster to use this rather than +-- using 'migrateStructured' in a loop. +-- +-- @since 2.14.1.0 +migrateEntitiesStructured + :: BackendSpecificOverrides + -> (Text -> IO Statement) + -> [EntityDef] + -> [EntityDef] + -> IO (Either [Text] [AlterDB]) +migrateEntitiesStructured overrides getStmt allDefs defsToMigrate = do + r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate) + pure $ case r of + Right schemaState -> + migrateEntitiesFromSchemaState overrides schemaState allDefs defsToMigrate + Left err -> + Left [err] + +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity to the state +-- described in Haskell, assuming it currently does not +-- exist in the database. +-- +-- @since 2.17.1.0 +mockMigrateStructured + :: BackendSpecificOverrides + -> [EntityDef] + -> EntityDef + -> [AlterDB] +mockMigrateStructured overrides allDefs entity = + migrateEntityFromSchemaState overrides EntityDoesNotExist allDefs entity + +-- | In order to ensure that generating migrations is fast and avoids N+1 +-- queries, we split it into two phases. The first phase involves querying the +-- database to gather all of the information we need about the existing schema. +-- The second phase then generates migrations based on the information from the +-- first phase. This data type represents all of the data that's gathered during +-- the first phase: information about the current state of the entities we're +-- migrating in the database. +newtype SchemaState = SchemaState (Map EntityNameDB EntitySchemaState) + deriving (Eq, Show) + +-- | The state of a particular entity (i.e. table) in the database; we generate +-- migrations based on the diff of this versus an EntityDef. +data EntitySchemaState + = -- | The table does not exist in the database + EntityDoesNotExist + | -- | The table does exist in the database + EntityExists ExistingEntitySchemaState + deriving (Eq, Show) + +-- | Information about an existing table in the database +data ExistingEntitySchemaState = ExistingEntitySchemaState + { essColumns :: Map FieldNameDB (Column, (Set ColumnReference)) + -- ^ The columns in this entity, together with the set of foreign key + -- constraints that they are subject to. Usually the ColumnReference list + -- will contain 0-1 elements, but in the event that there are multiple FK + -- constraints applying to a given column in the database we need to keep + -- track of them all because we don't yet know which one has the right name + -- (based on what is in the corresponding model's EntityDef). + -- + -- Note that cReference will be unset for these columns, for the same reason: + -- there may be multiple FK constraints and we don't yet know which one to + -- use. + , essUniqueConstraints :: Map ConstraintNameDB [FieldNameDB] + -- ^ A map of unique constraint names to the columns that are affected by + -- those constraints. + } + deriving (Eq, Show) + +-- | Query a database in order to assemble a SchemaState containing information +-- about each of the entities in the given list. Every entity name in the input +-- should be present in the returned Map. +collectSchemaState + :: (Text -> IO Statement) -> [EntityNameDB] -> IO (Either Text SchemaState) +collectSchemaState getStmt entityNames = runExceptT $ do + existence <- getTableExistence getStmt entityNames + columns <- getColumnsWithoutReferences getStmt entityNames + constraints <- getConstraints getStmt entityNames + foreignKeyReferences <- getForeignKeyReferences getStmt entityNames + + fmap (SchemaState . Map.fromList) $ + for entityNames $ \entityNameDB -> do + tableExists <- case Map.lookup entityNameDB existence of + Just e -> pure e + Nothing -> + throwError + ("Missing entity name from existence map: " <> unEntityNameDB entityNameDB) + + if tableExists + then do + essColumns <- case Map.lookup entityNameDB columns of + Just cols -> + pure $ Map.fromList $ flip map cols $ \c -> + ( cName c + , + ( c + , fromMaybe Set.empty $ + Map.lookup (cName c) =<< Map.lookup entityNameDB foreignKeyReferences + ) + ) + Nothing -> + throwError + ("Missing entity name from columns map: " <> unEntityNameDB entityNameDB) + + let + essUniqueConstraints = fromMaybe Map.empty (Map.lookup entityNameDB constraints) + pure + ( entityNameDB + , EntityExists $ ExistingEntitySchemaState{essColumns, essUniqueConstraints} + ) + else + pure + ( entityNameDB + , EntityDoesNotExist + ) + +runStmt + :: (Show a) + => (Text -> IO Statement) + -> Text + -> [PersistValue] + -> ([PersistValue] -> a) + -> IO [a] +runStmt getStmt sql values process = do + stmt <- getStmt sql + results <- + with + (stmtQuery stmt values) + (\src -> runConduit $ src .| CL.map process .| CL.consume) + pure results + +-- | Check for the existence of each of the input tables. The keys in the +-- returned Map are exactly the entity names in the argument; True means the +-- table exists. +getTableExistence + :: (Text -> IO Statement) + -> [EntityNameDB] + -> ExceptT Text IO (Map EntityNameDB Bool) +getTableExistence getStmt entityNames = do + results <- + liftIO $ + runStmt + getStmt + getTableExistenceSql + [PersistArray (map (PersistText . unEntityNameDB) entityNames)] + processTable + case partitionEithers results of + ([], xs) -> + let + existing = Set.fromList xs + in + pure $ Map.fromList $ map (\n -> (n, Set.member n existing)) entityNames + (errs, _) -> throwError (T.intercalate "\n" errs) + where + getTableExistenceSql = + "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" + <> " AND schemaname != 'information_schema' AND tablename=ANY (?)" + + processTable :: [PersistValue] -> Either Text EntityNameDB + processTable resultRow = do + fmap EntityNameDB $ + case resultRow of + [PersistText tableName] -> + pure tableName + [PersistByteString tableName] -> + pure (T.decodeUtf8 tableName) + other -> + throwError $ T.pack $ "Invalid result from information_schema: " ++ show other + +-- | Get all columns for the listed tables from the database, ignoring foreign +-- key references (those are filled in later). +getColumnsWithoutReferences + :: (Text -> IO Statement) + -> [EntityNameDB] + -> ExceptT Text IO (Map EntityNameDB [Column]) +getColumnsWithoutReferences getStmt entityNames = do + results <- + liftIO $ + runStmt + getStmt + getColumnsSql + [PersistArray (map (PersistText . unEntityNameDB) entityNames)] + processColumn + case partitionEithers results of + ([], xs) -> pure $ Map.fromListWith (++) $ map (second (: [])) xs + (errs, _) -> throwError (T.intercalate "\n" errs) + where + getColumnsSql = + T.concat + [ "SELECT " + , "table_name " + , ",column_name " + , ",is_nullable " + , ",COALESCE(domain_name, udt_name)" -- See DOMAINS below + , ",column_default " + , ",generation_expression " + , ",numeric_precision " + , ",numeric_scale " + , ",character_maximum_length " + , "FROM information_schema.columns " + , "WHERE table_catalog=current_database() " + , "AND table_schema=current_schema() " + , "AND table_name=ANY (?) " + ] + + -- DOMAINS Postgres supports the concept of domains, which are data types + -- with optional constraints. An app might make an "email" domain over the + -- varchar type, with a CHECK that the emails are valid In this case the + -- generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN + -- foo TYPE email This code exists to use the domain name (email), instead + -- of the underlying type (varchar). This is tested in + -- EquivalentTypeTest.hs + processColumn :: [PersistValue] -> Either Text (EntityNameDB, Column) + processColumn resultRow = do + case resultRow of + [ PersistText tableName + , PersistText columnName + , PersistText isNullable + , PersistText typeName + , defaultValue + , generationExpression + , numericPrecision + , numericScale + , maxlen + ] -> mapLeft (addErrorContext tableName columnName) $ do + defaultValue' <- + case defaultValue of + PersistNull -> + pure Nothing + PersistText t -> + pure $ Just t + _ -> + throwError $ T.pack $ "Invalid default column: " ++ show defaultValue + generationExpression' <- + case generationExpression of + PersistNull -> + pure Nothing + PersistText t -> + pure $ Just t + _ -> + throwError $ T.pack $ "Invalid generated column: " ++ show generationExpression + let + typeStr = + case maxlen of + PersistInt64 n -> + T.concat [typeName, "(", T.pack (show n), ")"] + _ -> + typeName + + t <- getType numericPrecision numericScale typeStr + + pure + ( EntityNameDB tableName + , Column + { cName = FieldNameDB columnName + , cNull = isNullable == "YES" + , cSqlType = t + , cDefault = fmap stripSuffixes defaultValue' + , cGenerated = fmap stripSuffixes generationExpression' + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + ) + other -> + Left $ + T.pack $ + "Invalid result from information_schema: " ++ show other + + stripSuffixes t = + loop' + [ "::character varying" + , "::text" + ] + where + loop' [] = t + loop' (p : ps) = + case T.stripSuffix p t of + Nothing -> loop' ps + Just t' -> t' + + getType _ _ "int4" = pure SqlInt32 + getType _ _ "int8" = pure SqlInt64 + getType _ _ "varchar" = pure SqlString + getType _ _ "text" = pure SqlString + getType _ _ "date" = pure SqlDay + getType _ _ "bool" = pure SqlBool + getType _ _ "timestamptz" = pure SqlDayTime + getType _ _ "float4" = pure SqlReal + getType _ _ "float8" = pure SqlReal + getType _ _ "bytea" = pure SqlBlob + getType _ _ "time" = pure SqlTime + getType precision scale "numeric" = getNumeric precision scale + getType _ _ a = pure $ SqlOther a + + getNumeric (PersistInt64 a) (PersistInt64 b) = + pure $ SqlNumeric (fromIntegral a) (fromIntegral b) + getNumeric PersistNull PersistNull = + throwError $ + T.concat + [ "No precision and scale were specified. " + , "Postgres defaults to a maximum scale of 147,455 and precision of 16383," + , " which is probably not what you intended." + , " Specify the values as numeric(total_digits, digits_after_decimal_place)." + ] + getNumeric a b = + throwError $ + T.concat + [ "Can not get numeric field precision. " + , "Expected an integer for both precision and scale, " + , "got: " + , T.pack $ show a + , " and " + , T.pack $ show b + , ", respectively." + , " Specify the values as numeric(total_digits, digits_after_decimal_place)." + ] + +-- cyclist putting a stick into his own wheel meme +addErrorContext :: Text -> Text -> Text -> Text +addErrorContext tableName columnName originalMsg = + T.concat + [ "Error in column " + , tableName + , "." + , columnName + , ": " + , originalMsg + ] + +-- | Get all constraints for the listed tables from the database, except for foreign +-- keys and primary keys (those go in the Column data type) +getConstraints + :: (Text -> IO Statement) + -> [EntityNameDB] + -> ExceptT Text IO (Map EntityNameDB (Map ConstraintNameDB [FieldNameDB])) +getConstraints getStmt entityNames = do + results <- + liftIO $ + runStmt + getStmt + getConstraintsSql + [PersistArray (map (PersistText . unEntityNameDB) entityNames)] + processConstraint + case partitionEithers results of + ([], xs) -> pure $ Map.unionsWith (Map.unionWith (<>)) xs + (errs, _) -> throwError (T.intercalate "\n" errs) + where + getConstraintsSql = + T.concat + [ "SELECT " + , "c.table_name, " + , "c.constraint_name, " + , "c.column_name " + , "FROM information_schema.key_column_usage AS c, " + , "information_schema.table_constraints AS k " + , "WHERE c.table_catalog=current_database() " + , "AND c.table_catalog=k.table_catalog " + , "AND c.table_schema=current_schema() " + , "AND c.table_schema=k.table_schema " + , "AND c.table_name=ANY (?) " + , "AND c.table_name=k.table_name " + , "AND c.constraint_name=k.constraint_name " + , "AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') " + , "ORDER BY c.constraint_name, c.column_name" + ] + + processConstraint + :: [PersistValue] + -> Either Text (Map EntityNameDB (Map ConstraintNameDB [FieldNameDB])) + processConstraint resultRow = do + (tableName, constraintName, columnName) <- case resultRow of + [PersistText tab, PersistText con, PersistText col] -> + pure (tab, con, col) + [PersistByteString tab, PersistByteString con, PersistByteString col] -> + pure (T.decodeUtf8 tab, T.decodeUtf8 con, T.decodeUtf8 col) + o -> + throwError $ T.pack $ "unexpected datatype returned for postgres o=" ++ show o + + pure $ + Map.singleton + (EntityNameDB tableName) + (Map.singleton (ConstraintNameDB constraintName) [FieldNameDB columnName]) + +-- | Get foreign key constraint information for all columns in the supplied +-- tables from the database. We return a list of references per column because +-- there may be duplicate FK constraints in the database. +-- +-- Note that we only care about FKs where the column in question has ordinal +-- position 1 i.e. is the first column appearing in the FK constraint. +-- Eventually we may want to fill this gap so that multi-column FK constraints +-- can be dealt with by this migrator, but for now that is not something that +-- persistent-postgresql handles. +getForeignKeyReferences + :: (Text -> IO Statement) + -> [EntityNameDB] + -> ExceptT Text IO (Map EntityNameDB (Map FieldNameDB (Set ColumnReference))) +getForeignKeyReferences getStmt entityNames = do + results <- + liftIO $ + runStmt + getStmt + getForeignKeyReferencesSql + [PersistArray (map (PersistText . unEntityNameDB) entityNames)] + processForeignKeyReference + case partitionEithers results of + ([], xs) -> pure $ Map.unionsWith (Map.unionWith Set.union) xs + (errs, _) -> throwError (T.intercalate "\n" errs) + where + getForeignKeyReferencesSql = T.decodeUtf8 $(embedFileRelative "sql/getForeignKeyReferences.sql") + + processForeignKeyReference + :: [PersistValue] + -> Either Text (Map EntityNameDB (Map FieldNameDB (Set ColumnReference))) + processForeignKeyReference resultRow = do + ( sourceTableName + , sourceColumnName + , refTableName + , constraintName + , updRule + , delRule + ) <- + case resultRow of + [ PersistText constrName + , PersistText srcTable + , PersistText refTable + , PersistText srcColumn + , PersistText _refColumn + , PersistText updRule + , PersistText delRule + ] -> + pure + ( EntityNameDB srcTable + , FieldNameDB srcColumn + , EntityNameDB refTable + , ConstraintNameDB constrName + , updRule + , delRule + ) + other -> + throwError $ T.pack $ "unexpected row returned for postgres: " ++ show other + + fcOnUpdate <- parseCascade updRule + fcOnDelete <- parseCascade delRule + + let + columnRef = + ColumnReference + { crTableName = refTableName + , crConstraintName = constraintName + , crFieldCascade = + FieldCascade + { fcOnUpdate = Just fcOnUpdate + , fcOnDelete = Just fcOnDelete + } + } + + pure $ + Map.singleton + sourceTableName + (Map.singleton sourceColumnName (Set.singleton columnRef)) + +-- Parse a cascade action as represented in pg_constraint +parseCascade :: Text -> Either Text CascadeAction +parseCascade txt = + case txt of + "a" -> + Right NoAction + "c" -> + Right Cascade + "n" -> + Right SetNull + "d" -> + Right SetDefault + "r" -> + Right Restrict + _ -> + Left $ "Unexpected value in parseCascade: " <> txt + +mapLeft :: (a1 -> a2) -> Either a1 b -> Either a2 b +mapLeft _ (Right x) = Right x +mapLeft f (Left x) = Left (f x) + +migrateEntitiesFromSchemaState + :: BackendSpecificOverrides + -> SchemaState + -> [EntityDef] + -> [EntityDef] + -> Either [Text] [AlterDB] +migrateEntitiesFromSchemaState overrides (SchemaState schemaStateMap) allDefs defsToMigrate = + let + go :: EntityDef -> Either Text [AlterDB] + go entity = do + let + name = getEntityDBName entity + case Map.lookup name schemaStateMap of + Just entityState -> + Right $ migrateEntityFromSchemaState overrides entityState allDefs entity + Nothing -> + Left $ T.pack $ "No entry for entity in schemaState: " <> show name + in + case partitionEithers (map go defsToMigrate) of + ([], xs) -> Right (concat xs) + (errs, _) -> Left errs + +migrateEntityFromSchemaState + :: BackendSpecificOverrides + -> EntitySchemaState + -> [EntityDef] + -> EntityDef + -> [AlterDB] +migrateEntityFromSchemaState overrides schemaState allDefs entity = + case schemaState of + EntityDoesNotExist -> + (addTable newcols entity) : uniques ++ references ++ foreignsAlt + EntityExists ExistingEntitySchemaState{essColumns, essUniqueConstraints} -> + let + (acs, ats) = + getAlters + allDefs + entity + (newcols, udspair) + ( map pickColumnReference (Map.elems essColumns) + , Map.toList essUniqueConstraints + ) + acs' = map (AlterColumn name) acs + ats' = map (AlterTable name) ats + in + acs' ++ ats' + where + name = getEntityDBName entity + (newcols', udefs, fdefs) = postgresMkColumns overrides allDefs entity + newcols = filter (not . safeToRemove entity . cName) newcols' + udspair = map udToPair udefs + + uniques = flip concatMap udspair $ \(uname, ucols) -> + [AlterTable name $ AddUniqueConstraint uname ucols] + references = + mapMaybe + ( \Column{cName, cReference} -> + getAddReference allDefs entity cName =<< cReference + ) + newcols + foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs + + -- HACK! This was added to preserve existing behaviour during a refactor. + -- The migrator currently expects to only see cReference set in the old + -- columns if it is also set in the new ones. It also ignores any existing + -- FK constraints in the database that don't match the expected FK + -- constraint name as defined by the Persistent EntityDef. + -- + -- This means that the migrator sometimes behaves incorrectly for standalone + -- Foreign declarations, like Child in the ForeignKey test in + -- persistent-test, as well as in situations where there are duplicate FK + -- constraints for a given column. + -- + -- See https://github.com/yesodweb/persistent/issues/1611#issuecomment-3613251095 for + -- more info + pickColumnReference (oldCol, oldReferences) = + case List.find (\c -> cName c == cName oldCol) newcols of + Just new -> fromMaybe oldCol $ do + -- Note that if this do block evaluates to Nothing, it means + -- we'll return a Column that has cReference = Nothing - + -- effectively, we are telling the migrator that this particular + -- column has no FK constraints in the DB. + + -- If the persistent models don't define a FK constraint, ignore + -- any FK constraints that might exist in the DB (this is + -- arguably a bug, but it's a pre-existing one) + newRef <- cReference new + + -- If the persistent models _do_ define an FK constraint but + -- there's no matching FK constraint in the DB, we don't have + -- to do anything else here: `getAlters` should handle adding + -- the FK constraint for us + oldRef <- + List.find + (\oldRef -> crConstraintName oldRef == crConstraintName newRef) + oldReferences + + -- Finally, if the persistent models define an FK constraint and + -- an FK constraint of that name exists in the DB, return it, so + -- that `getAlters` can check that the constraint is set up + -- correctly + pure $ oldCol{cReference = Just oldRef} + Nothing -> + -- We have a column that exists in the DB but not in the + -- EntityDef. We can no-op here, since `getAlters` will handle + -- dropping this for us. + oldCol + +-- | Indicates whether a Postgres Column is safe to drop. +-- +-- @since 2.17.1.0 +newtype SafeToRemove = SafeToRemove Bool + deriving (Show, Eq) + +-- | Represents a change to a Postgres column in a DB statement. +-- +-- @since 2.17.1.0 +data AlterColumn + = ChangeType Column SqlType Text + | IsNull Column + | NotNull Column + | AddColumn Column + | Drop Column SafeToRemove + | Default Column Text + | NoDefault Column + | UpdateNullToValue Column Text + | AddReference + EntityNameDB + ConstraintNameDB + (NEL.NonEmpty FieldNameDB) + [Text] + FieldCascade + | DropReference ConstraintNameDB + deriving (Show, Eq) + +-- | Represents a change to a Postgres table in a DB statement. +-- +-- @since 2.17.1.0 +data AlterTable + = AddUniqueConstraint ConstraintNameDB [FieldNameDB] + | DropConstraint ConstraintNameDB + deriving (Show, Eq) + +-- | Represents a change to a Postgres DB in a statement. +-- +-- @since 2.17.1.0 +data AlterDB + = AddTable EntityNameDB EntityIdDef [Column] + | AlterColumn EntityNameDB AlterColumn + | AlterTable EntityNameDB AlterTable + deriving (Show, Eq) + +-- | Create a table if it doesn't exist. +-- +-- @since 2.17.1.0 +addTable :: [Column] -> EntityDef -> AlterDB +addTable cols entity = + AddTable name entityId nonIdCols + where + nonIdCols = + case entityPrimary entity of + Just _ -> + cols + _ -> + filter keepField cols + where + keepField c = + Just (cName c) /= fmap fieldDB (getEntityIdField entity) + && not (safeToRemove entity (cName c)) + entityId = getEntityId entity + name = getEntityDBName entity + +maySerial :: SqlType -> Maybe Text -> Text +maySerial SqlInt64 Nothing = " SERIAL8 " +maySerial sType _ = " " <> showSqlType sType + +mayDefault :: Maybe Text -> Text +mayDefault def = case def of + Nothing -> "" + Just d -> " DEFAULT " <> d + +getAlters + :: [EntityDef] + -> EntityDef + -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) + -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) + -> ([AlterColumn], [AlterTable]) +getAlters defs def (c1, u1) (c2, u2) = + (getAltersC c1 c2, getAltersU u1 u2) + where + getAltersC [] old = + map (\x -> Drop x $ SafeToRemove $ safeToRemove def $ cName x) old + getAltersC (new : news) old = + let + (alters, old') = findAlters defs def new old + in + alters ++ getAltersC news old' + + getAltersU + :: [(ConstraintNameDB, [FieldNameDB])] + -> [(ConstraintNameDB, [FieldNameDB])] + -> [AlterTable] + getAltersU [] old = + map DropConstraint $ filter (not . isManual) $ map fst old + getAltersU ((name, cols) : news) old = + case lookup name old of + Nothing -> + AddUniqueConstraint name cols : getAltersU news old + Just ocols -> + let + old' = filter (\(x, _) -> x /= name) old + in + if sort cols == sort ocols + then getAltersU news old' + else + DropConstraint name + : AddUniqueConstraint name cols + : getAltersU news old' + + -- Don't drop constraints which were manually added. + isManual (ConstraintNameDB x) = "__manual_" `T.isPrefixOf` x + +-- | Postgres' default maximum identifier length in bytes +-- (You can re-compile Postgres with a new limit, but I'm assuming that virtually noone does this). +-- See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +maximumIdentifierLength :: Int +maximumIdentifierLength = 63 + +-- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer +sqlTypeEq :: SqlType -> SqlType -> Bool +sqlTypeEq x y = + let + -- Non exhaustive helper to map postgres aliases to the same name. Based on + -- https://www.postgresql.org/docs/9.5/datatype.html. + -- This prevents needless `ALTER TYPE`s when the type is the same. + normalize "int8" = "bigint" + normalize "serial8" = "bigserial" + normalize v = v + in + normalize (T.toCaseFold (showSqlType x)) + == normalize (T.toCaseFold (showSqlType y)) + +-- We check if we should alter a foreign key. This is almost an equality check, +-- except we consider 'Nothing' and 'Just Restrict' equivalent. +equivalentRef :: Maybe ColumnReference -> Maybe ColumnReference -> Bool +equivalentRef Nothing Nothing = True +equivalentRef (Just cr1) (Just cr2) = + crTableName cr1 == crTableName cr2 + && crConstraintName cr1 == crConstraintName cr2 + && eqCascade (fcOnUpdate $ crFieldCascade cr1) (fcOnUpdate $ crFieldCascade cr2) + && eqCascade (fcOnDelete $ crFieldCascade cr1) (fcOnDelete $ crFieldCascade cr2) + where + eqCascade :: Maybe CascadeAction -> Maybe CascadeAction -> Bool + eqCascade Nothing Nothing = True + eqCascade Nothing (Just Restrict) = True + eqCascade (Just Restrict) Nothing = True + eqCascade (Just cs1) (Just cs2) = cs1 == cs2 + eqCascade _ _ = False +equivalentRef _ _ = False + +-- | Generate the default foreign key constraint name for a given source table and +-- source column name. Note that this function should generally not be used +-- except as an argument to postgresMkColumns, because if you use it in other contexts, +-- you're likely to miss nonstandard constraint names declared in the persistent +-- models files via `constraint=` +refName :: EntityNameDB -> FieldNameDB -> ConstraintNameDB +refName (EntityNameDB table) (FieldNameDB column) = + let + overhead = T.length $ T.concat ["_", "_fkey"] + (fromTable, fromColumn) = shortenNames overhead (T.length table, T.length column) + in + ConstraintNameDB $ + T.concat [T.take fromTable table, "_", T.take fromColumn column, "_fkey"] + where + -- Postgres automatically truncates too long foreign keys to a combination of + -- truncatedTableName + "_" + truncatedColumnName + "_fkey" + -- This works fine for normal use cases, but it creates an issue for Persistent + -- Because after running the migrations, Persistent sees the truncated foreign key constraint + -- doesn't have the expected name, and suggests that you migrate again + -- To workaround this, we copy the Postgres truncation approach before sending foreign key constraints to it. + -- + -- I believe this will also be an issue for extremely long table names, + -- but it's just much more likely to exist with foreign key constraints because they're usually tablename * 2 in length + + -- Approximation of the algorithm Postgres uses to truncate identifiers + -- See makeObjectName https://github.com/postgres/postgres/blob/5406513e997f5ee9de79d4076ae91c04af0c52f6/src/backend/commands/indexcmds.c#L2074-L2080 + shortenNames :: Int -> (Int, Int) -> (Int, Int) + shortenNames overhead (x, y) + | x + y + overhead <= maximumIdentifierLength = (x, y) + | x > y = shortenNames overhead (x - 1, y) + | otherwise = shortenNames overhead (x, y - 1) + +postgresMkColumns + :: BackendSpecificOverrides + -> [EntityDef] + -> EntityDef + -> ([Column], [UniqueDef], [ForeignDef]) +postgresMkColumns overrides allDefs t = + mkColumns allDefs t $ + setBackendSpecificForeignKeyName refName overrides + +-- | Check if a column name is listed as the "safe to remove" in the entity +-- list. +safeToRemove :: EntityDef -> FieldNameDB -> Bool +safeToRemove def (FieldNameDB colName) = + any (elem FieldAttrSafeToRemove . fieldAttrs) $ + filter ((== FieldNameDB colName) . fieldDB) $ + allEntityFields + where + allEntityFields = + getEntityFieldsDatabase def <> case getEntityId def of + EntityIdField fdef -> + [fdef] + _ -> + [] + +udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB]) +udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud) + +-- | Get the references to be added to a table for the given column. +getAddReference + :: [EntityDef] + -> EntityDef + -> FieldNameDB + -> ColumnReference + -> Maybe AlterDB +getAddReference allDefs entity cname cr@ColumnReference{crTableName = s, crConstraintName = constraintName} = do + guard $ Just cname /= fmap fieldDB (getEntityIdField entity) + pure $ + AlterColumn + table + (AddReference s constraintName (cname NEL.:| []) id_ (crFieldCascade cr)) + where + table = getEntityDBName entity + id_ = + fromMaybe + (error $ "Could not find ID of entity " ++ show s) + $ do + entDef <- find ((== s) . getEntityDBName) allDefs + return $ NEL.toList $ Util.dbIdColumnsEsc escapeF entDef + +mkForeignAlt + :: EntityDef + -> ForeignDef + -> Maybe AlterDB +mkForeignAlt entity fdef = case NEL.nonEmpty childfields of + Nothing -> Nothing + Just childfields' -> Just $ AlterColumn tableName_ addReference + where + addReference = + AddReference + (foreignRefTableDBName fdef) + constraintName + childfields' + escapedParentFields + (foreignFieldCascade fdef) + where + tableName_ = getEntityDBName entity + constraintName = + foreignConstraintNameDBName fdef + (childfields, parentfields) = + unzip (map (\((_, b), (_, d)) -> (b, d)) (foreignFields fdef)) + escapedParentFields = + map escapeF parentfields + +escapeC :: ConstraintNameDB -> Text +escapeC = escapeWith escape + +escapeE :: EntityNameDB -> Text +escapeE = escapeWith escape + +escapeF :: FieldNameDB -> Text +escapeF = escapeWith escape + +escape :: Text -> Text +escape s = T.concat ["\"", T.replace "\"" "\"\"" s, "\""] + +showAlterDb :: AlterDB -> (Bool, Text) +showAlterDb (AddTable name entityId nonIdCols) = (False, rawText) + where + idtxt = + case entityId of + EntityIdNaturalKey pdef -> + T.concat + [ " PRIMARY KEY (" + , T.intercalate "," $ map (escapeF . fieldDB) $ NEL.toList $ compositeFields pdef + , ")" + ] + EntityIdField field -> + let + defText = defaultAttribute $ fieldAttrs field + sType = fieldSqlType field + in + T.concat + [ escapeF $ fieldDB field + , maySerial sType defText + , " PRIMARY KEY UNIQUE" + , mayDefault defText + ] + rawText = + T.concat + -- Lower case e: see Database.Persist.Sql.Migration + [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! + , escapeE name + , "(" + , idtxt + , if null nonIdCols then "" else "," + , T.intercalate "," $ map showColumn nonIdCols + , ")" + ] +showAlterDb (AlterColumn t ac) = + (isUnsafe ac, showAlter t ac) + where + isUnsafe (Drop _ (SafeToRemove safeRemove)) = not safeRemove + isUnsafe _ = False +showAlterDb (AlterTable t at) = (False, showAlterTable t at) + +showAlterTable :: EntityNameDB -> AlterTable -> Text +showAlterTable table (AddUniqueConstraint cname cols) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD CONSTRAINT " + , escapeC cname + , " UNIQUE(" + , T.intercalate "," $ map escapeF cols + , ")" + ] +showAlterTable table (DropConstraint cname) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP CONSTRAINT " + , escapeC cname + ] + +showAlter :: EntityNameDB -> AlterColumn -> Text +showAlter table (ChangeType c t extra) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " TYPE " + , showSqlType t + , extra + ] +showAlter table (IsNull c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " DROP NOT NULL" + ] +showAlter table (NotNull c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " SET NOT NULL" + ] +showAlter table (AddColumn col) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD COLUMN " + , showColumn col + ] +showAlter table (Drop c _) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP COLUMN " + , escapeF (cName c) + ] +showAlter table (Default c s) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " SET DEFAULT " + , s + ] +showAlter table (NoDefault c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " DROP DEFAULT" + ] +showAlter table (UpdateNullToValue c s) = + T.concat + [ "UPDATE " + , escapeE table + , " SET " + , escapeF (cName c) + , "=" + , s + , " WHERE " + , escapeF (cName c) + , " IS NULL" + ] +showAlter table (AddReference reftable fkeyname t2 id2 cascade) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD CONSTRAINT " + , escapeC fkeyname + , " FOREIGN KEY(" + , T.intercalate "," $ map escapeF $ NEL.toList t2 + , ") REFERENCES " + , escapeE reftable + , "(" + , T.intercalate "," id2 + , ")" + ] + <> renderFieldCascade cascade +showAlter table (DropReference cname) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP CONSTRAINT " + , escapeC cname + ] + +showColumn :: Column -> Text +showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) = + T.concat + [ escapeF n + , " " + , showSqlType sqlType' + , " " + , if nu then "NULL" else "NOT NULL" + , case def of + Nothing -> "" + Just s -> " DEFAULT " <> s + , case gen of + Nothing -> "" + Just s -> " GENERATED ALWAYS AS (" <> s <> ") STORED" + ] + +showSqlType :: SqlType -> Text +showSqlType SqlString = "VARCHAR" +showSqlType SqlInt32 = "INT4" +showSqlType SqlInt64 = "INT8" +showSqlType SqlReal = "DOUBLE PRECISION" +showSqlType (SqlNumeric s prec) = T.concat ["NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")"] +showSqlType SqlDay = "DATE" +showSqlType SqlTime = "TIME" +showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE" +showSqlType SqlBlob = "BYTEA" +showSqlType SqlBool = "BOOLEAN" +-- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682 +showSqlType (SqlOther (T.toLower -> "integer")) = "INT4" +showSqlType (SqlOther t) = t + +findAlters + :: [EntityDef] + -- ^ The list of all entity definitions that persistent is aware of. + -> EntityDef + -- ^ The entity definition for the entity that we're working on. + -> Column + -- ^ The column that we're searching for potential alterations for, derived + -- from the Persistent EntityDef. That is: this is how we _want_ the column + -- to look, and not necessarily how it actually looks in the database right + -- now. + -> [Column] + -- ^ The columns for this table, as they currently exist in the database. + -> ([AlterColumn], [Column]) +findAlters defs edef newCol oldCols = + case List.find (\c -> cName c == cName newCol) oldCols of + Nothing -> + ([AddColumn newCol] ++ refAdd (cReference newCol), oldCols) + Just + oldCol -> + let + refDrop Nothing = [] + refDrop (Just ColumnReference{crConstraintName = cname}) = + [DropReference cname] + + modRef = + if equivalentRef (cReference oldCol) (cReference newCol) + then [] + else refDrop (cReference oldCol) ++ refAdd (cReference newCol) + modNull = case (cNull newCol, cNull oldCol) of + (True, False) -> do + guard $ Just (cName newCol) /= fmap fieldDB (getEntityIdField edef) + pure (IsNull newCol) + (False, True) -> + let + up = case cDefault newCol of + Nothing -> id + Just s -> (:) (UpdateNullToValue newCol s) + in + up [NotNull newCol] + _ -> [] + modType + | sqlTypeEq (cSqlType newCol) (cSqlType oldCol) = [] + -- When converting from Persistent pre-2.0 databases, we + -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is + -- treated as UTC. + | cSqlType newCol == SqlDayTime && cSqlType oldCol == SqlOther "timestamp" = + [ ChangeType newCol (cSqlType newCol) $ + T.concat + [ " USING " + , escapeF (cName newCol) + , " AT TIME ZONE 'UTC'" + ] + ] + | otherwise = [ChangeType newCol (cSqlType newCol) ""] + modDef = + if cDefault newCol == cDefault oldCol + || isJust (T.stripPrefix "nextval" =<< cDefault oldCol) + then [] + else case cDefault newCol of + Nothing -> [NoDefault newCol] + Just s -> [Default newCol s] + dropSafe = + if safeToRemove edef (cName newCol) + then error "wtf" [Drop newCol (SafeToRemove True)] + else [] + in + ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe + , filter (\c -> cName c /= cName newCol) oldCols + ) + where + refAdd Nothing = [] + -- This check works around a bug where persistent will sometimes + -- generate an erroneous ForeignRef for ID fields. + -- See: https://github.com/yesodweb/persistent/issues/1615 + refAdd _ | fmap fieldDB (getEntityIdField edef) == Just (cName newCol) = [] + refAdd (Just colRef) = + case find ((== crTableName colRef) . getEntityDBName) defs of + Just refdef -> + [ AddReference + (crTableName colRef) + (crConstraintName colRef) + (cName newCol NEL.:| []) + (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) + (crFieldCascade colRef) + ] + Nothing -> + error $ + "could not find the entityDef for reftable[" + ++ show (crTableName colRef) + ++ "]" diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgCodec.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgCodec.hs new file mode 100644 index 000000000..46b3016ee --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgCodec.hs @@ -0,0 +1,367 @@ +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | Composable binary codecs for PostgreSQL types. +-- +-- 'PgDecode' and 'PgEncode' form the /value-level/ codec layer: they +-- operate on raw binary payloads ('PD.Value' \/ 'PE.Encoding') and +-- compose through PostgreSQL's compound-type wire formats (composites, +-- arrays, ranges). +-- +-- The column-level layer ('FieldDecode' \/ 'FieldEncode') delegates +-- to these for compound types while handling OID dispatch itself. +-- +-- Both layers receive an 'OidCache' so they can resolve +-- dynamically-assigned OIDs (composites, enums, domains). +module Database.Persist.Postgresql.Internal.PgCodec + ( -- * Decode + -- ** Reader types + PgDecoder (..) + , PgComposite (..) + -- ** Class + , PgDecode (..) + -- ** Combinators + , pgValue + , pgComposite + , pgField + , pgFieldNullable + , pgArray + , pgArrayNullable + + -- * Encode + -- ** Reader types + , PgEncoder (..) + -- ** Class + , PgEncode (..) + -- ** Combinators + , pgConst + , pgEncodeField + , pgArrayEncoder + ) where + +import Control.Monad (replicateM) +import Data.ByteString (ByteString) +import Data.Int (Int16, Int32, Int64) +import Data.Proxy (Proxy (..)) +import Data.Scientific (Scientific, fromRationalRepetendUnlimited) +import Data.Text (Text) +import Data.Time (Day, TimeOfDay, UTCTime) +import Data.Word (Word32) +import qualified PostgreSQL.Binary.Decoding as PD +import qualified PostgreSQL.Binary.Encoding as PE + +import Database.Persist.Postgresql.Internal.PgType + +--------------------------------------------------------------------------- +-- Decode: reader types +--------------------------------------------------------------------------- + +-- | A value-level binary decoder parameterised by 'OidCache'. +-- +-- Wraps @postgresql-binary@'s 'PD.Value' in a reader so the cache is +-- available to compound decoders (composites that want to validate +-- sub-field OIDs, etc.) without threading it manually. +newtype PgDecoder a = PgDecoder { runPgDecoder :: OidCache -> PD.Value a } + +instance Functor PgDecoder where + fmap f (PgDecoder g) = PgDecoder $ \c -> fmap f (g c) + {-# INLINE fmap #-} + +instance Applicative PgDecoder where + pure a = PgDecoder $ \_ -> pure a + {-# INLINE pure #-} + PgDecoder f <*> PgDecoder a = PgDecoder $ \c -> f c <*> a c + {-# INLINE (<*>) #-} + +instance Monad PgDecoder where + PgDecoder ma >>= k = PgDecoder $ \c -> do + a <- ma c + runPgDecoder (k a) c + {-# INLINE (>>=) #-} + +-- | A composite-field decoder parameterised by 'OidCache'. +-- +-- Wraps @postgresql-binary@'s 'PD.Composite' applicative. Build one +-- with 'pgField' \/ 'pgFieldNullable', then convert to a 'PgDecoder' +-- with 'pgComposite'. +newtype PgComposite a = PgComposite (OidCache -> PD.Composite a) + +instance Functor PgComposite where + fmap f (PgComposite g) = PgComposite $ \c -> fmap f (g c) + {-# INLINE fmap #-} + +instance Applicative PgComposite where + pure a = PgComposite $ \_ -> pure a + {-# INLINE pure #-} + PgComposite f <*> PgComposite a = PgComposite $ \c -> f c <*> a c + {-# INLINE (<*>) #-} + +--------------------------------------------------------------------------- +-- Decode: combinators +--------------------------------------------------------------------------- + +-- | Lift a raw @postgresql-binary@ decoder that doesn't need the cache. +pgValue :: PD.Value a -> PgDecoder a +pgValue v = PgDecoder $ \_ -> v +{-# INLINE pgValue #-} + +-- | Build a 'PgDecoder' for a PostgreSQL composite type from its fields. +-- +-- @ +-- instance PgDecode Address where +-- pgDecoder = pgComposite $ Address +-- \<$\> pgField pgDecoder +-- \<*\> pgField pgDecoder +-- \<*\> pgField pgDecoder +-- @ +pgComposite :: PgComposite a -> PgDecoder a +pgComposite (PgComposite f) = PgDecoder $ \c -> PD.composite (f c) +{-# INLINE pgComposite #-} + +-- | Decode a non-nullable field inside a composite. +pgField :: PgDecoder a -> PgComposite a +pgField (PgDecoder f) = PgComposite $ \c -> + PD.valueComposite (f c) +{-# INLINE pgField #-} + +-- | Decode a nullable field inside a composite. +pgFieldNullable :: PgDecoder a -> PgComposite (Maybe a) +pgFieldNullable (PgDecoder f) = PgComposite $ \c -> + PD.nullableValueComposite (f c) +{-# INLINE pgFieldNullable #-} + +-- | Decode a one-dimensional PostgreSQL array of non-nullable elements. +pgArray :: PgDecoder a -> PgDecoder [a] +pgArray (PgDecoder f) = PgDecoder $ \c -> + PD.array $ PD.dimensionArray replicateM (PD.valueArray (f c)) +{-# INLINE pgArray #-} + +-- | Decode a one-dimensional PostgreSQL array of nullable elements. +pgArrayNullable :: PgDecoder a -> PgDecoder [Maybe a] +pgArrayNullable (PgDecoder f) = PgDecoder $ \c -> + PD.array $ PD.dimensionArray replicateM (PD.nullableValueArray (f c)) +{-# INLINE pgArrayNullable #-} + +--------------------------------------------------------------------------- +-- Decode: class +--------------------------------------------------------------------------- + +-- | Value-level binary decoder for a Haskell type. +-- +-- Scalar instances wrap the corresponding @postgresql-binary@ decoder. +-- Compound instances compose via 'pgComposite', 'pgField', 'pgArray'. +class PgDecode a where + pgDecoder :: PgDecoder a + +-- Scalars + +instance PgDecode Bool where + pgDecoder = pgValue PD.bool + {-# INLINE pgDecoder #-} + +instance PgDecode Int16 where + pgDecoder = pgValue PD.int + {-# INLINE pgDecoder #-} + +instance PgDecode Int32 where + pgDecoder = pgValue PD.int + {-# INLINE pgDecoder #-} + +instance PgDecode Int64 where + pgDecoder = pgValue PD.int + {-# INLINE pgDecoder #-} + +instance PgDecode Int where + pgDecoder = pgValue (fromIntegral <$> (PD.int :: PD.Value Int64)) + {-# INLINE pgDecoder #-} + +instance PgDecode Double where + pgDecoder = pgValue PD.float8 + {-# INLINE pgDecoder #-} + +instance PgDecode Float where + pgDecoder = pgValue PD.float4 + {-# INLINE pgDecoder #-} + +instance PgDecode Scientific where + pgDecoder = pgValue PD.numeric + {-# INLINE pgDecoder #-} + +instance PgDecode Rational where + pgDecoder = pgValue (toRational <$> PD.numeric) + {-# INLINE pgDecoder #-} + +instance PgDecode Text where + pgDecoder = pgValue PD.text_strict + {-# INLINE pgDecoder #-} + +instance PgDecode ByteString where + pgDecoder = pgValue PD.bytea_strict + {-# INLINE pgDecoder #-} + +instance PgDecode Day where + pgDecoder = pgValue PD.date + {-# INLINE pgDecoder #-} + +instance PgDecode TimeOfDay where + pgDecoder = pgValue PD.time_int + {-# INLINE pgDecoder #-} + +instance PgDecode UTCTime where + pgDecoder = pgValue PD.timestamptz_int + {-# INLINE pgDecoder #-} + +-- Compound + +instance PgDecode a => PgDecode [a] where + pgDecoder = pgArray pgDecoder + {-# INLINE pgDecoder #-} + +instance {-# OVERLAPPING #-} PgDecode a => PgDecode [Maybe a] where + pgDecoder = pgArrayNullable pgDecoder + {-# INLINE pgDecoder #-} + +--------------------------------------------------------------------------- +-- Encode: reader types +--------------------------------------------------------------------------- + +-- | A value-level binary encoder parameterised by 'OidCache'. +-- +-- The cache is needed for composite and enum types whose OIDs are +-- assigned dynamically by PostgreSQL. +newtype PgEncoder a = PgEncoder { runPgEncoder :: OidCache -> a -> PE.Encoding } + +instance Semigroup (PgEncoder a) where + PgEncoder f <> PgEncoder g = PgEncoder $ \c a -> + f c a <> g c a + {-# INLINE (<>) #-} + +-- | Lift a pure encoder that doesn't need the cache. +pgConst :: (a -> PE.Encoding) -> PgEncoder a +pgConst f = PgEncoder $ \_ -> f +{-# INLINE pgConst #-} + +--------------------------------------------------------------------------- +-- Encode: combinators +--------------------------------------------------------------------------- + +-- | Encode a value as a composite field (OID + payload). +pgEncodeField :: forall a. PgEncode a => OidCache -> a -> PE.Composite +pgEncodeField cache a = + PE.field (pgTypeOid' cache (Proxy @a)) (runPgEncoder pgEncoder cache a) +{-# INLINE pgEncodeField #-} + +-- | Build a 'PgEncoder' for a one-dimensional array of non-nullable elements. +pgArrayEncoder :: forall a. PgEncode a => PgEncoder [a] +pgArrayEncoder = PgEncoder $ \cache xs -> + PE.array_foldable + (pgTypeOid' cache (Proxy @a)) + (\x -> Just (runPgEncoder pgEncoder cache x)) + xs +{-# INLINE pgArrayEncoder #-} + +-- | Build a 'PgEncoder' for a composite type from a function that +-- produces a 'PE.Composite' builder. +pgCompositeEncoder :: (OidCache -> a -> PE.Composite) -> PgEncoder a +pgCompositeEncoder f = PgEncoder $ \cache a -> PE.composite (f cache a) +{-# INLINE pgCompositeEncoder #-} + +--------------------------------------------------------------------------- +-- Encode: class +--------------------------------------------------------------------------- + +-- | Value-level binary encoder for a Haskell type. +-- +-- 'pgEncoder' produces the binary payload. 'pgTypeOid'' returns the +-- PostgreSQL OID for the type, needed when this value appears as an +-- element inside a composite or array. +class PgEncode a where + pgEncoder :: PgEncoder a + pgTypeOid' :: OidCache -> Proxy a -> Word32 + +-- Scalars + +instance PgEncode Bool where + pgEncoder = pgConst PE.bool + pgTypeOid' _ _ = scalarOidWord32 PgBool + {-# INLINE pgEncoder #-} + +instance PgEncode Int16 where + pgEncoder = pgConst PE.int2_int16 + pgTypeOid' _ _ = scalarOidWord32 PgInt2 + {-# INLINE pgEncoder #-} + +instance PgEncode Int32 where + pgEncoder = pgConst PE.int4_int32 + pgTypeOid' _ _ = scalarOidWord32 PgInt4 + {-# INLINE pgEncoder #-} + +instance PgEncode Int64 where + pgEncoder = pgConst PE.int8_int64 + pgTypeOid' _ _ = scalarOidWord32 PgInt8 + {-# INLINE pgEncoder #-} + +instance PgEncode Int where + pgEncoder = pgConst $ \i -> PE.int8_int64 (fromIntegral i) + pgTypeOid' _ _ = scalarOidWord32 PgInt8 + {-# INLINE pgEncoder #-} + +instance PgEncode Double where + pgEncoder = pgConst PE.float8 + pgTypeOid' _ _ = scalarOidWord32 PgFloat8 + {-# INLINE pgEncoder #-} + +instance PgEncode Float where + pgEncoder = pgConst PE.float4 + pgTypeOid' _ _ = scalarOidWord32 PgFloat4 + {-# INLINE pgEncoder #-} + +instance PgEncode Scientific where + pgEncoder = pgConst PE.numeric + pgTypeOid' _ _ = scalarOidWord32 PgNumeric + {-# INLINE pgEncoder #-} + +instance PgEncode Rational where + pgEncoder = pgConst $ \r -> + let (sci, _) = fromRationalRepetendUnlimited r + in PE.numeric sci + pgTypeOid' _ _ = scalarOidWord32 PgNumeric + {-# INLINE pgEncoder #-} + +instance PgEncode Text where + pgEncoder = pgConst PE.text_strict + pgTypeOid' _ _ = scalarOidWord32 PgText + {-# INLINE pgEncoder #-} + +instance PgEncode ByteString where + pgEncoder = pgConst PE.bytea_strict + pgTypeOid' _ _ = scalarOidWord32 PgBytea + {-# INLINE pgEncoder #-} + +instance PgEncode Day where + pgEncoder = pgConst PE.date + pgTypeOid' _ _ = scalarOidWord32 PgDate + {-# INLINE pgEncoder #-} + +instance PgEncode TimeOfDay where + pgEncoder = pgConst PE.time_int + pgTypeOid' _ _ = scalarOidWord32 PgTime + {-# INLINE pgEncoder #-} + +instance PgEncode UTCTime where + pgEncoder = pgConst PE.timestamptz_int + pgTypeOid' _ _ = scalarOidWord32 PgTimestamptz + {-# INLINE pgEncoder #-} + +-- Compound + +instance PgEncode a => PgEncode [a] where + pgEncoder = pgArrayEncoder + pgTypeOid' cache _ = case arrayOidWord32 (pgTypeOid' cache (Proxy @a)) of + Just oid -> oid + Nothing -> 0 + {-# INLINE pgEncoder #-} diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgType.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgType.hs new file mode 100644 index 000000000..c792b691d --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/PgType.hs @@ -0,0 +1,274 @@ +{-# LANGUAGE ViewPatterns #-} + +-- | Unified PostgreSQL type\/OID classification. +-- +-- Single source of truth for the mapping between PostgreSQL type OIDs and +-- the Haskell-side 'PgType' representation. Both "Encoding" and "Decoding" +-- modules reference this module instead of maintaining their own OID tables. +-- +-- 'fromOid' is designed for use as a view pattern: +-- +-- @ +-- f (fromOid -> Scalar PgBool) = … +-- f (fromOid -> Array PgInt8) = … +-- f (fromOid -> Unrecognized _) = … +-- @ +module Database.Persist.Postgresql.Internal.PgType + ( -- * Type classification + PgScalar (..) + , PgType (..) + -- * OID → PgType (view pattern) + , fromOid + -- * PgType → OID + , pgTypeOid + , scalarOid + , scalarOidWord32 + , arrayOid + , arrayOidWord32 + -- * OID cache for custom types + , OidCache + , emptyOidCache + , resolveOid + ) where + +import Data.IntMap.Strict (IntMap) +import qualified Data.IntMap.Strict as IntMap +import Data.Text (Text) +import Data.Word (Word32) +import Foreign.C.Types (CUInt) +import qualified Database.PostgreSQL.LibPQ as LibPQ + +-- | Known PostgreSQL scalar types. +data PgScalar + = PgBool + | PgBytea + | PgChar + | PgName + | PgInt2 + | PgInt4 + | PgInt8 + | PgFloat4 + | PgFloat8 + | PgText + | PgXml + | PgMoney + | PgBpchar + | PgVarchar + | PgDate + | PgTime + | PgTimestamp + | PgTimestamptz + | PgInterval + | PgBit + | PgVarbit + | PgNumeric + | PgVoid + | PgJson + | PgJsonb + | PgUnknown + -- ^ The PostgreSQL @unknown@ pseudotype (OID 705), used for untyped + -- string literals. Not to be confused with 'Unrecognized'. + | PgUuid + deriving (Eq, Ord, Show) + +-- | Classified PostgreSQL type: a known scalar, a known array of a scalar, +-- a user-defined composite\/enum, or an OID we don't have a built-in +-- mapping for. +data PgType + = Scalar !PgScalar + | Array !PgScalar + | Composite !Text !Int + -- ^ User-defined composite type. Carries the type name and the + -- array-type OID (0 if unknown), looked up from @pg_type@ at + -- connection time. + | CompositeArray !Text + -- ^ Array of a user-defined composite. Carries the element type name. + | Enum !Text !Int + -- ^ User-defined enum. Carries the type name and the array-type OID. + | EnumArray !Text + -- ^ Array of a user-defined enum. Carries the element type name. + | Unrecognized !Int + deriving (Eq, Show) + +--------------------------------------------------------------------------- +-- OID → PgType +--------------------------------------------------------------------------- + +-- | Classify a @libpq@ 'LibPQ.Oid' into a 'PgType'. +-- +-- Intended for use as a view pattern so that callers can pattern-match on +-- structured types rather than raw integers. +fromOid :: LibPQ.Oid -> PgType +fromOid (LibPQ.Oid n) = fromCUInt n + +fromCUInt :: CUInt -> PgType +fromCUInt 16 = Scalar PgBool +fromCUInt 17 = Scalar PgBytea +fromCUInt 18 = Scalar PgChar +fromCUInt 19 = Scalar PgName +fromCUInt 20 = Scalar PgInt8 +fromCUInt 21 = Scalar PgInt2 +fromCUInt 23 = Scalar PgInt4 +fromCUInt 25 = Scalar PgText +fromCUInt 114 = Scalar PgJson +fromCUInt 142 = Scalar PgXml +fromCUInt 700 = Scalar PgFloat4 +fromCUInt 701 = Scalar PgFloat8 +fromCUInt 705 = Scalar PgUnknown +fromCUInt 790 = Scalar PgMoney +fromCUInt 1042 = Scalar PgBpchar +fromCUInt 1043 = Scalar PgVarchar +fromCUInt 1082 = Scalar PgDate +fromCUInt 1083 = Scalar PgTime +fromCUInt 1114 = Scalar PgTimestamp +fromCUInt 1184 = Scalar PgTimestamptz +fromCUInt 1186 = Scalar PgInterval +fromCUInt 1560 = Scalar PgBit +fromCUInt 1562 = Scalar PgVarbit +fromCUInt 1700 = Scalar PgNumeric +fromCUInt 2278 = Scalar PgVoid +fromCUInt 2950 = Scalar PgUuid +fromCUInt 3802 = Scalar PgJsonb +-- Arrays +fromCUInt 143 = Array PgXml +fromCUInt 199 = Array PgJson +fromCUInt 791 = Array PgMoney +fromCUInt 1000 = Array PgBool +fromCUInt 1001 = Array PgBytea +fromCUInt 1002 = Array PgChar +fromCUInt 1003 = Array PgName +fromCUInt 1005 = Array PgInt2 +fromCUInt 1007 = Array PgInt4 +fromCUInt 1009 = Array PgText +fromCUInt 1014 = Array PgBpchar +fromCUInt 1015 = Array PgVarchar +fromCUInt 1016 = Array PgInt8 +fromCUInt 1021 = Array PgFloat4 +fromCUInt 1022 = Array PgFloat8 +fromCUInt 1115 = Array PgTimestamp +fromCUInt 1182 = Array PgDate +fromCUInt 1183 = Array PgTime +fromCUInt 1185 = Array PgTimestamptz +fromCUInt 1187 = Array PgInterval +fromCUInt 1231 = Array PgNumeric +fromCUInt 1561 = Array PgBit +fromCUInt 1563 = Array PgVarbit +fromCUInt 2951 = Array PgUuid +fromCUInt 3807 = Array PgJsonb +fromCUInt n = Unrecognized (fromIntegral n) + +--------------------------------------------------------------------------- +-- PgType → OID +--------------------------------------------------------------------------- + +-- | The OID for a 'PgType'. +pgTypeOid :: PgType -> LibPQ.Oid +pgTypeOid (Scalar s) = scalarOid s +pgTypeOid (Array s) = maybe (LibPQ.Oid 0) id (arrayOid s) +pgTypeOid (Composite _ _) = LibPQ.Oid 0 +pgTypeOid (CompositeArray _) = LibPQ.Oid 0 +pgTypeOid (Enum _ _) = LibPQ.Oid 0 +pgTypeOid (EnumArray _) = LibPQ.Oid 0 +pgTypeOid (Unrecognized n) = LibPQ.Oid (fromIntegral n) + +-- | The scalar OID. +scalarOid :: PgScalar -> LibPQ.Oid +scalarOid = LibPQ.Oid . scalarCUInt + +-- | The scalar OID as 'Word32', for @postgresql-binary@ array element +-- encoding which takes the element OID as a 'Word32'. +scalarOidWord32 :: PgScalar -> Word32 +scalarOidWord32 = fromIntegral . scalarCUInt + +scalarCUInt :: PgScalar -> CUInt +scalarCUInt PgBool = 16 +scalarCUInt PgBytea = 17 +scalarCUInt PgChar = 18 +scalarCUInt PgName = 19 +scalarCUInt PgInt2 = 21 +scalarCUInt PgInt4 = 23 +scalarCUInt PgInt8 = 20 +scalarCUInt PgFloat4 = 700 +scalarCUInt PgFloat8 = 701 +scalarCUInt PgText = 25 +scalarCUInt PgXml = 142 +scalarCUInt PgMoney = 790 +scalarCUInt PgBpchar = 1042 +scalarCUInt PgVarchar = 1043 +scalarCUInt PgDate = 1082 +scalarCUInt PgTime = 1083 +scalarCUInt PgTimestamp = 1114 +scalarCUInt PgTimestamptz = 1184 +scalarCUInt PgInterval = 1186 +scalarCUInt PgBit = 1560 +scalarCUInt PgVarbit = 1562 +scalarCUInt PgNumeric = 1700 +scalarCUInt PgVoid = 2278 +scalarCUInt PgJson = 114 +scalarCUInt PgJsonb = 3802 +scalarCUInt PgUnknown = 705 +scalarCUInt PgUuid = 2950 + +-- | The array OID corresponding to a scalar type, if one exists. +arrayOid :: PgScalar -> Maybe LibPQ.Oid +arrayOid = fmap LibPQ.Oid . arrayCUInt + +arrayCUInt :: PgScalar -> Maybe CUInt +arrayCUInt PgBool = Just 1000 +arrayCUInt PgBytea = Just 1001 +arrayCUInt PgChar = Just 1002 +arrayCUInt PgName = Just 1003 +arrayCUInt PgInt2 = Just 1005 +arrayCUInt PgInt4 = Just 1007 +arrayCUInt PgInt8 = Just 1016 +arrayCUInt PgFloat4 = Just 1021 +arrayCUInt PgFloat8 = Just 1022 +arrayCUInt PgText = Just 1009 +arrayCUInt PgXml = Just 143 +arrayCUInt PgMoney = Just 791 +arrayCUInt PgBpchar = Just 1014 +arrayCUInt PgVarchar = Just 1015 +arrayCUInt PgDate = Just 1182 +arrayCUInt PgTime = Just 1183 +arrayCUInt PgTimestamp = Just 1115 +arrayCUInt PgTimestamptz = Just 1185 +arrayCUInt PgInterval = Just 1187 +arrayCUInt PgBit = Just 1561 +arrayCUInt PgVarbit = Just 1563 +arrayCUInt PgNumeric = Just 1231 +arrayCUInt PgJson = Just 199 +arrayCUInt PgJsonb = Just 3807 +arrayCUInt PgUuid = Just 2951 +arrayCUInt PgVoid = Nothing +arrayCUInt PgUnknown = Nothing + +-- | Reverse lookup: given a scalar element OID (as 'Word32'), return +-- the corresponding array OID. Used by 'PgEncode' instances for +-- @[a]@ to tag the encoded array parameter. +arrayOidWord32 :: Word32 -> Maybe Word32 +arrayOidWord32 w = case fromCUInt (fromIntegral w) of + Scalar s -> fromIntegral . (\(LibPQ.Oid o) -> o) <$> arrayOid s + _ -> Nothing + +--------------------------------------------------------------------------- +-- OID cache +--------------------------------------------------------------------------- + +-- | Cache for dynamically-discovered type OIDs (custom enums, composites, +-- domains, etc.) that aren't in the built-in table. Keyed by raw OID +-- integer. +-- +-- Currently unused by the decoding path, but wired into 'PgConn' so that +-- future custom-type support can populate and consult it without changing +-- any signatures. +type OidCache = IntMap PgType + +emptyOidCache :: OidCache +emptyOidCache = IntMap.empty + +-- | Classify an OID, falling back to the cache for OIDs not in the +-- built-in table. +resolveOid :: OidCache -> LibPQ.Oid -> PgType +resolveOid cache oid = case fromOid oid of + Unrecognized n -> maybe (Unrecognized n) id (IntMap.lookup n cache) + known -> known diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Placeholders.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Placeholders.hs new file mode 100644 index 000000000..6f45e71fb --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Internal/Placeholders.hs @@ -0,0 +1,186 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE OverloadedStrings #-} + +-- | Rewrite @?@ placeholders to PostgreSQL @$1, $2, ...@ numbered parameters. +-- +-- This replaces the functionality that @postgresql-simple@ provides internally. +-- We must be careful to: +-- +-- * Replace @?@ with @$N@ +-- * Replace @??@ with a literal @?@ (used by persistent for column expansion, e.g. @RETURNING ??@) +-- * Skip contents of string literals (@\'...\'@), identifiers (@\"...\"@), +-- and comments (@--@ line comments, @/* ... */@ block comments) +module Database.Persist.Postgresql.Internal.Placeholders + ( rewritePlaceholders + -- * Numbered parameter detection + , ParamStyle (..) + , detectParamStyle + -- * SQL lexer helpers (shared with Pipeline.hs) + , skipStringLiteral + , skipQuotedIdent + , skipBlockComment + ) where + +import Data.ByteString (ByteString) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Builder as BB +import qualified Data.ByteString.Lazy as LBS +import Data.Char (isDigit, ord) +import Data.Word (Word8) + + +-- | Rewrite @?@-style placeholders to @$1, $2, ...@ numbered parameters. +-- +-- Returns the rewritten SQL and the total number of parameters found. +-- +-- @??@ is replaced with a literal @?@ (not a parameter). +rewritePlaceholders :: ByteString -> (ByteString, Int) +rewritePlaceholders input = + let (builder, count) = go input 0 mempty + in (LBS.toStrict (BB.toLazyByteString builder), count) + where + go :: ByteString -> Int -> BB.Builder -> (BB.Builder, Int) + go bs !n !acc + | BS.null bs = (acc, n) + | otherwise = + let c = BS.head bs + rest = BS.tail bs + in case c of + -- String literal: copy through including contents + 39 {- '\'' -} -> + let (literal, after) = skipStringLiteral rest + in go after n (acc <> BB.word8 39 <> BB.byteString literal) + + -- Quoted identifier: copy through including contents + 34 {- '"' -} -> + let (ident, after) = skipQuotedIdent rest + in go after n (acc <> BB.word8 34 <> BB.byteString ident) + + -- Line comment: copy through to end of line + 45 {- '-' -} + | not (BS.null rest) && BS.head rest == 45 -> + let (comment, after) = BS.break (== 10) rest + in go after n (acc <> BB.word8 45 <> BB.byteString comment) + + -- Block comment: copy through to closing */ + 47 {- '/' -} + | not (BS.null rest) && BS.head rest == 42 {- '*' -} -> + let (comment, after) = skipBlockComment (BS.tail rest) 1 + in go after n (acc <> BB.byteString "/*" <> BB.byteString comment) + + -- Placeholder + 63 {- '?' -} + -- ?? -> literal ? + | not (BS.null rest) && BS.head rest == 63 -> + go (BS.tail rest) n (acc <> BB.word8 63) + -- ? -> $N + | otherwise -> + let n' = n + 1 + in go rest n' (acc <> BB.word8 36 <> BB.intDec n') + + -- Anything else: copy through + _ -> go rest n (acc <> BB.word8 c) + +-- SQL lexer helpers. These use index-based scanning for efficiency +-- (no intermediate ByteString allocations per character), only doing +-- a single splitAt at the boundary. + +-- | Skip a string literal (everything after the opening @\'@). +-- Returns (consumed bytes including closing @\'@, remaining input). +-- Handles @\'\'@ escape sequences. +skipStringLiteral :: ByteString -> (ByteString, ByteString) +skipStringLiteral = goStr 0 + where + goStr !i bs + | i >= BS.length bs = (bs, BS.empty) + | BS.index bs i == 39 {- '\'' -} = + if i + 1 < BS.length bs && BS.index bs (i + 1) == 39 + then goStr (i + 2) bs + else BS.splitAt (i + 1) bs + | otherwise = goStr (i + 1) bs + +-- | Skip a quoted identifier (everything after the opening @\"@). +-- Returns (consumed bytes including closing @\"@, remaining input). +-- Handles @\"\"@ escape sequences. +skipQuotedIdent :: ByteString -> (ByteString, ByteString) +skipQuotedIdent = goId 0 + where + goId !i bs + | i >= BS.length bs = (bs, BS.empty) + | BS.index bs i == 34 {- '"' -} = + if i + 1 < BS.length bs && BS.index bs (i + 1) == 34 + then goId (i + 2) bs + else BS.splitAt (i + 1) bs + | otherwise = goId (i + 1) bs + +-- | Skip a block comment (after the opening delimiter). +-- Handles nested block comments. Returns (consumed bytes including +-- the closing delimiter, remaining input). +skipBlockComment :: ByteString -> Int -> (ByteString, ByteString) +skipBlockComment = goBlk 0 + where + goBlk !i bs !depth + | i >= BS.length bs = (bs, BS.empty) + | i + 1 < BS.length bs + , BS.index bs i == 42 {- '*' -}, BS.index bs (i + 1) == 47 {- '/' -} = + if depth <= 1 + then BS.splitAt (i + 2) bs + else goBlk (i + 2) bs (depth - 1) + | i + 1 < BS.length bs + , BS.index bs i == 47 {- '/' -}, BS.index bs (i + 1) == 42 {- '*' -} = + goBlk (i + 2) bs (depth + 1) + | otherwise = goBlk (i + 1) bs depth + +--------------------------------------------------------------------------- +-- Numbered parameter detection +--------------------------------------------------------------------------- + +-- | Parameter placeholder style detected in a SQL statement. +data ParamStyle + = QuestionMarkParams + -- ^ Traditional @?@ placeholders (persistent style). Each @?@ is a + -- distinct positional parameter. + | NumberedParams !Int + -- ^ PostgreSQL-native @$N@ placeholders. The 'Int' is the highest + -- @$N@ found. A single @$1@ can appear multiple times in the query, + -- referencing the same parameter value. + deriving (Eq, Show) + +-- | Scan a SQL statement (skipping string literals, quoted identifiers, +-- and comments) for @$N@ patterns. If any are found, return +-- 'NumberedParams' with the max N; otherwise 'QuestionMarkParams'. +detectParamStyle :: ByteString -> ParamStyle +detectParamStyle = go 0 + where + go :: Int -> ByteString -> ParamStyle + go !maxN bs + | BS.null bs = if maxN > 0 then NumberedParams maxN else QuestionMarkParams + | otherwise = + let c = BS.head bs + rest = BS.tail bs + in case c of + 39 {- '\'' -} -> + let (_lit, after) = skipStringLiteral rest + in go maxN after + 34 {- '"' -} -> + let (_ident, after) = skipQuotedIdent rest + in go maxN after + 45 {- '-' -} + | not (BS.null rest) && BS.head rest == 45 -> + let (_comment, after) = BS.break (== 10) rest + in go maxN after + 47 {- '/' -} + | not (BS.null rest) && BS.head rest == 42 -> + let (_comment, after) = skipBlockComment (BS.tail rest) 1 + in go maxN after + 36 {- '$' -} -> + let (digits, after) = BS.span (\w -> isDigit (chr w)) rest + in if BS.null digits + then go maxN after + else + let n = BS.foldl' (\acc w -> acc * 10 + (ord (chr w) - 48)) 0 digits + in go (max maxN n) after + _ -> go maxN rest + + chr :: Word8 -> Char + chr = toEnum . fromIntegral diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/JSON.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/JSON.hs new file mode 100644 index 000000000..57828f371 --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/JSON.hs @@ -0,0 +1,404 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +-- | Filter operators for JSON values added to PostgreSQL 9.4 +module Database.Persist.Postgresql.JSON + ( (@>.) + , (<@.) + , (?.) + , (?|.) + , (?&.) + , Value () + ) where + +import Data.Aeson (FromJSON, ToJSON, Value, eitherDecodeStrict, encode) +import qualified Data.ByteString.Lazy as BSL +import Data.Proxy (Proxy) +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding as TE (encodeUtf8) + +import Database.Persist + ( EntityField + , Filter (..) + , PersistField (..) + , PersistFilter (..) + , PersistValue (..) + ) +import Database.Persist.Sql (PersistFieldSql (..), SqlType (..)) +import Database.Persist.Types (FilterValue (..)) + +infix 4 @>., <@., ?., ?|., ?&. + +-- | This operator checks inclusion of the JSON value +-- on the right hand side in the JSON value on the left +-- hand side. +-- +-- === __Objects__ +-- +-- An empty Object matches any object +-- +-- @ +-- {} \@> {} == True +-- {"a":1,"b":false} \@> {} == True +-- @ +-- +-- Any key-value will be matched top-level +-- +-- @ +-- {"a":1,"b":{"c":true"}} \@> {"a":1} == True +-- {"a":1,"b":{"c":true"}} \@> {"b":1} == False +-- {"a":1,"b":{"c":true"}} \@> {"b":{}} == True +-- {"a":1,"b":{"c":true"}} \@> {"c":true} == False +-- {"a":1,"b":{"c":true"}} \@> {"b":{c":true}} == True +-- @ +-- +-- === __Arrays__ +-- +-- An empty Array matches any array +-- +-- @ +-- [] \@> [] == True +-- [1,2,"hi",false,null] \@> [] == True +-- @ +-- +-- Any array has to be a sub-set. +-- Any object or array will also be compared as being a subset of. +-- +-- @ +-- [1,2,"hi",false,null] \@> [1] == True +-- [1,2,"hi",false,null] \@> [null,"hi"] == True +-- [1,2,"hi",false,null] \@> ["hi",true] == False +-- [1,2,"hi",false,null] \@> ["hi",2,null,false,1] == True +-- [1,2,"hi",false,null] \@> [1,2,"hi",false,null,{}] == False +-- @ +-- +-- Arrays and objects inside arrays match the same way they'd +-- be matched as being on their own. +-- +-- @ +-- [1,"hi",[false,3],{"a":[null]}] \@> [{}] == True +-- [1,"hi",[false,3],{"a":[null]}] \@> [{"a":[]}] == True +-- [1,"hi",[false,3],{"a":[null]}] \@> [{"b":[null]}] == False +-- [1,"hi",[false,3],{"a":[null]}] \@> [[]] == True +-- [1,"hi",[false,3],{"a":[null]}] \@> [[3]] == True +-- [1,"hi",[false,3],{"a":[null]}] \@> [[true,3]] == False +-- @ +-- +-- A regular value has to be a member +-- +-- @ +-- [1,2,"hi",false,null] \@> 1 == True +-- [1,2,"hi",false,null] \@> 5 == False +-- [1,2,"hi",false,null] \@> "hi" == True +-- [1,2,"hi",false,null] \@> false == True +-- [1,2,"hi",false,null] \@> "2" == False +-- @ +-- +-- An object will never match with an array +-- +-- @ +-- [1,2,"hi",[false,3],{"a":null}] \@> {} == False +-- [1,2,"hi",[false,3],{"a":null}] \@> {"a":null} == False +-- @ +-- +-- === __Other values__ +-- +-- For any other JSON values the `(\@>.)` operator +-- functions like an equivalence operator. +-- +-- @ +-- "hello" \@> "hello" == True +-- "hello" \@> \"Hello" == False +-- "hello" \@> "h" == False +-- "hello" \@> {"hello":1} == False +-- "hello" \@> ["hello"] == False +-- +-- 5 \@> 5 == True +-- 5 \@> 5.00 == True +-- 5 \@> 1 == False +-- 5 \@> 7 == False +-- 12345 \@> 1234 == False +-- 12345 \@> 2345 == False +-- 12345 \@> "12345" == False +-- 12345 \@> [1,2,3,4,5] == False +-- +-- true \@> true == True +-- true \@> false == False +-- false \@> true == False +-- true \@> "true" == False +-- +-- null \@> null == True +-- null \@> 23 == False +-- null \@> "null" == False +-- null \@> {} == False +-- @ +-- +-- @since 2.8.2 +(@>.) :: EntityField record Value -> Value -> Filter record +(@>.) field val = Filter field (FilterValue val) $ BackendSpecificFilter " @> " + +-- | Same as '@>.' except the inclusion check is reversed. +-- i.e. is the JSON value on the left hand side included +-- in the JSON value of the right hand side. +-- +-- @since 2.8.2 +(<@.) :: EntityField record Value -> Value -> Filter record +(<@.) field val = Filter field (FilterValue val) $ BackendSpecificFilter " <@ " + +-- | This operator takes a column and a string to find a +-- top-level key/field in an object. +-- +-- @column ?. string@ +-- +-- N.B. This operator might have some unexpected interactions +-- with non-object values. Please reference the examples. +-- +-- === __Objects__ +-- +-- @ +-- {"a":null} ? "a" == True +-- {"test":false,"a":500} ? "a" == True +-- {"b":{"a":[]}} ? "a" == False +-- {} ? "a" == False +-- {} ? "{}" == False +-- {} ? "" == False +-- {"":9001} ? "" == True +-- @ +-- +-- === __Arrays__ +-- +-- This operator will match an array if the string to be matched +-- is an element of that array, but nothing else. +-- +-- @ +-- ["a"] ? "a" == True +-- [["a"]] ? "a" == False +-- [9,false,"1",null] ? "1" == True +-- [] ? "[]" == False +-- [{"a":true}] ? "a" == False +-- @ +-- +-- === __Other values__ +-- +-- This operator functions like an equivalence operator on strings only. +-- Any other value does not match. +-- +-- @ +-- "a" ? "a" == True +-- "1" ? "1" == True +-- "ab" ? "a" == False +-- 1 ? "1" == False +-- null ? "null" == False +-- true ? "true" == False +-- 1.5 ? "1.5" == False +-- @ +-- +-- @since 2.10.0 +(?.) :: EntityField record Value -> Text -> Filter record +(?.) = jsonFilter " ?? " + +-- | This operator takes a column and a list of strings to +-- test whether ANY of the elements of the list are top +-- level fields in an object. +-- +-- @column ?|. list@ +-- +-- /N.B. An empty list __will never match anything__. Also, this/ +-- /operator might have some unexpected interactions with/ +-- /non-object values. Please reference the examples./ +-- +-- === __Objects__ +-- +-- @ +-- {"a":null} ?| ["a","b","c"] == True +-- {"test":false,"a":500} ?| ["a","b","c"] == True +-- {} ?| ["a","{}"] == False +-- {"b":{"a":[]}} ?| ["a","c"] == False +-- {"b":{"a":[]},"test":null} ?| [] == False +-- @ +-- +-- === __Arrays__ +-- +-- This operator will match an array if __any__ of the elements +-- of the list are matching string elements of the array. +-- +-- @ +-- ["a"] ?| ["a","b","c"] == True +-- [["a"]] ?| ["a","b","c"] == False +-- [9,false,"1",null] ?| ["a","false"] == False +-- [] ?| ["a","b","c"] == False +-- [] ?| [] == False +-- [{"a":true}] ?| ["a","b","c"] == False +-- [null,4,"b",[]] ?| ["a","b","c"] == True +-- @ +-- +-- === __Other values__ +-- +-- This operator functions much like an equivalence operator +-- on strings only. If a string matches with __any__ element of +-- the given list, the comparison matches. No other values match. +-- +-- @ +-- "a" ?| ["a","b","c"] == True +-- "1" ?| ["a","b","1"] == True +-- "ab" ?| ["a","b","c"] == False +-- 1 ?| ["a","1"] == False +-- null ?| ["a","null"] == False +-- true ?| ["a","true"] == False +-- "a" ?| [] == False +-- @ +-- +-- @since 2.10.0 +(?|.) :: EntityField record Value -> [Text] -> Filter record +(?|.) field = jsonFilter " ??| " field . PostgresArray + +-- | This operator takes a column and a list of strings to +-- test whether ALL of the elements of the list are top +-- level fields in an object. +-- +-- @column ?&. list@ +-- +-- /N.B. An empty list __will match anything__. Also, this/ +-- /operator might have some unexpected interactions with/ +-- /non-object values. Please reference the examples./ +-- +-- === __Objects__ +-- +-- @ +-- {"a":null} ?& ["a"] == True +-- {"a":null} ?& ["a","a"] == True +-- {"test":false,"a":500} ?& ["a"] == True +-- {"test":false,"a":500} ?& ["a","b"] == False +-- {} ?& ["{}"] == False +-- {"b":{"a":[]}} ?& ["a"] == False +-- {"b":{"a":[]},"c":false} ?& ["a","c"] == False +-- {"a":1,"b":2,"c":3,"d":4} ?& ["b","d"] == True +-- {} ?& [] == True +-- {"b":{"a":[]},"test":null} ?& [] == True +-- @ +-- +-- === __Arrays__ +-- +-- This operator will match an array if __all__ of the elements +-- of the list are matching string elements of the array. +-- +-- @ +-- ["a"] ?& ["a"] == True +-- ["a"] ?& ["a","a"] == True +-- [["a"]] ?& ["a"] == False +-- ["a","b","c"] ?& ["a","b","d"] == False +-- [9,"false","1",null] ?& ["1","false"] == True +-- [] ?& ["a","b"] == False +-- [{"a":true}] ?& ["a"] == False +-- ["a","b","c","d"] ?& ["b","c","d"] == True +-- [null,4,{"test":false}] ?& [] == True +-- [] ?& [] == True +-- @ +-- +-- === __Other values__ +-- +-- This operator functions much like an equivalence operator +-- on strings only. If a string matches with all elements of +-- the given list, the comparison matches. +-- +-- @ +-- "a" ?& ["a"] == True +-- "1" ?& ["a","1"] == False +-- "b" ?& ["b","b"] == True +-- "ab" ?& ["a","b"] == False +-- 1 ?& ["1"] == False +-- null ?& ["null"] == False +-- true ?& ["true"] == False +-- 31337 ?& [] == True +-- true ?& [] == True +-- null ?& [] == True +-- @ +-- +-- @since 2.10.0 +(?&.) :: EntityField record Value -> [Text] -> Filter record +(?&.) field = jsonFilter " ??& " field . PostgresArray + +jsonFilter + :: (PersistField a) => Text -> EntityField record Value -> a -> Filter record +jsonFilter op field a = Filter field (UnsafeValue a) $ BackendSpecificFilter op + +----------------- +-- AESON VALUE -- +----------------- + +instance PersistField Value where + toPersistValue = toPersistValueJsonB + fromPersistValue = fromPersistValueJsonB + +instance PersistFieldSql Value where + sqlType = sqlTypeJsonB + +-- FIXME: PersistText might be a bit more efficient, +-- but needs testing/profiling before changing it. +-- (When entering into the DB the type isn't as important as fromPersistValue) +toPersistValueJsonB :: (ToJSON a) => a -> PersistValue +toPersistValueJsonB = PersistLiteralEscaped . BSL.toStrict . encode + +fromPersistValueJsonB :: (FromJSON a) => PersistValue -> Either Text a +fromPersistValueJsonB (PersistText t) = + case eitherDecodeStrict $ TE.encodeUtf8 t of + Left str -> Left $ fromPersistValueParseError "FromJSON" t $ T.pack str + Right v -> Right v +fromPersistValueJsonB (PersistByteString bs) = + case eitherDecodeStrict bs of + Left str -> Left $ fromPersistValueParseError "FromJSON" bs $ T.pack str + Right v -> Right v +fromPersistValueJsonB x = Left $ fromPersistValueError "FromJSON" "string or bytea" x + +-- Constraints on the type might not be necessary, +-- but better to leave them in. +sqlTypeJsonB :: (ToJSON a, FromJSON a) => Proxy a -> SqlType +sqlTypeJsonB _ = SqlOther "JSONB" + +fromPersistValueError + :: Text + -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64" + -> Text + -- ^ Database type(s), should appear different from Haskell name, e.g. "integer" or "INT", not "Int". + -> PersistValue + -- ^ Incorrect value + -> Text + -- ^ Error message +fromPersistValueError haskellType databaseType received = + T.concat + [ "Failed to parse Haskell type `" + , haskellType + , "`; expected " + , databaseType + , " from database, but received: " + , T.pack (show received) + , ". Potential solution: Check that your database schema matches your Persistent model definitions." + ] + +fromPersistValueParseError + :: (Show a) + => Text + -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64" + -> a + -- ^ Received value + -> Text + -- ^ Additional error + -> Text + -- ^ Error message +fromPersistValueParseError haskellType received err = + T.concat + [ "Failed to parse Haskell type `" + , haskellType + , "`, but received " + , T.pack (show received) + , " | with error: " + , err + ] + +newtype PostgresArray a = PostgresArray [a] + +instance (PersistField a) => PersistField (PostgresArray a) where + toPersistValue (PostgresArray ts) = PersistArray $ toPersistValue <$> ts + fromPersistValue (PersistArray as) = PostgresArray <$> traverse fromPersistValue as + fromPersistValue wat = Left $ fromPersistValueError "PostgresArray" "array" wat diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline.hs new file mode 100644 index 000000000..c769985b5 --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline.hs @@ -0,0 +1,2694 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} + +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | A pipelined PostgreSQL backend for persistent. +-- +-- Uses binary wire protocol via @postgresql-libpq@ and @postgresql-binary@, +-- with support for libpq pipeline mode. +module Database.Persist.Postgresql.Pipeline + ( -- * Backend types + PostgreSQLBackend + , ReadBackend (..) + , WriteBackend (..) + , readToWrite + , readToUnknown + , writeToUnknown + -- * Connection pools + , withPostgresqlPipelinePool + , createPostgresqlPipelinePool + , withPostgresqlPipelineConn + -- * Re-exports from persistent + , module Database.Persist.Sql + , ConnectionString + -- * Configuration + -- $configuration + , PipelineSettings + -- ** Construction + , defaultPipelineSettings + , singleRowSettings + , chunkedSettings + -- ** Fields + , pipelineFetchMode + , pipelineConnStr + , pipelinePoolSize + , FetchMode (..) + , PostgresConf (..) + , PostgresConfHooks (..) + , defaultPostgresConfHooks + -- * Upsert utilities + , HandleUpdateCollision + , copyField + , copyUnlessNull + , copyUnlessEmpty + , copyUnlessEq + , excludeNotEqualToOriginal + , upsertWhere + , upsertManyWhere + -- * Misc utilities + , PgInterval (..) + , tableName + , fieldName + , mockMigration + , migrateEnableExtension + -- * Optimized batch operations + , getManyKeys + , deleteManyKeys + , insertManyUnnest + , insertManyUnnest_ + , putManyUnnest + , repsertManyUnnest + -- * Pipeline mode + , withPipeline + , flushPipeline + -- * Pipeline-aware raw execution + , rawExecuteSync + , rawExecuteNoReturn + -- * Raw access + , RawPostgresqlPipeline (..) + , createRawPostgresqlPipelinePool + -- * Backend creation + , createBackend + -- * Low-level access (for testing) + , getPipelineConn + , inlineAndRewrite + , collapseInClauses + ) where + +import qualified Database.PostgreSQL.LibPQ as LibPQ + +import Control.Exception (throwIO) +import Control.Monad +import Control.Monad.IO.Unlift (MonadIO (..), MonadUnliftIO) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT) +import Control.Monad.Trans.Reader (ReaderT (..), ask, asks, runReaderT) +import Control.Monad.Trans.Writer (WriterT (..), runWriterT) +import qualified Data.Foldable as Foldable +import qualified Data.List.NonEmpty as NEL +import Data.Proxy (Proxy (..)) + +import Data.Acquire (Acquire, mkAcquire) +import qualified Data.Aeson as A +import Data.Aeson (FromJSON(..), ToJSON(..), (.:), (.:?), (.!=), withObject) +import Data.Aeson.Types (modifyFailure) +import GHC.Generics (Generic) +import Data.ByteString (ByteString) +import qualified Data.ByteString.Builder as BB +import qualified Data.ByteString.Char8 as B8 +import qualified Data.ByteString.Lazy as LBS +import Data.Conduit +import Data.Data (Data) +import Data.Either (partitionEithers) +import Data.IORef +import Data.Int (Int64) +import Data.Maybe (fromMaybe) +import Data.Function (on) +import Data.List (nubBy, transpose) +import Data.List.NonEmpty (NonEmpty) +import qualified Data.Map as Map +import qualified Data.Monoid as Monoid +import Data.Pool (Pool) +import Data.Text (Text) +import qualified Data.Text as T +import qualified Data.Text.Encoding as T +import qualified Data.Text.IO as TIO +import System.Environment (getEnvironment) + +import Database.Persist.Class.PersistUnique (persistUniqueKeyValues) +import Database.Persist.Compatible +import qualified Data.Vault.Strict as Vault +import Database.Persist.Postgresql.Internal +import Database.Persist.Postgresql.Pipeline.Internal +import Database.Persist.Postgresql.Pipeline.FFI + ( rawResultStatusInt + , setChunkedRowsMode + , pattern PGRES_SINGLE_TUPLE + , pattern PGRES_TUPLES_OK + , pattern PGRES_TUPLES_CHUNK + , pattern PGRES_FATAL_ERROR + ) +import Database.Persist.DirectDecode + (FromRow (..), RowDecoder (..), newCounter, runRowReaderCPS, runRowDecoderCPS) +import Database.Persist.Postgresql.Internal.Decoding (decodePersistValue) +import Database.Persist.Postgresql.Internal.DirectDecode (PgRowEnv (..)) +import Database.Persist.Postgresql.Internal.DirectEncode (PgParam (..), pgParamToLibPQ) +import Database.Persist.Postgresql.Internal.Encoding (encodePersistValue) +import Database.Persist.Postgresql.Internal.Placeholders + (skipStringLiteral, skipQuotedIdent, skipBlockComment, ParamStyle (..), detectParamStyle, rewritePlaceholders) +import Database.Persist.Postgresql.Internal.PgType (OidCache, PgType, emptyOidCache, fromOid) +import Database.Persist.Sql.DirectRaw (HasDirectQuery (..)) +import Database.Persist.SqlBackend.Internal + (SqlBackend (..), DirectQueryCap (..)) +import qualified Data.Vector as V +import Database.Persist.Sql hiding (readToWrite, readToUnknown, writeToUnknown) +import qualified Database.Persist.Sql.Util as Util +import Database.Persist.SqlBackend +import System.IO.Unsafe (unsafeInterleaveIO, unsafePerformIO) + +-- | A @libpq@ connection string. A simple example of connection +-- string would be @\"host=localhost port=5432 user=test +-- dbname=test password=test\"@. Please read libpq's +-- documentation at +-- +-- for more details on how to create such strings. +type ConnectionString = ByteString + +-- $configuration +-- +-- 'PipelineSettings' is an opaque configuration record. Start from +-- 'defaultPipelineSettings' (or a convenience smart constructor like +-- 'singleRowSettings' / 'chunkedSettings') and customise individual +-- fields with record update syntax: +-- +-- @ +-- mySettings :: PipelineSettings +-- mySettings = defaultPipelineSettings +-- { pipelineFetchMode = FetchChunked 500 +-- , pipelinePoolSize = 10 +-- , pipelineConnStr = \"host=localhost dbname=mydb\" +-- } +-- @ +-- +-- The constructor is intentionally hidden so that new fields can be +-- added (with defaults) in future releases without breaking +-- downstream code that uses record updates. + +-- | Configuration for a pipelined PostgreSQL backend. +-- +-- Construct via 'defaultPipelineSettings', 'singleRowSettings', or +-- 'chunkedSettings', then tweak with 'setFetchMode', 'setConnStr', +-- and 'setPoolSize'. The constructor is not exported; use the +-- accessor functions to read values. +data PipelineSettings = MkPipelineSettings + { pipelineFetchMode :: !FetchMode + -- ^ How SELECT-like query results are fetched. Default: 'FetchAll'. + , pipelineConnStr :: !ConnectionString + -- ^ libpq connection string. Default: @\"\"@ (empty -- uses libpq + -- defaults \/ environment variables). + , pipelinePoolSize :: !Int + -- ^ Number of connections in the pool. Default: @1@. + } + +-- | Sensible defaults: 'FetchAll', empty connection string, pool size 1. +defaultPipelineSettings :: PipelineSettings +defaultPipelineSettings = MkPipelineSettings + { pipelineFetchMode = FetchAll + , pipelineConnStr = "" + , pipelinePoolSize = 1 + } + +-- | 'FetchSingleRow' preset. One row per @PGresult@ -- lowest memory, +-- works with every PostgreSQL version. +singleRowSettings :: ConnectionString -> Int -> PipelineSettings +singleRowSettings cstr poolSz = MkPipelineSettings + { pipelineFetchMode = FetchSingleRow + , pipelineConnStr = cstr + , pipelinePoolSize = poolSz + } + +-- | 'FetchChunked' preset. Up to @chunkSize@ rows per @PGresult@ -- good +-- balance of bounded memory and low allocation overhead. Requires +-- libpq >= 17; automatically falls back to 'FetchSingleRow' otherwise. +chunkedSettings :: Int -> ConnectionString -> Int -> PipelineSettings +chunkedSettings chunkSize cstr poolSz = MkPipelineSettings + { pipelineFetchMode = FetchChunked chunkSize + , pipelineConnStr = cstr + , pipelinePoolSize = poolSz + } + +-- | Create a PostgreSQL connection pool and run the given action. +-- +-- @ +-- withPostgresqlPipelinePool settings $ \\pool -> +-- runSqlPool (selectList [] []) pool +-- @ +withPostgresqlPipelinePool + :: (MonadLoggerIO m, MonadUnliftIO m) + => PipelineSettings + -> (Pool (WriteBackend PostgreSQLBackend) -> m a) + -> m a +withPostgresqlPipelinePool settings = + withSqlPool (openPipelineWith settings) (pipelinePoolSize settings) + +-- | Create a PostgreSQL connection pool (returned, not bracketed). +createPostgresqlPipelinePool + :: (MonadUnliftIO m, MonadLoggerIO m) + => PipelineSettings + -> m (Pool (WriteBackend PostgreSQLBackend)) +createPostgresqlPipelinePool settings = + createSqlPool (openPipelineWith settings) (pipelinePoolSize settings) + +-- | Open a single pipelined PostgreSQL connection and run the given action. +withPostgresqlPipelineConn + :: (MonadUnliftIO m, MonadLoggerIO m) + => PipelineSettings + -> (WriteBackend PostgreSQLBackend -> m a) + -> m a +withPostgresqlPipelineConn settings = + withSqlConn (openPipelineWith settings) + +-- | Open a pipelined connection with the given settings, returning a +-- write-capable backend. +openPipelineWith :: PipelineSettings -> LogFunc -> IO (WriteBackend PostgreSQLBackend) +openPipelineWith settings logFunc = do + pc <- openPgConn (pipelineFetchMode settings) (pipelineConnStr settings) + smap <- newIORef mempty + return $ WriteBackend $ PostgreSQLBackend $ + createBackend logFunc (pgVersion pc) smap pc + +-- | Create the backend given a logging function, server version, +-- mutable statement cell, and connection. +createBackend + :: LogFunc + -> NonEmpty Word + -> IORef (Map.Map Text Statement) + -> PgConn + -> SqlBackend +createBackend logFunc serverVersion smap pc = + maybe id setConnPutManySql (upsertFunction putManySql serverVersion) $ + maybe id setConnUpsertSql (upsertFunction upsertSql' serverVersion) $ + setConnInsertManySql insertManySql' $ + maybe id setConnRepsertManySql (upsertFunction repsertManySql serverVersion) $ + modifyConnVault (Vault.insert underlyingConnectionKey pc) $ + setDirectQueryCap pc $ + mkSqlBackend + MkSqlBackendArgs + { connPrepare = prepare' pc + , connStmtMap = smap + , connInsertSql = insertSql' + , connClose = closePgConn pc + , connMigrateSql = migrate' emptyBackendSpecificOverrides + , connBegin = \_ mIsolation -> do + -- Safety: drain any leftover pending results + -- from a previous session that didn't go through + -- COMMIT/ROLLBACK (e.g. runSqlPoolNoTransaction). + -- Zero-cost when pgPending == 0. + drainPending pc + -- Fire-and-forget: BEGIN is pipelined with the + -- first user query for zero extra round-trips. + let isolationStr = case mIsolation of + Nothing -> "" + Just ReadUncommitted -> " ISOLATION LEVEL READ COMMITTED" + Just ReadCommitted -> " ISOLATION LEVEL READ COMMITTED" + Just RepeatableRead -> " ISOLATION LEVEL REPEATABLE READ" + Just Serializable -> " ISOLATION LEVEL SERIALIZABLE" + ok <- pgSendQueryParams pc ("BEGIN" <> isolationStr) [] LibPQ.Binary + unless ok $ throwIO $ PipelineExecError "failed to send BEGIN" + beginReply <- pgRecvResult pc + modifyIORef (pgPending pc) (beginReply :) + , connCommit = \_ -> do + -- Drain all outstanding results (lazy + fire-and-forget). + drainPending pc + -- Send COMMIT + sync, read COMMIT result + sync marker. + ok <- pgSendQueryParams pc "COMMIT" [] LibPQ.Binary + unless ok $ throwIO $ PipelineExecError "failed to send COMMIT" + ok2 <- pgPipelineSync pc + unless ok2 $ throwIO $ PipelineModeError "failed to sync pipeline" + commitReply <- pgRecvResult pc + st <- LibPQ.resultStatus commitReply + LibPQ.unsafeFreeResult commitReply + drainSyncResult pc + case st of + LibPQ.FatalError -> throwIO $ PipelineExecError "COMMIT failed" + _ -> return () + , connRollback = \_ -> do + -- Drain all outstanding results. + drainPending pc + ok <- pgSendQueryParams pc "ROLLBACK" [] LibPQ.Binary + unless ok $ throwIO $ PipelineExecError "failed to send ROLLBACK" + ok2 <- pgPipelineSync pc + unless ok2 $ throwIO $ PipelineModeError "failed to sync pipeline" + drainToSync pc + writeIORef (pgPending pc) [] + , connEscapeFieldName = escapeF + , connEscapeTableName = escapeE . getEntityDBName + , connEscapeRawName = escape + , connNoLimit = "LIMIT ALL" + , connRDBMS = "postgresql" + , connLimitOffset = decorateSQLWithLimitOffset "LIMIT ALL" + , connLogFunc = logFunc + } + +-- | Install the direct-query capability for PgRowEnv. +-- +-- The 'DirectQueryCap' receives a pre-built 'RowDecoder' from the call +-- site (produced by 'lookupDirectDecoder'). The backend just sends the +-- query, builds 'PgRowEnv' per row, and applies the decoder. +setDirectQueryCap :: PgConn -> SqlBackend -> SqlBackend +setDirectQueryCap pc' sb = sb { connDirectQueryCap = Just cap } + where + cap = MkDirectQueryCap (Proxy @PgRowEnv) $ \decoder sql pvs -> do + let sqlBS = T.encodeUtf8 sql + drainPending pc' + cache <- readIORef (pgOidCache pc') + let params = map encodePersistValue pvs + ok <- case detectParamStyle sqlBS of + QuestionMarkParams -> + let (rewritten, _) = rewritePlaceholders sqlBS + in pgSendQueryParams pc' rewritten params LibPQ.Binary + NumberedParams _ -> + pgSendQueryParams pc' sqlBS params LibPQ.Binary + unless ok $ do + merr <- LibPQ.errorMessage (pgConn pc') + throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr) + _ <- pgSendFlushRequest pc' + pgFlush pc' + ret <- readOneQueryResult pc' + colTypes <- getColumnPgTypes ret + rowCount <- LibPQ.ntuples ret + results <- decodeAllRows decoder ret rowCount colTypes cache + LibPQ.unsafeFreeResult ret + pure results + + decodeAllRows decoder ret rowCount colTypes cache = go [] (LibPQ.Row 0) + where + go !acc !row + | row == rowCount = pure (reverse acc) + | otherwise = do + let env = PgRowEnv ret row colTypes cache + val <- runRowDecoder decoder env + (\e -> throwIO $ PipelineExecError $ T.encodeUtf8 e) + pure + go (val : acc) (row + 1) + +-- | Key for storing PgConn in the SqlBackend vault. +underlyingConnectionKey :: Vault.Key PgConn +underlyingConnectionKey = unsafePerformIO Vault.newKey +{-# NOINLINE underlyingConnectionKey #-} + +-- | Access underlying PgConn, returning 'Nothing' if the 'SqlBackend' +-- provided isn't backed by this pipeline backend. +getPipelineConn + :: (BackendCompatible SqlBackend backend) => backend -> Maybe PgConn +getPipelineConn = Vault.lookup underlyingConnectionKey <$> getConnVault + +------------------------------------------------------------------------------- +-- PostgreSQLBackend / ReadBackend / WriteBackend +-- +-- PostgreSQLBackend wraps SqlBackend. ReadBackend and WriteBackend +-- are capability newtypes that mirror SqlReadBackend / SqlWriteBackend +-- but carry the UNNEST / = ANY optimised typeclass instances. +------------------------------------------------------------------------------- + +-- | The underlying PostgreSQL backend. Not used directly in user +-- signatures -- use @'ReadBackend' 'PostgreSQLBackend'@ or +-- @'WriteBackend' 'PostgreSQLBackend'@ to select read-only or +-- read-write capability. +newtype PostgreSQLBackend = PostgreSQLBackend { unPostgreSQLBackend :: SqlBackend } + +-- | A read-only capability wrapper. @ReadBackend PostgreSQLBackend@ +-- provides 'PersistStoreRead', 'PersistQueryRead', and +-- 'PersistUniqueRead' instances but no write instances. +newtype ReadBackend backend = ReadBackend { unReadBackend :: backend } + +-- | A read-write capability wrapper. @WriteBackend PostgreSQLBackend@ +-- provides all read instances plus 'PersistStoreWrite', +-- 'PersistQueryWrite', and 'PersistUniqueWrite' with UNNEST-based +-- columnar inserts. +newtype WriteBackend backend = WriteBackend { unWriteBackend :: backend } + +-- | Lift a read-only action into a read-write context. +readToWrite + :: (Monad m) + => ReaderT (ReadBackend PostgreSQLBackend) m a + -> ReaderT (WriteBackend PostgreSQLBackend) m a +readToWrite ma = do + WriteBackend w <- ask + lift $ runReaderT ma (ReadBackend w) + +-- | Run a read-only action against a bare 'PostgreSQLBackend'. +readToUnknown + :: (Monad m) + => ReaderT (ReadBackend PostgreSQLBackend) m a + -> ReaderT PostgreSQLBackend m a +readToUnknown ma = do + unknown <- ask + lift $ runReaderT ma (ReadBackend unknown) + +-- | Run a read-write action against a bare 'PostgreSQLBackend'. +writeToUnknown + :: (Monad m) + => ReaderT (WriteBackend PostgreSQLBackend) m a + -> ReaderT PostgreSQLBackend m a +writeToUnknown ma = do + unknown <- ask + lift $ runReaderT ma (WriteBackend unknown) + +-- PostgreSQLBackend base instances + +instance HasPersistBackend PostgreSQLBackend where + type BaseBackend PostgreSQLBackend = SqlBackend + persistBackend = unPostgreSQLBackend + +instance BackendCompatible SqlBackend PostgreSQLBackend where + projectBackend = unPostgreSQLBackend + +instance PersistCore PostgreSQLBackend where + newtype BackendKey PostgreSQLBackend = PostgreSQLBackendKey + { unPostgreSQLBackendKey :: Int64 } + deriving stock (Show, Read, Eq, Ord, Generic) + deriving newtype + ( Num, Integral, PersistField, PersistFieldSql + , Real, Enum, Bounded + , A.ToJSON, A.FromJSON + ) + +-- ReadBackend instances + +instance HasPersistBackend (ReadBackend PostgreSQLBackend) where + type BaseBackend (ReadBackend PostgreSQLBackend) = SqlBackend + persistBackend = unPostgreSQLBackend . unReadBackend + +instance BackendCompatible SqlBackend (ReadBackend PostgreSQLBackend) where + projectBackend = unPostgreSQLBackend . unReadBackend + +instance PersistCore (ReadBackend PostgreSQLBackend) where + newtype BackendKey (ReadBackend PostgreSQLBackend) = ReadPostgreSQLBackendKey + { unReadPostgreSQLBackendKey :: Int64 } + deriving stock (Show, Read, Eq, Ord, Generic) + deriving newtype + ( Num, Integral, PersistField, PersistFieldSql + , Real, Enum, Bounded + , A.ToJSON, A.FromJSON + ) + +instance PersistStoreRead (ReadBackend PostgreSQLBackend) where + get k = liftViaWrite $ pipelinedGet k + getMany ks = withBaseBackend $ getMany ks + +instance PersistQueryRead (ReadBackend PostgreSQLBackend) where + count filts = liftViaWrite $ pipelinedCount filts + exists filts = liftViaWrite $ pipelinedExists filts + selectSourceRes filts opts = withBaseBackend $ selectSourceRes filts opts + selectKeysRes filts opts = withBaseBackend $ selectKeysRes filts opts + +instance PersistUniqueRead (ReadBackend PostgreSQLBackend) where + getBy uniq = liftViaWrite $ pipelinedGetBy uniq + existsBy uniq = withBaseBackend $ existsBy uniq + +liftViaWrite + :: MonadIO m + => ReaderT (WriteBackend PostgreSQLBackend) IO a + -> ReaderT (ReadBackend PostgreSQLBackend) m a +liftViaWrite action = do + backend <- ask + liftIO $ runReaderT action (WriteBackend $ unReadBackend backend) + +-- WriteBackend instances (read + write, with UNNEST overrides) + +instance HasPersistBackend (WriteBackend PostgreSQLBackend) where + type BaseBackend (WriteBackend PostgreSQLBackend) = SqlBackend + persistBackend = unPostgreSQLBackend . unWriteBackend + +instance BackendCompatible SqlBackend (WriteBackend PostgreSQLBackend) where + projectBackend = unPostgreSQLBackend . unWriteBackend + +instance PersistCore (WriteBackend PostgreSQLBackend) where + newtype BackendKey (WriteBackend PostgreSQLBackend) = WritePostgreSQLBackendKey + { unWritePostgreSQLBackendKey :: Int64 } + deriving stock (Show, Read, Eq, Ord, Generic) + deriving newtype + ( Num, Integral, PersistField, PersistFieldSql + , Real, Enum, Bounded + , A.ToJSON, A.FromJSON + ) + +instance PersistStoreRead (WriteBackend PostgreSQLBackend) where + get = pipelinedGet + getMany ks = withBaseBackend $ getMany ks + +instance PersistStoreWrite (WriteBackend PostgreSQLBackend) where + insert = pipelinedInsert + insert_ v = withBaseBackend $ insert_ v + insertMany vs = withBaseBackend $ insertMany vs + insertKey k v = withBaseBackend $ insertKey k v + insertEntityMany vs = withBaseBackend $ insertEntityMany vs + repsert k v = withBaseBackend $ repsert k v + replace k v = withBaseBackend $ replace k v + delete k = withBaseBackend $ delete k + update k upds = withBaseBackend $ update k upds + + insertMany_ [] = return () + insertMany_ records = do + let t = entityDef records + (sql, params) = insertMany_Hook t (map Util.mkInsertValues records) + rawExecute sql params + + repsertMany [] = return () + repsertMany krsDups = do + let krs = nubBy ((==) `on` fst) (reverse krsDups) + t = entityDef (map snd krs) + toVals (k, r) = + case entityPrimary t of + Nothing -> keyToValues k <> Util.mkInsertValues r + Just _ -> Util.mkInsertValues r + (sql, params) = repsertManyHook t (map toVals krs) + rawExecute sql params + +instance PersistQueryRead (WriteBackend PostgreSQLBackend) where + count = pipelinedCount + exists = pipelinedExists + selectSourceRes filts opts = withBaseBackend $ selectSourceRes filts opts + selectKeysRes filts opts = withBaseBackend $ selectKeysRes filts opts + +instance PersistQueryWrite (WriteBackend PostgreSQLBackend) where + deleteWhere filts = withBaseBackend $ deleteWhere filts + updateWhere filts upds = withBaseBackend $ updateWhere filts upds + +instance PersistUniqueRead (WriteBackend PostgreSQLBackend) where + getBy = pipelinedGetBy + existsBy uniq = withBaseBackend $ existsBy uniq + +instance PersistUniqueWrite (WriteBackend PostgreSQLBackend) where + deleteBy uniq = withBaseBackend $ deleteBy uniq + upsertBy uniqueKey record updates = pipelinedUpsertBy uniqueKey record updates + + putMany [] = return () + putMany rsD@(r:_) = do + let uKeys = persistUniqueKeys r + case uKeys of + [] -> insertMany_ rsD + _ -> do + let rs = nubBy ((==) `on` persistUniqueKeyValues) (reverse rsD) + t = entityDef rs + toVals rv = map toPersistValue (toPersistFields rv) + (sql, params) = putManyHook t (map toVals rs) + rawExecute sql params + +--------------------------------------------------------------------------- +-- Hedis-style automatic pipelining primitives +--------------------------------------------------------------------------- + +-- | Send a SQL query into the pipeline buffer and pop a lazy reply. +-- Does NOT flush or read -- the IO happens when the returned +-- 'LibPQ.Result' thunk is forced. +-- +-- Appends to 'pgPending' so commit\/rollback forces the result even +-- if the caller discards the lazy thunk. Also returns the thunk +-- to the caller for lazy decoding. +pipelinedSend :: PgConn -> Text -> [PersistValue] -> IO LibPQ.Result +pipelinedSend pc sql vals = do + let (rewrittenSQL, remainingVals) = inlineAndRewrite (T.encodeUtf8 sql) vals + params = map encodePersistValue remainingVals + ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary + unless ok $ do + merr <- LibPQ.errorMessage (pgConn pc) + throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr) + lazyReply <- pgRecvResult pc + modifyIORef (pgPending pc) (lazyReply :) + return lazyReply + +-- | Decode all rows from a lazy 'LibPQ.Result' into @[[PersistValue]]@ +-- and free the result (early free for better memory). Double-free is +-- safe since 'LibPQ.Result' is a 'ForeignPtr'. The entire decode+free +-- happens inside 'unsafeInterleaveIO', so calling this returns a lazy +-- thunk. +pipelinedDecodeRows :: LibPQ.Result -> IO [[PersistValue]] +pipelinedDecodeRows lazyReply = unsafeInterleaveIO $ do + st <- LibPQ.resultStatus lazyReply + case st of + LibPQ.TuplesOk -> do + rowCount <- LibPQ.ntuples lazyReply + cols <- LibPQ.nfields lazyReply + oids <- ireplicateM cols $ \col -> do + oid <- LibPQ.ftype lazyReply col + return (col, oid) + rows <- forM [LibPQ.Row 0 .. rowCount - 1] $ \row -> + forM oids $ \(col, oid) -> do + mbs <- LibPQ.getvalue' lazyReply row col + case decodePersistValue oid mbs of + Left err -> throwIO $ PersistMarshalError err + Right val -> return val + LibPQ.unsafeFreeResult lazyReply + return rows + LibPQ.CommandOk -> do + LibPQ.unsafeFreeResult lazyReply + return [] + _ -> do + merr <- LibPQ.resultErrorMessage lazyReply + LibPQ.unsafeFreeResult lazyReply + throwIO $ PipelineExecError + (fromMaybe "pipelined query error" merr) + +-- | Run a query through the pipeline and return decoded rows lazily. +-- Combines 'pipelinedSend' + 'pipelinedDecodeRows'. +pipelinedQuery :: PgConn -> Text -> [PersistValue] -> IO [[PersistValue]] +pipelinedQuery pc sql vals = pipelinedSend pc sql vals >>= pipelinedDecodeRows + +-- | Decode a single entity from a lazy reply, returning lazily. +pipelinedDecodeEntity + :: PersistEntity record + => EntityDef -> LibPQ.Result -> IO (Maybe record) +pipelinedDecodeEntity t lazyReply = unsafeInterleaveIO $ do + rows <- pipelinedDecodeRows lazyReply + case rows of + [] -> return Nothing + (vals:_) -> case Util.parseEntityValues t vals of + Left err -> throwIO $ PersistMarshalError err + Right entity -> return (Just (entityVal entity)) + +-- | Decode the first column of the first row as Int64, returning lazily. +pipelinedDecodeInt64 :: LibPQ.Result -> IO (Maybe Int64) +pipelinedDecodeInt64 lazyReply = unsafeInterleaveIO $ do + rows <- pipelinedDecodeRows lazyReply + case rows of + ([PersistInt64 n]:_) -> return (Just n) + _ -> return Nothing + +-- | Decode the first column of the first row as Bool, returning lazily. +pipelinedDecodeBool :: LibPQ.Result -> IO (Maybe Bool) +pipelinedDecodeBool lazyReply = unsafeInterleaveIO $ do + rows <- pipelinedDecodeRows lazyReply + case rows of + ([PersistBool b]:_) -> return (Just b) + _ -> return Nothing + +-- | Decode a key from the first row (RETURNING), returning lazily. +pipelinedDecodeKey + :: PersistEntity record + => LibPQ.Result -> IO (Key record) +pipelinedDecodeKey lazyReply = unsafeInterleaveIO $ do + rows <- pipelinedDecodeRows lazyReply + case rows of + ([PersistInt64 i]:_) -> case keyFromValues [PersistInt64 i] of + Left err -> throwIO $ PersistMarshalError err + Right k -> return k + (vals:_) -> case keyFromValues vals of + Left err -> throwIO $ PersistMarshalError err + Right k -> return k + _ -> throwIO $ PersistMarshalError "pipelinedDecodeKey: no key returned" + +-- | Get the 'PgConn' and 'SqlBackend' from a backend. +-- Does NOT drain pending results -- callers that need prior writes +-- to be visible should call 'drainPending' themselves. +withPipelineConn + :: (BackendCompatible SqlBackend backend, MonadIO m) + => backend + -> (PgConn -> SqlBackend -> IO a) + -> m a +withPipelineConn backend f = liftIO $ do + let conn = projectBackend backend :: SqlBackend + case Vault.lookup underlyingConnectionKey (connVault conn) of + Nothing -> fail "persistent-postgresql-ng: no PgConn" + Just pc -> f pc conn + +--------------------------------------------------------------------------- +-- Hedis-style pipelined operations +--------------------------------------------------------------------------- + +pipelinedGet + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + ) + => Key record + -> ReaderT (WriteBackend PostgreSQLBackend) m (Maybe record) +pipelinedGet k = do + backend <- ask + withPipelineConn backend $ \pc conn -> do + let t = entityDef (dummyFromKey k) + escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn + cols = T.intercalate "," + $ Foldable.toList + $ Util.keyAndEntityColumnNames t conn + wher = T.intercalate " AND " + $ Foldable.toList + $ fmap (<> "=? ") + $ Util.dbIdColumns conn t + sql = T.concat + [ "SELECT ", cols, " FROM ", escTbl t, " WHERE ", wher ] + lazyReply <- pipelinedSend pc sql (keyToValues k) + pipelinedDecodeEntity t lazyReply + where + dummyFromKey :: Key record -> Maybe record + dummyFromKey _ = Nothing + +pipelinedGetBy + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + ) + => Unique record + -> ReaderT (WriteBackend PostgreSQLBackend) m (Maybe (Entity record)) +pipelinedGetBy uniq = do + backend <- ask + withPipelineConn backend $ \pc conn -> do + let t = entityDef (Nothing :: Maybe record) + escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn + escFld = Database.Persist.SqlBackend.Internal.connEscapeFieldName conn + cols = T.intercalate "," + $ Foldable.toList + $ Util.keyAndEntityColumnNames t conn + uniqs = persistUniqueToFieldNames uniq + wher = T.intercalate " AND " + $ Foldable.toList + $ fmap (\(_, n) -> escFld n <> "=?") uniqs + sql = T.concat + [ "SELECT ", cols, " FROM ", escTbl t, " WHERE ", wher ] + vals = persistUniqueToValues uniq + lazyReply <- pipelinedSend pc sql vals + unsafeInterleaveIO $ do + rows <- pipelinedDecodeRows lazyReply + case rows of + [] -> return Nothing + (v:_) -> case Util.parseEntityValues t v of + Left err -> throwIO $ PersistMarshalError err + Right entity -> return (Just entity) + +pipelinedInsert + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , SafeToInsert record + , MonadIO m + ) + => record + -> ReaderT (WriteBackend PostgreSQLBackend) m (Key record) +pipelinedInsert record = do + backend <- ask + withPipelineConn backend $ \pc conn -> do + let vals = Util.mkInsertValues record + case Database.Persist.SqlBackend.Internal.connInsertSql conn (entityDef (Just record)) vals of + ISRSingle sql -> pipelinedSend pc sql vals >>= pipelinedDecodeKey + ISRSingleCustom sql params -> pipelinedSend pc sql params >>= pipelinedDecodeKey + _ -> runReaderT (withBaseBackend $ insert record) backend + +pipelinedCount + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + ) + => [Filter record] + -> ReaderT (WriteBackend PostgreSQLBackend) m Int +pipelinedCount filts = do + backend <- ask + withPipelineConn backend $ \pc conn -> do + let t = entityDef (Nothing :: Maybe record) + escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn + (wher, vals) = if null filts + then ("", []) + else filterClauseWithVals (Just PrefixTableName) conn filts + sql = T.concat + [ "SELECT COUNT(*) FROM ", escTbl t, wher ] + lazyReply <- pipelinedSend pc sql vals + mn <- pipelinedDecodeInt64 lazyReply + return $ maybe 0 fromIntegral mn + +pipelinedExists + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + ) + => [Filter record] + -> ReaderT (WriteBackend PostgreSQLBackend) m Bool +pipelinedExists filts = do + backend <- ask + withPipelineConn backend $ \pc conn -> do + let t = entityDef (Nothing :: Maybe record) + escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn + (wher, vals) = if null filts + then ("", []) + else filterClauseWithVals (Just PrefixTableName) conn filts + sql = T.concat + [ "SELECT EXISTS(SELECT 1 FROM ", escTbl t, wher, ")" ] + lazyReply <- pipelinedSend pc sql vals + mb <- pipelinedDecodeBool lazyReply + return $ maybe False id mb + +pipelinedUpsertBy + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , SafeToInsert record + , MonadIO m + ) + => Unique record + -> record + -> [Update record] + -> ReaderT (WriteBackend PostgreSQLBackend) m (Entity record) +pipelinedUpsertBy uniqueKey record updates = do + backend <- ask + case updates of + [] -> withBaseBackend $ upsertBy uniqueKey record updates + _ -> withPipelineConn backend $ \pc conn -> do + let escFld = Database.Persist.SqlBackend.Internal.connEscapeFieldName conn + escTbl = Database.Persist.SqlBackend.Internal.connEscapeTableName conn + t = entityDef (Just record) + refCol n = T.concat [escTbl t, ".", n] + mkUpd = Util.mkUpdateText' escFld refCol + case Database.Persist.SqlBackend.Internal.connUpsertSql conn of + Just upsertSqlFn -> do + let upds = T.intercalate "," $ map mkUpd updates + rawSql_ = upsertSqlFn t (persistUniqueToFieldNames uniqueKey) upds + colList = T.intercalate "," + $ Foldable.toList + $ Util.keyAndEntityColumnNames t conn + sql = T.replace "??" colList rawSql_ + vals = map toPersistValue (toPersistFields record) + ++ map Util.updatePersistValue updates + ++ concatMap persistUniqueToValues [uniqueKey] + lazyReply <- pipelinedSend pc sql vals + unsafeInterleaveIO $ do + rows <- pipelinedDecodeRows lazyReply + case rows of + (v:_) -> case Util.parseEntityValues t v of + Left err -> throwIO $ PersistMarshalError err + Right entity -> return entity + [] -> throwIO $ PersistMarshalError "pipelinedUpsertBy: no rows returned" + Nothing -> runReaderT (withBaseBackend $ upsertBy uniqueKey record updates) backend + +-- | Prepare a statement. Rewrites @?@ placeholders to @$1, $2, ...@ +-- and handles 'PersistLiteral_ Unescaped' values by inlining them +-- into the SQL text. +prepare' :: PgConn -> Text -> IO Statement +prepare' pc sql = do + return + Statement + { stmtFinalize = return () + , stmtReset = return () + , stmtExecute = execute' pc sqlBS + , stmtQuery = withStmt' pc sqlBS + } + where + sqlBS = T.encodeUtf8 sql + +-- | Inline 'PersistLiteral_ Unescaped' values into the SQL text, +-- removing them from the parameter list. Then rewrite remaining @?@ +-- placeholders to @$1, $2, ...@. +inlineAndRewrite + :: ByteString + -> [PersistValue] + -> (ByteString, [PersistValue]) +inlineAndRewrite sql vals + | any isUnescaped vals = + let (sql1, vals1) = inlineUnescaped sql vals + in collapseAndRewrite sql1 vals1 + | otherwise = collapseAndRewrite sql vals + where + isUnescaped (PersistLiteral_ Unescaped _) = True + isUnescaped _ = False + +-- | Fused IN-clause collapsing and @?@ → @$N@ placeholder rewriting in a +-- single pass using a 'BB.Builder'. Combines the work of 'collapseInClauses' +-- and 'rewritePlaceholders' to avoid an intermediate 'ByteString' allocation. +collapseAndRewrite :: ByteString -> [PersistValue] -> (ByteString, [PersistValue]) +collapseAndRewrite sql params = + let (builder, _n, keptParams) = go sql params 0 mempty [] + in (LBS.toStrict (BB.toLazyByteString builder), reverse keptParams) + where + go :: ByteString -> [PersistValue] -> Int -> BB.Builder -> [PersistValue] + -> (BB.Builder, Int, [PersistValue]) + go !bs ps !n !acc keep + | B8.null bs = (acc, n, keep) + | otherwise = + let c = B8.head bs + rest = B8.tail bs + in case c of + '\'' -> let (lit, after) = skipStringLiteral rest + in go after ps n (acc <> BB.char8 '\'' <> BB.byteString lit) keep + + '"' -> let (ident, after) = skipQuotedIdent rest + in go after ps n (acc <> BB.char8 '"' <> BB.byteString ident) keep + + '-' | not (B8.null rest) && B8.head rest == '-' -> + let (comment, after) = B8.break (== '\n') rest + in go after ps n (acc <> BB.char8 '-' <> BB.byteString comment) keep + + '/' | not (B8.null rest) && B8.head rest == '*' -> + let (comment, after) = skipBlockComment (B8.tail rest) 1 + in go after ps n (acc <> BB.byteString "/*" <> BB.byteString comment) keep + + -- ?? → literal ? + '?' | not (B8.null rest) && B8.head rest == '?' -> + go (B8.tail rest) ps n (acc <> BB.char8 '?') keep + + -- ? → $N + '?' -> case ps of + [] -> go rest [] n (acc <> BB.char8 '?') keep + (p : ps') -> + let n' = n + 1 + in go rest ps' n' (acc <> BB.char8 '$' <> BB.intDec n') (p : keep) + + _ | matchesIn bs -> + tryCollapseIn bs ps n acc keep + | matchesNotIn bs -> + tryCollapseNotIn bs ps n acc keep + | otherwise -> + go rest ps n (acc <> BB.char8 c) keep + + tryCollapseIn bs ps n acc keep = + let afterIn = B8.drop 5 bs + (nQMarks, afterParens) = countQMarks afterIn + in if nQMarks >= 2 + then let (collapsed, remaining) = splitAt nQMarks ps + n' = n + 1 + in go afterParens remaining n' + (acc <> BB.byteString " = ANY(" <> BB.char8 '$' <> BB.intDec n' <> BB.char8 ')') + (PersistArray collapsed : keep) + else go (B8.tail bs) ps n (acc <> BB.char8 (B8.head bs)) keep + + tryCollapseNotIn bs ps n acc keep = + let afterNotIn = B8.drop 9 bs + (nQMarks, afterParens) = countQMarks afterNotIn + in if nQMarks >= 2 + then let (collapsed, remaining) = splitAt nQMarks ps + n' = n + 1 + in go afterParens remaining n' + (acc <> BB.byteString " <> ALL(" <> BB.char8 '$' <> BB.intDec n' <> BB.char8 ')') + (PersistArray collapsed : keep) + else go (B8.tail bs) ps n (acc <> BB.char8 (B8.head bs)) keep + +-- IN-clause detection helpers, shared by collapseInClauses and collapseAndRewrite. + +-- | Check if current position matches @\" IN (\"@ (case-insensitive). +matchesIn :: ByteString -> Bool +matchesIn bs = + B8.length bs >= 5 + && isSqlSpace (B8.index bs 0) + && (B8.index bs 1 == 'I' || B8.index bs 1 == 'i') + && (B8.index bs 2 == 'N' || B8.index bs 2 == 'n') + && B8.index bs 3 == ' ' + && B8.index bs 4 == '(' + +-- | Check if current position matches @\" NOT IN (\"@ (case-insensitive). +matchesNotIn :: ByteString -> Bool +matchesNotIn bs = + B8.length bs >= 9 + && isSqlSpace (B8.index bs 0) + && (B8.index bs 1 == 'N' || B8.index bs 1 == 'n') + && (B8.index bs 2 == 'O' || B8.index bs 2 == 'o') + && (B8.index bs 3 == 'T' || B8.index bs 3 == 't') + && B8.index bs 4 == ' ' + && (B8.index bs 5 == 'I' || B8.index bs 5 == 'i') + && (B8.index bs 6 == 'N' || B8.index bs 6 == 'n') + && B8.index bs 7 == ' ' + && B8.index bs 8 == '(' + +-- | SQL whitespace: space, newline, carriage return, tab. +isSqlSpace :: Char -> Bool +isSqlSpace w = w == ' ' || w == '\n' || w == '\r' || w == '\t' +{-# INLINE isSqlSpace #-} + +-- | Count consecutive @?@ marks separated by commas (with optional whitespace) +-- inside parentheses. Input starts after the opening @(@. +-- Returns @(count, rest after closing paren)@. +-- Returns @(0, original input)@ to abort collapsing on unexpected content. +countQMarks :: ByteString -> (Int, ByteString) +countQMarks = countLoop 0 + where + countLoop !n bs0 + | B8.null bs0 = (n, bs0) + | otherwise = + let bs = B8.dropWhile isSqlSpace bs0 + in if B8.null bs + then (n, bs) + else case B8.head bs of + '?' + | not (B8.null (B8.tail bs)) && B8.head (B8.tail bs) == '?' -> + -- ?? is a literal ?, not a param -- abort collapsing + (0, bs0) + | otherwise -> + let bs' = B8.tail bs + bs'' = B8.dropWhile isSqlSpace bs' + in if B8.null bs'' + then (n + 1, bs'') + else case B8.head bs'' of + ',' -> countLoop (n + 1) (B8.tail bs'') + ')' -> (n + 1, B8.tail bs'') + _ -> (0, bs0) -- unexpected char, abort + ')' -> (n, B8.tail bs) -- empty parens or trailing + _ -> (0, bs0) -- unexpected char, abort + +-- | Rewrite @IN (?,?,?)@ to @= ANY(?)@ and @NOT IN (?,?,?)@ to +-- @<> ALL(?)@, collapsing multiple individual parameters into a single +-- 'PersistList' parameter. This reduces the number of bind parameters +-- sent to PostgreSQL when 'PersistList' is encoded as a native array. +-- +-- Only collapses when there are 2 or more @?@ placeholders in the +-- IN clause. Single-element @IN (?)@ is left unchanged. +collapseInClauses :: ByteString -> [PersistValue] -> (ByteString, [PersistValue]) +collapseInClauses sql params = + let (builder, keptParams) = go sql params mempty [] + in (LBS.toStrict (BB.toLazyByteString builder), reverse keptParams) + where + go :: ByteString -> [PersistValue] -> BB.Builder -> [PersistValue] + -> (BB.Builder, [PersistValue]) + go !bs ps !acc keep + | B8.null bs = (acc, keep) + | otherwise = + let c = B8.head bs + rest = B8.tail bs + in case c of + -- String literal: copy through (don't match IN inside strings) + '\'' -> let (lit, after) = skipStringLiteral rest + in go after ps (acc <> BB.char8 '\'' <> BB.byteString lit) keep + + -- Quoted identifier: copy through + '"' -> let (ident, after) = skipQuotedIdent rest + in go after ps (acc <> BB.char8 '"' <> BB.byteString ident) keep + + -- Line comment + '-' | not (B8.null rest) && B8.head rest == '-' -> + let (comment, after) = B8.break (== '\n') rest + in go after ps (acc <> BB.char8 '-' <> BB.byteString comment) keep + + -- Block comment + '/' | not (B8.null rest) && B8.head rest == '*' -> + let (comment, after) = skipBlockComment (B8.tail rest) 1 + in go after ps (acc <> BB.byteString "/*" <> BB.byteString comment) keep + + -- ?? -> literal ?, not a parameter + '?' | not (B8.null rest) && B8.head rest == '?' -> + go (B8.tail rest) ps (acc <> BB.byteString "??") keep + + -- ? -> regular parameter, just pass through + '?' -> case ps of + [] -> go rest [] (acc <> BB.char8 '?') keep + (p : ps') -> go rest ps' (acc <> BB.char8 '?') (p : keep) + + -- Check for IN/NOT IN pattern + _ | matchesIn bs -> + tryCollapseIn bs ps acc keep + | matchesNotIn bs -> + tryCollapseNotIn bs ps acc keep + | otherwise -> + go rest ps (acc <> BB.char8 c) keep + + tryCollapseIn bs ps acc keep = + let afterIn = B8.drop 5 bs -- skip " IN (" + (nQMarks, afterParens) = countQMarks afterIn + in if nQMarks >= 2 + then let (collapsed, remaining) = splitAt nQMarks ps + in go afterParens remaining (acc <> BB.byteString " = ANY(?)") (PersistArray collapsed : keep) + else go (B8.tail bs) ps (acc <> BB.char8 (B8.head bs)) keep + + tryCollapseNotIn bs ps acc keep = + let afterNotIn = B8.drop 9 bs -- skip " NOT IN (" + (nQMarks, afterParens) = countQMarks afterNotIn + in if nQMarks >= 2 + then let (collapsed, remaining) = splitAt nQMarks ps + in go afterParens remaining (acc <> BB.byteString " <> ALL(?)") (PersistArray collapsed : keep) + else go (B8.tail bs) ps (acc <> BB.char8 (B8.head bs)) keep + +-- | Scan the SQL for @?@ placeholders and the parameter list for +-- 'PersistLiteral_ Unescaped' values. When a @?@ corresponds to an +-- Unescaped value, inline the literal bytes directly into the SQL text. +-- Non-Unescaped values are kept in the output parameter list. +-- +-- This handles @??@ (literal @?@ for column expansion) by passing it +-- through without consuming a parameter. +-- +-- Correctly skips string literals, quoted identifiers, line comments, +-- and block comments. +inlineUnescaped :: ByteString -> [PersistValue] -> (ByteString, [PersistValue]) +inlineUnescaped sql vals = + let (builder, keptParams) = go sql vals mempty [] + in (LBS.toStrict (BB.toLazyByteString builder), reverse keptParams) + where + go :: ByteString -> [PersistValue] -> BB.Builder -> [PersistValue] + -> (BB.Builder, [PersistValue]) + go bs params !acc keepParams + | B8.null bs = (acc, keepParams) + | otherwise = + let c = B8.head bs + rest = B8.tail bs + in case c of + -- String literal: copy through (properly handles '' escapes) + '\'' -> + let (lit, after) = skipStringLiteral rest + in go after params (acc <> BB.char8 '\'' <> BB.byteString lit) keepParams + + -- Quoted identifier: copy through (properly handles "" escapes) + '"' -> + let (ident, after) = skipQuotedIdent rest + in go after params (acc <> BB.char8 '"' <> BB.byteString ident) keepParams + + -- Line comment: copy through + '-' | not (B8.null rest) && B8.head rest == '-' -> + let (comment, after) = B8.break (== '\n') rest + in go after params (acc <> BB.char8 '-' <> BB.byteString comment) keepParams + + -- Block comment: copy through + '/' | not (B8.null rest) && B8.head rest == '*' -> + let (comment, after) = skipBlockComment (B8.tail rest) 1 + in go after params (acc <> BB.byteString "/*" <> BB.byteString comment) keepParams + + -- Placeholder + '?' -> + -- ?? -> literal ? + if not (B8.null rest) && B8.head rest == '?' + then go (B8.tail rest) params (acc <> BB.byteString "??") keepParams + -- ? -> check if param is Unescaped + else case params of + [] -> go rest [] (acc <> BB.char8 '?') keepParams + (PersistLiteral_ Unescaped literal : ps) -> + go rest ps (acc <> BB.byteString literal) keepParams + (p : ps) -> + go rest ps (acc <> BB.char8 '?') (p : keepParams) + + _ -> go rest params (acc <> BB.char8 c) keepParams + +-- Pipeline helpers +-- These functions implement the libpq pipeline result protocol: +-- Each query result: getResult → Just result, then getResult → Nothing (separator). +-- PipelineSync result: getResult → Just PipelineSync (no NULL terminator). + +-- | Drain one result from the pipeline via the lazy reply stream. +-- Pops and forces one element, frees the PGresult. +-- Returns @Just errorMsg@ on FatalError, @Nothing@ otherwise. +drainOneResult :: PgConn -> IO (Maybe ByteString) +drainOneResult pc = do + ret <- pgRecvResult pc -- pop lazy thunk + -- Force the thunk (triggers flush+read if needed) + st <- LibPQ.resultStatus ret + err <- case st of + LibPQ.FatalError -> do + merr <- LibPQ.resultErrorMessage ret + return (Just (fromMaybe "unknown error" merr)) + _ -> return Nothing + LibPQ.unsafeFreeResult ret + return err + +-- | Read one query result from the pipeline via the lazy reply stream. +-- Pops and forces one element. Returns the Result (caller must free). +-- Throws on error status. +readOneQueryResult :: PgConn -> IO LibPQ.Result +readOneQueryResult pc = do + ret <- pgRecvResult pc -- pop lazy thunk + -- Force the thunk + st <- LibPQ.resultStatus ret + case st of + LibPQ.CommandOk -> return ret + LibPQ.TuplesOk -> return ret + _ -> do + merr <- LibPQ.resultErrorMessage ret + LibPQ.unsafeFreeResult ret + syncPipeline pc + throwIO $ PipelineExecError (fromMaybe "query error in pipeline" merr) + +-- | Drain N pending results, collecting error messages. Does NOT throw. +drainNResults :: PgConn -> Int -> IO [ByteString] +drainNResults pc n = do + errs <- replicateM n (drainOneResult pc) + return [e | Just e <- errs] + +-- | Read a PipelineSync result. No NULL terminator follows it. +drainSyncResult :: PgConn -> IO () +drainSyncResult pc = do + mret <- pgGetResult pc + case mret of + Nothing -> throwIO $ PipelineModeError "expected PipelineSync, got NULL" + Just ret -> do + st <- LibPQ.resultStatus ret + LibPQ.unsafeFreeResult ret + case st of + LibPQ.PipelineSync -> return () + _ -> throwIO $ PipelineModeError $ + "expected PipelineSync, got " ++ show st + +-- | Drain everything (results and NULL separators) until PipelineSync. +-- Ignores all errors. Used by 'connRollback' and 'closePgConn'. +drainToSync :: PgConn -> IO () +drainToSync pc = do + mret <- pgGetResult pc + case mret of + Nothing -> drainToSync pc -- NULL separator between results, keep reading + Just ret -> do + st <- LibPQ.resultStatus ret + LibPQ.unsafeFreeResult ret + case st of + LibPQ.PipelineSync -> return () + _ -> drainToSync pc + +-- | Send a pipeline sync and drain to the sync point, clearing any +-- aborted-pipeline state. Best-effort: if the sync itself fails to +-- send we silently continue (the subsequent throw is more important). +syncPipeline :: PgConn -> IO () +syncPipeline pc = do + ok <- pgPipelineSync pc + when ok $ drainToSync pc + +-- | Drain all pending fire-and-forget results. Throws if any failed. +-- Used by 'withStmt'' before sending a query that needs results. +-- +-- On error, a 'pipelineSync' is issued and drained before throwing so +-- that the pipeline is left in a clean (non-aborted) state. This +-- allows callers who catch 'PipelineExecError' to continue issuing +-- commands (e.g. @ROLLBACK TO SAVEPOINT@) without hitting +-- @PGRES_PIPELINE_ABORTED@. +-- | Drain all pending results (both fire-and-forget and lazy reads). +-- Forces each thunk (triggering flush+read via the lazy stream cascade), +-- checks for errors, and frees the PGresult. Double-free is safe: +-- 'LibPQ.Result' is a 'ForeignPtr', so 'unsafeFreeResult' is idempotent. +-- If the caller's decode thunk already freed the result, this is a no-op. +drainPending :: PgConn -> IO () +drainPending pc = do + pending <- atomicModifyIORef (pgPending pc) (\ps -> ([], ps)) + case pending of + [] -> return () + _ -> do + _ <- pgSendFlushRequest pc + pgFlush pc + errors <- forM (reverse pending) $ \ret -> do + st <- LibPQ.resultStatus ret + err <- case st of + LibPQ.FatalError -> do + merr <- LibPQ.resultErrorMessage ret + return (Just (fromMaybe "unknown error" merr)) + _ -> return Nothing + LibPQ.unsafeFreeResult ret + return err + case [e | Just e <- errors] of + (e:_) -> do + syncPipeline pc + throwIO $ PipelineExecError e + [] -> return () + +-- | Execute a statement in pipeline mode (fire-and-forget). +-- Sends the query, pops a lazy reply from the stream, and appends +-- it to 'pgPending'. The result is forced at the next drain point +-- (read operation, commit, or rollback). +execute' :: PgConn -> ByteString -> [PersistValue] -> IO Int64 +execute' pc sql vals = do + let (rewrittenSQL, remainingVals) = inlineAndRewrite sql vals + params = map encodePersistValue remainingVals + ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary + unless ok $ do + merr <- LibPQ.errorMessage (pgConn pc) + throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr) + lazyReply <- pgRecvResult pc + modifyIORef (pgPending pc) (lazyReply :) + return 0 + +-- | Execute a query in pipeline mode and stream the results as a conduit. +-- First drains all pending fire-and-forget results, then sends the query +-- and reads its result. +-- +-- The fetch strategy depends on 'pgFetchMode': +-- +-- * 'FetchAll' (default): reads the entire result into one @PGresult@ +-- before yielding any rows. Memory = O(result set). +-- +-- * 'FetchSingleRow': activates @PQsetSingleRowMode@ so that each +-- @PGresult@ contains exactly one row. The conduit fetches, yields, +-- and frees one @PGresult@ at a time. Memory = O(1 row). +-- +-- * 'FetchChunked N': activates @PQsetChunkedRowsMode(N)@ (PG17+) so +-- that each @PGresult@ contains up to N rows. Same streaming +-- discipline as single-row mode but with lower per-row allocation +-- overhead. Memory = O(N rows). +withStmt' + :: (MonadIO m) + => PgConn + -> ByteString + -> [PersistValue] + -> Acquire (ConduitM () [PersistValue] m ()) +withStmt' pc sql vals = + case pgFetchMode pc of + FetchAll -> withStmtDefault pc sql vals + FetchSingleRow -> withStmtStreaming pc sql vals + FetchChunked _ -> withStmtStreaming pc sql vals + +------------------------------------------------------------------------------- +-- Default (FetchAll): entire result in one PGresult +------------------------------------------------------------------------------- + +-- | Default fetch: read entire result into memory, then stream rows from it. +withStmtDefault + :: (MonadIO m) + => PgConn + -> ByteString + -> [PersistValue] + -> Acquire (ConduitM () [PersistValue] m ()) +withStmtDefault pc sql vals = + pull `fmap` mkAcquire openS closeS + where + openS = do + drainPending pc + let (rewrittenSQL, remainingVals) = inlineAndRewrite sql vals + params = map encodePersistValue remainingVals + ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary + unless ok $ do + merr <- LibPQ.errorMessage (pgConn pc) + throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr) + -- readOneQueryResult pops from the lazy reply stream. + -- The stream thunk handles flush+read internally. + ret <- readOneQueryResult pc + cols <- LibPQ.nfields ret + oids <- ireplicateM cols $ \col -> do + oid <- LibPQ.ftype ret col + return (col, oid) + rowCount <- LibPQ.ntuples ret + return (ret, rowCount, oids) + + closeS (ret, _, _) = LibPQ.unsafeFreeResult ret + + pull (ret, rowCount, oids) = go (LibPQ.Row 0) + where + go !row + | row == rowCount = return () + | otherwise = do + vals' <- liftIO $ forM oids $ \(col, oid) -> do + mbs <- LibPQ.getvalue' ret row col + case decodePersistValue oid mbs of + Left err -> fail $ "decodePersistValue: " ++ T.unpack err + Right val -> return val + yield vals' + go (row + 1) + +------------------------------------------------------------------------------- +-- Streaming (FetchSingleRow / FetchChunked): incremental PGresult fetching +------------------------------------------------------------------------------- + +-- | Streaming fetch: the conduit drives libpq I/O, fetching one small +-- @PGresult@ at a time, yielding its rows, freeing it, then fetching the +-- next. At most one @PGresult@ is resident at any time. +withStmtStreaming + :: (MonadIO m) + => PgConn + -> ByteString + -> [PersistValue] + -> Acquire (ConduitM () [PersistValue] m ()) +withStmtStreaming pc sql vals = + pull `fmap` mkAcquire openS closeS + where + openS = do + drainPending pc + let (rewrittenSQL, remainingVals) = inlineAndRewrite sql vals + params = map encodePersistValue remainingVals + ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary + unless ok $ do + merr <- LibPQ.errorMessage (pgConn pc) + throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr) + -- Activate row-fetch mode immediately after sendQueryParams + activateFetchMode pc + _ <- pgSendFlushRequest pc + pgFlush pc + -- Return an IORef that tracks whether we have fully consumed + -- the result stream. closeS uses this for cleanup. + doneRef <- newIORef False + return doneRef + + closeS doneRef = do + done <- readIORef doneRef + unless done $ drainStreamRemainder pc + + -- The conduit body: all libpq result I/O happens here, interleaved + -- with yield. This is what makes it a true streaming source. + pull doneRef = do + mfirst <- liftIO $ pgGetResult pc + case mfirst of + Nothing -> do + -- NULL right away = empty result (protocol: no rows) + liftIO $ writeIORef doneRef True + Just first -> do + st <- liftIO $ rawResultStatusInt first + -- Extract column metadata from the first result. + -- All subsequent results share the same column layout. + oids <- liftIO $ getColumnOids first + -- Process the first result and enter the streaming loop + case () of + _ | st == PGRES_SINGLE_TUPLE || st == PGRES_TUPLES_CHUNK -> do + yieldRowsAndFree first oids + streamLoop oids doneRef + | st == PGRES_TUPLES_OK -> do + -- Zero-row result: the query returned no data. + liftIO $ LibPQ.unsafeFreeResult first + -- Consume the NULL separator + _ <- liftIO $ pgGetResult pc + liftIO $ writeIORef doneRef True + | st == PGRES_FATAL_ERROR -> do + merr <- liftIO $ LibPQ.resultErrorMessage first + liftIO $ LibPQ.unsafeFreeResult first + -- Consume NULL separator + _ <- liftIO $ pgGetResult pc + liftIO $ syncPipeline pc + liftIO $ writeIORef doneRef True + liftIO $ throwIO $ PipelineExecError + (fromMaybe "query error in pipeline (streaming)" merr) + | otherwise -> do + liftIO $ LibPQ.unsafeFreeResult first + _ <- liftIO $ pgGetResult pc + liftIO $ writeIORef doneRef True + liftIO $ throwIO $ PipelineExecError $ + "unexpected result status in streaming mode: " + <> B8.pack (show st) + + -- Core streaming loop: fetch next PGresult, yield its rows, free it. + -- Each PGresult is freed BEFORE fetching the next one. + streamLoop oids doneRef = do + mret <- liftIO $ pgGetResult pc + case mret of + Nothing -> do + -- NULL separator: the query's result stream is fully consumed. + liftIO $ writeIORef doneRef True + Just ret -> do + st <- liftIO $ rawResultStatusInt ret + case () of + _ | st == PGRES_SINGLE_TUPLE || st == PGRES_TUPLES_CHUNK -> do + yieldRowsAndFree ret oids + streamLoop oids doneRef + | st == PGRES_TUPLES_OK -> do + -- End marker (0 rows). Free and consume NULL separator. + liftIO $ LibPQ.unsafeFreeResult ret + _ <- liftIO $ pgGetResult pc + liftIO $ writeIORef doneRef True + | st == PGRES_FATAL_ERROR -> do + merr <- liftIO $ LibPQ.resultErrorMessage ret + liftIO $ LibPQ.unsafeFreeResult ret + _ <- liftIO $ pgGetResult pc + liftIO $ syncPipeline pc + liftIO $ writeIORef doneRef True + liftIO $ throwIO $ PipelineExecError + (fromMaybe "query error in pipeline (streaming)" merr) + | otherwise -> do + liftIO $ LibPQ.unsafeFreeResult ret + _ <- liftIO $ pgGetResult pc + liftIO $ writeIORef doneRef True + liftIO $ throwIO $ PipelineExecError $ + "unexpected result status in streaming mode: " + <> B8.pack (show st) + +-- | Extract column (index, OID) pairs from a result. +getColumnOids :: LibPQ.Result -> IO [(LibPQ.Column, LibPQ.Oid)] +getColumnOids ret = do + cols <- LibPQ.nfields ret + ireplicateM cols $ \col -> do + oid <- LibPQ.ftype ret col + return (col, oid) + +-- | Yield all rows from a @PGresult@ via the conduit, then free the result. +yieldRowsAndFree + :: (MonadIO m) + => LibPQ.Result + -> [(LibPQ.Column, LibPQ.Oid)] + -> ConduitM () [PersistValue] m () +yieldRowsAndFree ret oids = do + rowCount <- liftIO $ LibPQ.ntuples ret + go (LibPQ.Row 0) rowCount + liftIO $ LibPQ.unsafeFreeResult ret + where + go !row rowCount + | row == rowCount = return () + | otherwise = do + vals' <- liftIO $ forM oids $ \(col, oid) -> do + mbs <- LibPQ.getvalue' ret row col + case decodePersistValue oid mbs of + Left err -> fail $ "decodePersistValue: " ++ T.unpack err + Right val -> return val + yield vals' + go (row + 1) rowCount + +-- | Extract column metadata as a 'V.Vector' of @(Column, PgType)@ pairs +-- for use with 'PgRowEnv'. +getColumnPgTypes :: LibPQ.Result -> IO (V.Vector (LibPQ.Column, PgType)) +getColumnPgTypes ret = do + cols <- LibPQ.nfields ret + V.fromList <$> ireplicateM cols (\col -> do + oid <- LibPQ.ftype ret col + return (col, fromOid oid)) + +-- | Yield all rows from a @PGresult@ decoded directly via 'FromRow', +-- then free the result. No 'PersistValue' intermediary. +-- +-- Uses 'prepareRow' to resolve all 'FieldRunner's once from column +-- metadata, then applies the prepared 'RowDecoder' to each row. +yieldRowsDirectAndFree + :: (FromRow PgRowEnv a, MonadIO m) + => LibPQ.Result + -> V.Vector (LibPQ.Column, PgType) + -> OidCache + -> ConduitM () a m () +yieldRowsDirectAndFree ret colTypes cache = do + rowCount <- liftIO $ LibPQ.ntuples ret + case rowCount of + LibPQ.Row 0 -> return () + _ -> do + let metaEnv = PgRowEnv ret (LibPQ.Row 0) colTypes cache + ctr <- liftIO newCounter + decoder <- liftIO $ prepareRow metaEnv ctr + (\e -> throwIO $ PipelineExecError $ T.encodeUtf8 e) + pure + go decoder (LibPQ.Row 0) rowCount + liftIO $ LibPQ.unsafeFreeResult ret + where + go !decoder !row !rowCount + | row == rowCount = return () + | otherwise = do + let env = PgRowEnv ret row colTypes cache + val <- liftIO $ runRowDecoderCPS decoder env + (throwIO . PipelineExecError . T.encodeUtf8) + pure + yield val + go decoder (row + 1) rowCount + +-- | Send a query, handling both @?@ and @$N@ parameter styles. +-- +-- If the SQL contains @$N@ placeholders (detected by 'detectParamStyle'), +-- it is passed through to libpq as-is with the parameter array indexed by +-- @$N@ number. Otherwise, @?@ placeholders are rewritten to @$1, $2, ..@ +-- via 'inlineAndRewrite'. +sendQueryWithParams + :: PgConn + -> ByteString + -> [PersistValue] + -> IO Bool +sendQueryWithParams pc sql vals = case detectParamStyle sql of + QuestionMarkParams -> + let (rewritten, remaining) = inlineAndRewrite sql vals + params = map encodePersistValue remaining + in pgSendQueryParams pc rewritten params LibPQ.Binary + NumberedParams _maxN -> + let params = map encodePersistValue vals + in pgSendQueryParams pc sql params LibPQ.Binary + +--------------------------------------------------------------------------- +-- HasDirectQuery instance +--------------------------------------------------------------------------- + +instance HasDirectQuery (WriteBackend PostgreSQLBackend) where + type Env (WriteBackend PostgreSQLBackend) = PgRowEnv + type Param (WriteBackend PostgreSQLBackend) = PgParam + directQuerySource backend sql params = + case getPipelineConn backend of + Nothing -> error "persistent-postgresql-ng: no PgConn in vault" + Just pc -> + let sqlBS = T.encodeUtf8 sql + in pull `fmap` mkAcquire (openDQ pc sqlBS) closeDQ + where + paramList = V.toList params + + openDQ pc sqlBS = do + drainPending pc + cache <- readIORef (pgOidCache pc) + ok <- sendDirectParams pc sqlBS paramList + unless ok $ do + merr <- LibPQ.errorMessage (pgConn pc) + throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr) + _ <- pgSendFlushRequest pc + pgFlush pc + ret <- readOneQueryResult pc + colTypes <- getColumnPgTypes ret + rowCount <- LibPQ.ntuples ret + return (ret, rowCount, colTypes, cache) + + closeDQ (ret, _, _, _) = LibPQ.unsafeFreeResult ret + + pull (ret, rowCount, colTypes, cache) = go (LibPQ.Row 0) + where + go !row + | row == rowCount = return () + | otherwise = do + yield (PgRowEnv ret row colTypes cache) + go (row + 1) + +instance HasDirectQuery (ReadBackend PostgreSQLBackend) where + type Env (ReadBackend PostgreSQLBackend) = PgRowEnv + type Param (ReadBackend PostgreSQLBackend) = PgParam + directQuerySource backend sql params = + directQuerySource (writeFromRead backend) sql params + where + writeFromRead :: ReadBackend PostgreSQLBackend -> WriteBackend PostgreSQLBackend + writeFromRead (ReadBackend pg) = WriteBackend pg + +-- | Send a query with pre-encoded parameters (the direct-encode path). +-- +-- Handles both @?@ and @$N@ placeholder styles: if @?@ placeholders are +-- detected, they are rewritten to @$1, $2, ...@ via 'rewritePlaceholders'. +-- Unlike 'sendQueryWithParams', no IN-clause collapsing or unescaped +-- literal inlining is performed -- the direct path assumes the caller +-- provides final SQL and properly typed parameters. +sendDirectParams + :: PgConn + -> ByteString + -> [PgParam] + -> IO Bool +sendDirectParams pc sql params = case detectParamStyle sql of + QuestionMarkParams -> + let (rewritten, _count) = rewritePlaceholders sql + in pgSendQueryParams pc rewritten (map pgParamToLibPQ params) LibPQ.Binary + NumberedParams _maxN -> + pgSendQueryParams pc sql (map pgParamToLibPQ params) LibPQ.Binary + +-- | Activate the appropriate row-fetch mode for the query just sent. +-- Must be called immediately after 'pgSendQueryParams', before any +-- other operation on the connection. +activateFetchMode :: PgConn -> IO () +activateFetchMode pc = + case pgFetchMode pc of + FetchAll -> return () + FetchSingleRow -> do + ok <- LibPQ.setSingleRowMode (pgConn pc) + unless ok $ + throwIO $ PipelineExecError "failed to set single-row mode" + FetchChunked n -> do + ok <- setChunkedRowsMode (pgConn pc) n + unless ok $ + throwIO $ PipelineExecError "failed to set chunked rows mode" + +-- | Drain any remaining results from a streaming query. Used by closeS +-- when the conduit consumer stops pulling before the result stream is +-- exhausted (e.g. @take 5@ on a large query). +drainStreamRemainder :: PgConn -> IO () +drainStreamRemainder pc = go + where + go = do + mret <- pgGetResult pc + case mret of + Nothing -> return () -- NULL separator reached, done + Just ret -> do + st <- rawResultStatusInt ret + LibPQ.unsafeFreeResult ret + case () of + _ | st == PGRES_SINGLE_TUPLE || st == PGRES_TUPLES_CHUNK -> + go -- more data, keep draining + | st == PGRES_TUPLES_OK -> do + -- End marker. Consume the NULL separator. + _ <- pgGetResult pc + return () + | otherwise -> + -- Error or unexpected status; consume NULL separator + -- and stop. Pipeline sync will be handled elsewhere. + do _ <- pgGetResult pc + return () + +-- Index-based monadic iteration, avoiding intermediate list allocation. + +ireplicateM :: (Monad m, Eq i, Num i) => i -> (i -> m a) -> m [a] +ireplicateM n f = go 0 + where + go i + | i == n = pure [] + | otherwise = f i >>= \a -> (a :) <$> go (i + 1) +{-# INLINABLE ireplicateM #-} + +-- SQL generation functions, adapted from existing persistent-postgresql + +insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult +insertSql' ent vals = + case getEntityId ent of + EntityIdNaturalKey _pdef -> + ISRManyKeys sql vals + EntityIdField field -> + ISRSingle (sql <> " RETURNING " <> escapeF (fieldDB field)) + where + (fieldNames, placeholders) = unzip (Util.mkInsertPlaceholders ent escapeF) + sql = + T.concat + [ "INSERT INTO " + , escapeE $ getEntityDBName ent + , if null (getEntityFields ent) + then " DEFAULT VALUES" + else + T.concat + [ "(" + , T.intercalate "," fieldNames + , ") VALUES(" + , T.intercalate "," placeholders + , ")" + ] + ] + +upsertSql' :: EntityDef -> NonEmpty (FieldNameHS, FieldNameDB) -> Text -> Text +upsertSql' ent uniqs updateVal = + T.concat + [ "INSERT INTO " + , escapeE (getEntityDBName ent) + , "(" + , T.intercalate "," fieldNames + , ") VALUES (" + , T.intercalate "," placeholders + , ") ON CONFLICT (" + , T.intercalate "," $ map (escapeF . snd) (NEL.toList uniqs) + , ") DO UPDATE SET " + , updateVal + , " WHERE " + , wher + , " RETURNING ??" + ] + where + (fieldNames, placeholders) = unzip (Util.mkInsertPlaceholders ent escapeF) + wher = T.intercalate " AND " $ map (singleClause . snd) $ NEL.toList uniqs + singleClause :: FieldNameDB -> Text + singleClause field = escapeE (getEntityDBName ent) <> "." <> (escapeF field) <> " =?" + +insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult +insertManySql' ent valss = + ISRSingleCustom sql params + where + fields = filter isFieldNotGenerated (getEntityFields ent) + fieldNames = map (escapeF . fieldDB) fields + typeCasts = map (unnestTypeCast . fieldSqlType) fields + unnestArgs = map ("?" <>) typeCasts + sql = + T.concat + [ "INSERT INTO " + , escapeE (getEntityDBName ent) + , "(" + , T.intercalate "," fieldNames + , ") SELECT * FROM UNNEST(" + , T.intercalate "," unnestArgs + , ") RETURNING " + , Util.commaSeparated $ NEL.toList $ Util.dbIdColumnsEsc escapeF ent + ] + params = map PersistArray (transpose valss) + +-- | Hook for 'connInsertMany_Custom': UNNEST-based bulk insert (no RETURNING). +insertMany_Hook :: EntityDef -> [[PersistValue]] -> (Text, [PersistValue]) +insertMany_Hook ent valss = (unnestInsertSql ent False, map PersistArray (transpose valss)) + +-- | Hook for 'connPutManyCustom': UNNEST-based bulk upsert. +putManyHook :: EntityDef -> [[PersistValue]] -> (Text, [PersistValue]) +putManyHook ent valss = + (sql, map PersistArray (transpose valss)) + where + fields = filter isFieldNotGenerated (getEntityFields ent) + fieldDbToText = escapeF . fieldDB + table = escapeE (getEntityDBName ent) + columns = map fieldDbToText fields + typeCasts = map (unnestTypeCast . fieldSqlType) fields + unnestArgs = map ("?" <>) typeCasts + conflictColumns = + concatMap + (map (escapeF . snd) . NEL.toList . uniqueFields) + (getEntityUniques ent) + updates = map (\c -> c <> "=EXCLUDED." <> c) columns + sql = T.concat + [ "INSERT INTO ", table + , "(", T.intercalate "," columns, ")" + , " SELECT * FROM UNNEST(" + , T.intercalate "," unnestArgs + , ")" + , " ON CONFLICT (" + , T.intercalate "," conflictColumns + , ") DO UPDATE SET " + , T.intercalate "," updates + ] + +-- | Hook for 'connRepsertManyCustom': UNNEST-based bulk repsert (conflict on PK). +repsertManyHook :: EntityDef -> [[PersistValue]] -> (Text, [PersistValue]) +repsertManyHook ent valss = + (sql, map PersistArray (transpose valss)) + where + allFields = NEL.toList $ keyAndEntityFields ent + fields = filter isFieldNotGenerated allFields + fieldDbToText = escapeF . fieldDB + table = escapeE (getEntityDBName ent) + columns = map fieldDbToText fields + typeCasts = map (unnestTypeCast . fieldSqlType) fields + unnestArgs = map ("?" <>) typeCasts + conflictColumns = + NEL.toList $ escapeF . fieldDB <$> getEntityKeyFields ent + keyFieldNames = map (escapeF . fieldDB) (NEL.toList $ getEntityKeyFields ent) + updates = + map (\c -> c <> "=EXCLUDED." <> c) $ + filter (`notElem` keyFieldNames) columns + sql = T.concat + [ "INSERT INTO ", table + , "(", T.intercalate "," columns, ")" + , " SELECT * FROM UNNEST(" + , T.intercalate "," unnestArgs + , ")" + , " ON CONFLICT (" + , T.intercalate "," conflictColumns + , ") DO UPDATE SET " + , case updates of + [] -> case columns of + (c:_) -> c <> "=" <> table <> "." <> c + _ -> error "repsertManyHook: entity has no columns" + _ -> T.intercalate "," updates + ] + +migrate' + :: BackendSpecificOverrides + -> [EntityDef] + -> (Text -> IO Statement) + -> EntityDef + -> IO (Either [Text] CautiousMigration) +migrate' overrides allDefs getter entity = + fmap (fmap $ map showAlterDb) $ + migrateStructured overrides allDefs getter entity + +-- | Version comparison for upsert support (>= 9.5) +upsertFunction :: a -> NonEmpty Word -> Maybe a +upsertFunction f version = + if version >= postgres9dot5 + then Just f + else Nothing + where + postgres9dot5 :: NonEmpty Word + postgres9dot5 = 9 NEL.:| [5] + +-- putManySql / repsertManySql + +putManySql :: EntityDef -> Int -> Text +putManySql ent n = putManySql' conflictColumns fields ent n + where + fields = getEntityFields ent + conflictColumns = + concatMap + (map (escapeF . snd) . NEL.toList . uniqueFields) + (getEntityUniques ent) + +repsertManySql :: EntityDef -> Int -> Text +repsertManySql ent n = putManySql' conflictColumns fields ent n + where + fields = NEL.toList $ keyAndEntityFields ent + conflictColumns = NEL.toList $ escapeF . fieldDB <$> getEntityKeyFields ent + +putManySql' :: [Text] -> [FieldDef] -> EntityDef -> Int -> Text +putManySql' conflictColumns (filter isFieldNotGenerated -> fields) ent n = q + where + fieldDbToText = escapeF . fieldDB + mkAssignment f = T.concat [f, "=EXCLUDED.", f] + + table = escapeE . getEntityDBName $ ent + columns = Util.commaSeparated $ map fieldDbToText fields + placeholders = map (const "?") fields + updates = map (mkAssignment . fieldDbToText) fields + + q = + T.concat + [ "INSERT INTO " + , table + , Util.parenWrapped columns + , " VALUES " + , Util.commaSeparated + . replicate n + . Util.parenWrapped + . Util.commaSeparated + $ placeholders + , " ON CONFLICT " + , Util.parenWrapped . Util.commaSeparated $ conflictColumns + , " DO UPDATE SET " + , Util.commaSeparated updates + ] + +------------------------------------------------------------------------------- +-- Optimized batch operations +-- +-- These standalone functions generate more efficient SQL than the default +-- persistent typeclass methods. They use PostgreSQL-specific features +-- (= ANY for key lookups, UNNEST for columnar inserts) to minimize bind +-- parameters and improve query plan caching. +------------------------------------------------------------------------------- + +-- | Retrieve multiple records by key using a single array parameter. +-- +-- Generates @SELECT ... WHERE id = ANY($1)@ instead of the default +-- @WHERE id=? OR id=? OR ...@, reducing bind parameters from N to 1 +-- and improving query plan caching (the query shape is constant +-- regardless of how many keys are requested). +-- +-- Only supports single-column keys (the common auto-increment case). +-- For composite keys, falls back to the default 'getMany'. +getManyKeys + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + ) + => [Key record] + -> ReaderT SqlBackend m (Map.Map (Key record) record) +getManyKeys [] = return Map.empty +getManyKeys ks = do + let t = entityDef (Nothing :: Maybe record) + idCols = NEL.toList $ Util.dbIdColumnsEsc escapeF t + case idCols of + [idCol] -> do + let sql = T.concat + [ "SELECT ?? FROM " + , escapeE (getEntityDBName t) + , " WHERE " + , idCol + , " = ANY(?)" + ] + keyVals = map (headErr . keyToValues) ks + entities <- rawSql sql [PersistArray keyVals] + return $ Map.fromList $ + map (\(e :: Entity record) -> (entityKey e, entityVal e)) entities + _ -> + -- Composite key: fall back to persistent's default OR-based approach + getMany ks + where + headErr [] = error "getManyKeys: keyToValues returned empty list" + headErr (x:_) = x + +-- | Delete multiple records by key using a single array parameter. +-- +-- Generates @DELETE FROM ... WHERE id = ANY($1)@ instead of issuing +-- N individual @DELETE@ statements. In pipeline mode this is +-- fire-and-forget, so N deletes that would normally be N pipelined +-- commands become a single command. +-- +-- Only supports single-column keys. For composite keys, falls back +-- to individual deletes. +deleteManyKeys + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + ) + => [Key record] + -> ReaderT SqlBackend m () +deleteManyKeys [] = return () +deleteManyKeys ks = do + let t = entityDef (Nothing :: Maybe record) + idCols = NEL.toList $ Util.dbIdColumnsEsc escapeF t + case idCols of + [idCol] -> do + let sql = T.concat + [ "DELETE FROM " + , escapeE (getEntityDBName t) + , " WHERE " + , idCol + , " = ANY(?)" + ] + keyVals = map (headErr . keyToValues) ks + rawExecute sql [PersistArray keyVals] + _ -> + mapM_ delete ks + where + headErr [] = error "deleteManyKeys: keyToValues returned empty list" + headErr (x:_) = x + +-- | Insert multiple records using PostgreSQL @UNNEST@ for columnar +-- parameter binding. Returns the generated keys. +-- +-- Generates: +-- +-- @ +-- INSERT INTO t (col1, col2, ...) +-- SELECT * FROM UNNEST(?::type1[], ?::type2[], ...) +-- RETURNING id +-- @ +-- +-- This uses M array parameters (one per column) instead of N*M scalar +-- parameters (one per field per row), avoiding PostgreSQL's per-query +-- parameter limit and improving throughput for large batches. +insertManyUnnest + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + , SafeToInsert record + ) + => [record] + -> ReaderT SqlBackend m [Key record] +insertManyUnnest [] = return [] +insertManyUnnest records = do + let t = entityDef (Nothing :: Maybe record) + sql = unnestInsertSql t True + params = unnestInsertParams records + rawSql sql params + +-- | Like 'insertManyUnnest' but does not return keys. +-- +-- In pipeline mode this is fire-and-forget: the @INSERT@ is sent without +-- waiting for a response, and errors surface at the next drain point. +insertManyUnnest_ + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + , SafeToInsert record + ) + => [record] + -> ReaderT SqlBackend m () +insertManyUnnest_ [] = return () +insertManyUnnest_ records = do + let t = entityDef (Nothing :: Maybe record) + sql = unnestInsertSql t False + params = unnestInsertParams records + rawExecute sql params + +-- | Upsert multiple records using @UNNEST@ + @ON CONFLICT DO UPDATE SET@. +-- +-- Generates: +-- +-- @ +-- INSERT INTO t (cols) +-- SELECT * FROM UNNEST(?::type1[], ?::type2[], ...) +-- ON CONFLICT (unique_cols) DO UPDATE SET col1=EXCLUDED.col1, ... +-- @ +-- +-- Uses M array parameters (one per column) for the entire batch. +-- Falls back to 'insertManyUnnest_' if the entity has no unique +-- constraints. +putManyUnnest + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + , SafeToInsert record + ) + => [record] + -> ReaderT SqlBackend m () +putManyUnnest [] = return () +putManyUnnest records = do + let t = entityDef (Nothing :: Maybe record) + uniqs = getEntityUniques t + case uniqs of + [] -> insertManyUnnest_ records + _ -> do + let fields = filter isFieldNotGenerated (getEntityFields t) + fieldDbToText = escapeF . fieldDB + table = escapeE (getEntityDBName t) + columns = map fieldDbToText fields + typeCasts = map (unnestTypeCast . fieldSqlType) fields + unnestArgs = map ("?" <>) typeCasts + conflictColumns = + concatMap + (map (escapeF . snd) . NEL.toList . uniqueFields) + uniqs + updates = map (\c -> c <> "=EXCLUDED." <> c) columns + sql = T.concat + [ "INSERT INTO ", table + , "(", T.intercalate "," columns, ")" + , " SELECT * FROM UNNEST(" + , T.intercalate "," unnestArgs + , ")" + , " ON CONFLICT (" + , T.intercalate "," conflictColumns + , ") DO UPDATE SET " + , T.intercalate "," updates + ] + params = unnestInsertParams records + rawExecute sql params + +-- | Repsert (replace-or-insert) multiple records using @UNNEST@ + +-- @ON CONFLICT (primary key) DO UPDATE SET@. +-- +-- Unlike 'putManyUnnest' which conflicts on unique constraints, this +-- conflicts on the primary key columns and includes them in the +-- @INSERT@. +-- +-- Uses M array parameters (one per column) for the entire batch. +repsertManyUnnest + :: forall record m + . ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + ) + => [(Key record, record)] + -> ReaderT SqlBackend m () +repsertManyUnnest [] = return () +repsertManyUnnest krs = do + let t = entityDef (Nothing :: Maybe record) + allFields = NEL.toList $ keyAndEntityFields t + fields = filter isFieldNotGenerated allFields + fieldDbToText = escapeF . fieldDB + table = escapeE (getEntityDBName t) + columns = map fieldDbToText fields + typeCasts = map (unnestTypeCast . fieldSqlType) fields + unnestArgs = map (\tc -> "?" <> tc) typeCasts + conflictColumns = + NEL.toList $ escapeF . fieldDB <$> getEntityKeyFields t + -- Update all non-key fields on conflict + keyFieldNames = map fieldDB (NEL.toList $ getEntityKeyFields t) + updates = + map (\c -> c <> "=EXCLUDED." <> c) $ + filter (\c -> escapeF `fmap` (lookup' c keyFieldNames) == Nothing) columns + sql = T.concat + [ "INSERT INTO ", table + , "(", T.intercalate "," columns, ")" + , " SELECT * FROM UNNEST(" + , T.intercalate "," unnestArgs + , ") ON CONFLICT (" + , T.intercalate "," conflictColumns + , ") DO UPDATE SET " + , case updates of + [] -> case columns of + (c:_) -> c <> "=" <> table <> "." <> c + [] -> error "repsertManyUnnest: entity has no columns" + _ -> T.intercalate "," updates + ] + -- Build values: for each (key, record) pair, produce key columns ++ entity columns + -- matching keyAndEntityFields order (minus generated columns). + rowValues = map (repsertRowValues t) krs + params = map PersistArray (transpose rowValues) + rawExecute sql params + where + -- Check if any FieldNameDB matches (by comparing the raw DB name) + lookup' :: Text -> [FieldNameDB] -> Maybe FieldNameDB + lookup' _ [] = Nothing + lookup' col (fn:fns) + | escapeF fn == col = Just fn + | otherwise = lookup' col fns + + repsertRowValues :: EntityDef -> (Key record, record) -> [PersistValue] + repsertRowValues ent (k, r) = + case entityPrimary ent of + Nothing -> keyToValues k <> Util.mkInsertValues r + Just _ -> Util.mkInsertValues r + +-- | Generate @INSERT ... SELECT * FROM UNNEST(...)@ SQL, optionally +-- with a @RETURNING@ clause. +unnestInsertSql :: EntityDef -> Bool -> Text +unnestInsertSql t withReturning = + T.concat + [ "INSERT INTO " + , escapeE (getEntityDBName t) + , "(" + , T.intercalate "," fieldNames + , ") SELECT * FROM UNNEST(" + , T.intercalate "," unnestArgs + , ")" + , if withReturning + then " RETURNING " + <> (T.intercalate "," . NEL.toList $ + Util.dbIdColumnsEsc escapeF t) + else "" + ] + where + fields = filter isFieldNotGenerated (getEntityFields t) + fieldNames = map (escapeF . fieldDB) fields + typeCasts = map (unnestTypeCast . fieldSqlType) fields + unnestArgs = map (\tc -> "?" <> tc) typeCasts + +-- | Transpose record values from row-major to column-major and wrap +-- each column in a 'PersistArray'. +unnestInsertParams :: (PersistEntity record) => [record] -> [PersistValue] +unnestInsertParams records = + let rowValues = map Util.mkInsertValues records + in map PersistArray (transpose rowValues) + +-- | Map a 'SqlType' to a PostgreSQL array type cast for @UNNEST@ +-- parameters, e.g. @SqlInt64@ becomes @\"::int8[]\"@. +unnestTypeCast :: SqlType -> Text +unnestTypeCast SqlString = "::varchar[]" +unnestTypeCast SqlInt32 = "::int4[]" +unnestTypeCast SqlInt64 = "::int8[]" +unnestTypeCast SqlReal = "::float8[]" +unnestTypeCast (SqlNumeric _ _) = "::numeric[]" +unnestTypeCast SqlBool = "::boolean[]" +unnestTypeCast SqlDay = "::date[]" +unnestTypeCast SqlTime = "::time[]" +unnestTypeCast SqlDayTime = "::timestamptz[]" +unnestTypeCast SqlBlob = "::bytea[]" +unnestTypeCast (SqlOther t) = "::" <> t <> "[]" + +-- | Get the SQL string for the table that a PersistEntity represents. +tableName :: (PersistEntity record) => record -> Text +tableName = escapeE . tableDBName + +-- | Get the SQL string for the field that an EntityField represents. +fieldName :: (PersistEntity record) => EntityField record typ -> Text +fieldName = escapeF . fieldDBName + +-- HandleUpdateCollision and upsert functions + +-- | This type is used to determine how to update rows using Postgres' +-- @INSERT ... ON CONFLICT KEY UPDATE@ functionality. +data HandleUpdateCollision record where + CopyField :: EntityField record typ -> HandleUpdateCollision record + CopyUnlessEq + :: (PersistField typ) + => EntityField record typ + -> typ + -> HandleUpdateCollision record + +-- | Copy the field into the database only if the value in the +-- corresponding record is non-@NULL@. +copyUnlessNull + :: (PersistField typ) + => EntityField record (Maybe typ) + -> HandleUpdateCollision record +copyUnlessNull field = CopyUnlessEq field Nothing + +-- | Copy the field into the database only if the value in the +-- corresponding record is non-empty. +copyUnlessEmpty + :: (Monoid.Monoid typ, PersistField typ) + => EntityField record typ + -> HandleUpdateCollision record +copyUnlessEmpty field = CopyUnlessEq field Monoid.mempty + +-- | Copy the field into the database only if the field is not equal to the +-- provided value. +copyUnlessEq + :: (PersistField typ) + => EntityField record typ + -> typ + -> HandleUpdateCollision record +copyUnlessEq = CopyUnlessEq + +-- | Copy the field directly from the record. +copyField + :: (PersistField typ) => EntityField record typ -> HandleUpdateCollision record +copyField = CopyField + +upsertWhere + :: ( backend ~ PersistEntityBackend record + , PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + , PersistStore backend + , BackendCompatible SqlBackend backend + , OnlyOneUniqueKey record + ) + => record + -> [Update record] + -> [Filter record] + -> ReaderT backend m () +upsertWhere record updates filts = + upsertManyWhere [record] [] updates filts + +upsertManyWhere + :: forall record backend m + . ( backend ~ PersistEntityBackend record + , BackendCompatible SqlBackend backend + , PersistEntityBackend record ~ SqlBackend + , PersistEntity record + , OnlyOneUniqueKey record + , MonadIO m + ) + => [record] + -> [HandleUpdateCollision record] + -> [Update record] + -> [Filter record] + -> ReaderT backend m () +upsertManyWhere [] _ _ _ = return () +upsertManyWhere records fieldValues updates filters = do + conn <- asks projectBackend + let + uniqDef = onlyOneUniqueDef (Proxy :: Proxy record) + uncurry rawExecute $ + mkBulkUpsertQuery records conn fieldValues updates filters uniqDef + +excludeNotEqualToOriginal + :: (PersistField typ, PersistEntity rec) + => EntityField rec typ + -> Filter rec +excludeNotEqualToOriginal field = + Filter + { filterField = field + , filterFilter = Ne + , filterValue = + UnsafeValue $ + PersistLiteral_ + Unescaped + bsForExcludedField + } + where + bsForExcludedField = + T.encodeUtf8 $ + "EXCLUDED." <> fieldName field + +mkBulkUpsertQuery + :: ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , OnlyOneUniqueKey record + ) + => [record] + -> SqlBackend + -> [HandleUpdateCollision record] + -> [Update record] + -> [Filter record] + -> UniqueDef + -> (Text, [PersistValue]) +mkBulkUpsertQuery records conn fieldValues updates filters uniqDef = + (q, recordValues <> updsValues <> copyUnlessValues <> whereVals) + where + mfieldDef x = case x of + CopyField rec -> Right (fieldDbToText (persistFieldDef rec)) + CopyUnlessEq rec val -> Left (fieldDbToText (persistFieldDef rec), toPersistValue val) + (fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues + fieldDbToText = escapeF . fieldDB + entityDef' = entityDef records + conflictColumns = + map (escapeF . snd) $ NEL.toList $ uniqueFields uniqDef + firstField = case entityFieldNames of + [] -> error "The entity you're trying to insert does not have any fields." + (field : _) -> field + entityFieldNames = map fieldDbToText (getEntityFields entityDef') + nameOfTable = escapeE . getEntityDBName $ entityDef' + copyUnlessValues = map snd fieldsToMaybeCopy + recordValues = concatMap (map toPersistValue . toPersistFields) records + recordPlaceholders = + Util.commaSeparated + $ map + (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) + $ records + mkCondFieldSet n _ = + T.concat + [ n + , "=COALESCE(" + , "NULLIF(" + , "EXCLUDED." + , n + , "," + , "?" + , ")" + , "," + , nameOfTable + , "." + , n + , ")" + ] + condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy + fieldSets = map (\n -> T.concat [n, "=EXCLUDED.", n]) updateFieldNames + upds = + map + (Util.mkUpdateText' escapeF (\n -> T.concat [nameOfTable, ".", n])) + updates + updsValues = map (\(Update _ val _) -> toPersistValue val) updates + (wher, whereVals) = + if null filters + then ("", []) + else (filterClauseWithVals (Just PrefixTableName) conn filters) + updateText = + case fieldSets <> upds <> condFieldSets of + [] -> + T.concat [firstField, "=", nameOfTable, ".", firstField] + xs -> + Util.commaSeparated xs + q = + T.concat + [ "INSERT INTO " + , nameOfTable + , Util.parenWrapped . Util.commaSeparated $ entityFieldNames + , " VALUES " + , recordPlaceholders + , " ON CONFLICT " + , Util.parenWrapped $ Util.commaSeparated $ conflictColumns + , " DO UPDATE SET " + , updateText + , wher + ] + +-- | Enable a Postgres extension. +migrateEnableExtension :: Text -> Migration +migrateEnableExtension extName = WriterT $ WriterT $ do + res :: [Single Int] <- + rawSql + "SELECT COUNT(*) FROM pg_catalog.pg_extension WHERE extname = ?" + [PersistText extName] + if res == [Single 0] + then return (((), []), [(False, "CREATe EXTENSION \"" <> extName <> "\"")]) + else return (((), []), []) + +-- | Mock a migration even when the database is not present. +mockMigration :: Migration -> IO () +mockMigration mig = do + smap <- newIORef mempty + let + mockMigrateOverrides = emptyBackendSpecificOverrides + mockMigrateFn allDefs _ entity = + fmap (fmap $ map showAlterDb) $ + return $ + Right $ + mockMigrateStructured mockMigrateOverrides allDefs entity + sqlbackend = + mkSqlBackend + MkSqlBackendArgs + { connPrepare = \_ -> do + return + Statement + { stmtFinalize = return () + , stmtReset = return () + , stmtExecute = undefined + , stmtQuery = \_ -> return $ return () + } + , connInsertSql = undefined + , connStmtMap = smap + , connClose = undefined + , connMigrateSql = mockMigrateFn + , connBegin = undefined + , connCommit = undefined + , connRollback = undefined + , connEscapeFieldName = escapeF + , connEscapeTableName = escapeE . getEntityDBName + , connEscapeRawName = escape + , connNoLimit = undefined + , connRDBMS = undefined + , connLimitOffset = undefined + , connLogFunc = undefined + } + result = runReaderT $ runWriterT $ runWriterT mig + resp <- result sqlbackend + mapM_ TIO.putStrLn $ map snd $ snd resp + +-- PostgresConf + +-- | Information required to connect to a PostgreSQL database. +data PostgresConf = PostgresConf + { pgConnStr :: ConnectionString + , pgPoolStripes :: Int + , pgPoolIdleTimeout :: Integer + , pgPoolSize :: Int + } + deriving (Show, Read, Data) + +instance FromJSON PostgresConf where + parseJSON v = modifyFailure ("Persistent: error loading PostgreSQL conf: " ++) $ + flip (withObject "PostgresConf") v $ \o -> do + let defaultPoolConfig = defaultConnectionPoolConfig + database <- o .: "database" + host <- o .: "host" + port <- o .:? "port" .!= (5432 :: Int) + user <- o .: "user" + password <- o .: "password" + poolSize <- o .:? "poolsize" .!= (connectionPoolConfigSize defaultPoolConfig) + poolStripes <- + o .:? "stripes" .!= (connectionPoolConfigStripes defaultPoolConfig) + poolIdleTimeout <- + o .:? "idleTimeout" + .!= (floor $ connectionPoolConfigIdleTimeout defaultPoolConfig) + let + cstr = + B8.intercalate " " + [ "host='" <> B8.pack host <> "'" + , "port=" <> B8.pack (show port) + , "user='" <> B8.pack user <> "'" + , "password='" <> B8.pack password <> "'" + , "dbname='" <> B8.pack database <> "'" + ] + return $ PostgresConf cstr poolStripes poolIdleTimeout poolSize + +instance PersistConfig PostgresConf where + type PersistConfigBackend PostgresConf = ReaderT (WriteBackend PostgreSQLBackend) + type PersistConfigPool PostgresConf = Pool (WriteBackend PostgreSQLBackend) + createPoolConfig conf = + runNoLoggingT $ createPostgresqlPipelinePool + defaultPipelineSettings + { pipelineConnStr = pgConnStr conf + , pipelinePoolSize = pgPoolSize conf + } + runPool _ = runSqlPool + loadConfig = parseJSON + applyEnv c0 = do + env <- getEnvironment + return $ + addUser env $ + addPass env $ + addDatabase env $ + addPort env $ + addHost env c0 + where + addParam param val c = + c{pgConnStr = B8.concat [pgConnStr c, " ", param, "='", pgescape val, "'"]} + + pgescape = B8.pack . go + where + go ('\'' : rest) = '\\' : '\'' : go rest + go ('\\' : rest) = '\\' : '\\' : go rest + go (x : rest) = x : go rest + go [] = [] + + maybeAddParam param envvar env = + maybe id (addParam param) $ lookup envvar env + + addHost = maybeAddParam "host" "PGHOST" + addPort = maybeAddParam "port" "PGPORT" + addUser = maybeAddParam "user" "PGUSER" + addPass = maybeAddParam "password" "PGPASS" + addDatabase = maybeAddParam "dbname" "PGDATABASE" + +-- | Hooks for configuring the connection to Postgres. +data PostgresConfHooks = PostgresConfHooks + { pgConfHooksGetServerVersion :: LibPQ.Connection -> IO (NonEmpty Word) + , pgConfHooksAfterCreate :: LibPQ.Connection -> IO () + } + +-- | Default settings for 'PostgresConfHooks'. +defaultPostgresConfHooks :: PostgresConfHooks +defaultPostgresConfHooks = + PostgresConfHooks + { pgConfHooksGetServerVersion = \conn -> do + v <- LibPQ.serverVersion conn + let ver = fromIntegral v + major = ver `div` 10000 + minor = (ver `mod` 10000) `div` 100 + if ver == 0 + then return $ 9 NEL.:| [4] + else return $ major NEL.:| [minor] + , pgConfHooksAfterCreate = const $ pure () + } + +-- Pipeline mode + +-- | No-op: pipeline mode is always on. +-- +-- In previous versions, this function entered and exited libpq pipeline mode +-- around a block of operations. Now that pipelining is automatic and always +-- active, this is the identity function. It is kept for API compatibility. +withPipeline :: (MonadIO m) => ReaderT SqlBackend m a -> ReaderT SqlBackend m a +withPipeline = id + +-- | Flush all pending fire-and-forget results in the pipeline. +-- +-- Forces all queued DML operations to be sent to the server and their +-- results consumed. Throws 'PipelineExecError' if any operation +-- failed. This is useful for eagerly discovering errors from +-- fire-and-forget operations (INSERT, UPDATE, DELETE) without waiting +-- for the next SELECT or COMMIT. +flushPipeline :: (MonadIO m) => ReaderT SqlBackend m () +flushPipeline = do + backend <- ask + case getPipelineConn backend of + Nothing -> return () + Just pc -> liftIO $ drainPending pc + +-- | Execute raw SQL synchronously, returning the actual affected row count. +-- +-- Unlike 'rawExecuteCount' (which returns 0 in pipeline mode because DML +-- is fire-and-forget), this function drains all pending pipeline operations, +-- sends the statement, flushes, and reads the result -- giving you the real +-- number of rows affected. +-- +-- Use this when you need the count (e.g. for @deleteCount@, @updateCount@ +-- semantics). +rawExecuteSync + :: (MonadIO m) + => Text + -> [PersistValue] + -> ReaderT SqlBackend m Int64 +rawExecuteSync sql vals = do + backend <- ask + case getPipelineConn backend of + Nothing -> return 0 + Just pc -> liftIO $ do + drainPending pc + let sqlBS = T.encodeUtf8 sql + (rewrittenSQL, remainingVals) = inlineAndRewrite sqlBS vals + params = map encodePersistValue remainingVals + ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary + unless ok $ do + merr <- LibPQ.errorMessage (pgConn pc) + throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr) + _ <- pgSendFlushRequest pc + pgFlush pc + ret <- readOneQueryResult pc + mbs <- LibPQ.cmdTuples ret + LibPQ.unsafeFreeResult ret + case mbs of + Nothing -> return 0 + Just bs -> case B8.readInt bs of + Just (n, _) -> return (fromIntegral n) + Nothing -> return 0 + +-- | Execute raw SQL as fire-and-forget (non-flushing). +-- +-- The statement is queued in the pipeline and its result is consumed at the +-- next drain point (SELECT, COMMIT, or 'flushPipeline'). Errors from this +-- statement will surface at that drain point as a 'PipelineExecError'. +-- +-- This is the same behavior as 'rawExecute', but the name makes the +-- pipeline semantics explicit. +rawExecuteNoReturn + :: (MonadIO m) + => Text + -> [PersistValue] + -> ReaderT SqlBackend m () +rawExecuteNoReturn sql vals = do + backend <- ask + case getPipelineConn backend of + Nothing -> return () + Just pc -> liftIO $ do + let sqlBS = T.encodeUtf8 sql + (rewrittenSQL, remainingVals) = inlineAndRewrite sqlBS vals + params = map encodePersistValue remainingVals + ok <- pgSendQueryParams pc rewrittenSQL params LibPQ.Binary + unless ok $ do + merr <- LibPQ.errorMessage (pgConn pc) + throwIO $ PipelineExecError (fromMaybe "sendQueryParams failed" merr) + lazyReply <- pgRecvResult pc + modifyIORef (pgPending pc) (lazyReply :) + +-- RawPostgresqlPipeline + +-- | Wrapper for persistent SqlBackends that carry the corresponding +-- libpq 'PgConn'. +data RawPostgresqlPipeline backend = RawPostgresqlPipeline + { persistentBackend :: backend + , rawPipelineConnection :: PgConn + } + +instance BackendCompatible (RawPostgresqlPipeline b) (RawPostgresqlPipeline b) where + projectBackend = id + +instance BackendCompatible b (RawPostgresqlPipeline b) where + projectBackend = persistentBackend + +withRawConnection + :: (PgConn -> SqlBackend) + -> PgConn + -> RawPostgresqlPipeline SqlBackend +withRawConnection f pc = + RawPostgresqlPipeline + { persistentBackend = f pc + , rawPipelineConnection = pc + } + +-- | Create a PostgreSQL connection pool which also exposes the +-- raw PgConn. +createRawPostgresqlPipelinePool + :: (MonadUnliftIO m, MonadLoggerIO m) + => PipelineSettings + -> m (Pool (RawPostgresqlPipeline SqlBackend)) +createRawPostgresqlPipelinePool settings = + createSqlPool (\logFunc -> do + pc <- openPgConn (pipelineFetchMode settings) (pipelineConnStr settings) + smap <- newIORef mempty + return $ withRawConnection (\pc' -> createBackend logFunc (pgVersion pc') smap pc') pc + ) (pipelinePoolSize settings) + +instance (PersistCore b) => PersistCore (RawPostgresqlPipeline b) where + newtype BackendKey (RawPostgresqlPipeline b) = RawPostgresqlPipelineKey + { unRawPostgresqlPipelineKey :: BackendKey (Compatible b (RawPostgresqlPipeline b)) + } + +makeCompatibleKeyInstances [t| forall b. Compatible b (RawPostgresqlPipeline b) |] + +$(pure []) + +makeCompatibleInstances [t| forall b. Compatible b (RawPostgresqlPipeline b) |] diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/FFI.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/FFI.hs new file mode 100644 index 000000000..e2627ecc9 --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/FFI.hs @@ -0,0 +1,149 @@ +{-# LANGUAGE PatternSynonyms #-} + +-- | Additional libpq FFI bindings for features not yet in @postgresql-libpq@. +-- +-- Provides: +-- +-- * 'setChunkedRowsMode' — PG17's @PQsetChunkedRowsMode@, with a +-- compile-time stub when headers are pre-17. +-- * 'hasChunkedRowsSupport' — runtime check for chunked mode availability. +-- * 'rawResultStatusInt' — raw @ExecStatus@ as 'CInt', bypassing the +-- @postgresql-libpq@ enum which throws on unknown values (e.g. +-- @PGRES_TUPLES_CHUNK@). +-- * Pattern synonyms for the raw @ExecStatus@ integer constants. +module Database.Persist.Postgresql.Pipeline.FFI + ( -- * Chunked rows mode (PG17+) + setChunkedRowsMode + , hasChunkedRowsSupport + -- * Raw result status + , rawResultStatusInt + -- * Raw ExecStatus constants + , pattern PGRES_EMPTY_QUERY + , pattern PGRES_COMMAND_OK + , pattern PGRES_TUPLES_OK + , pattern PGRES_COPY_OUT + , pattern PGRES_COPY_IN + , pattern PGRES_BAD_RESPONSE + , pattern PGRES_NONFATAL_ERROR + , pattern PGRES_FATAL_ERROR + , pattern PGRES_COPY_BOTH + , pattern PGRES_SINGLE_TUPLE + , pattern PGRES_PIPELINE_SYNC + , pattern PGRES_PIPELINE_ABORTED + , pattern PGRES_TUPLES_CHUNK + ) where + +import Foreign.C.Types (CInt (..)) +import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) +import Foreign.Ptr (Ptr) +import Unsafe.Coerce (unsafeCoerce) + +import Database.PostgreSQL.LibPQ.Internal (PGconn, withConn) +import qualified Database.PostgreSQL.LibPQ as LibPQ + +------------------------------------------------------------------------------- +-- Foreign imports from cbits/hs_libpq_extra.c +------------------------------------------------------------------------------- + +foreign import ccall unsafe "hs_PQsetChunkedRowsMode" + c_setChunkedRowsMode :: Ptr PGconn -> CInt -> IO CInt + +foreign import ccall unsafe "hs_has_chunked_rows_support" + c_hasChunkedRowsSupport :: IO CInt + +-- | Phantom type tag matching the C @PGresult@ struct. We declare our +-- own rather than reusing the one from @postgresql-libpq@ (which is not +-- exported). +data PGresult + +foreign import ccall unsafe "hs_PQresultStatusRaw" + c_resultStatusRaw :: Ptr PGresult -> IO CInt + +------------------------------------------------------------------------------- +-- Haskell wrappers +------------------------------------------------------------------------------- + +-- | Activate chunked row retrieval for the query that was just sent via +-- @sendQueryParams@. Must be called immediately after sending, before +-- any other operation on the connection. +-- +-- Returns 'True' on success, 'False' if the mode could not be set +-- (e.g. libpq < 17, or wrong call timing). +setChunkedRowsMode :: LibPQ.Connection -> Int -> IO Bool +setChunkedRowsMode conn chunkSize = + withConn conn $ \ptr -> do + rc <- c_setChunkedRowsMode ptr (fromIntegral chunkSize) + return (rc == 1) + +-- | 'True' if the linked libpq supports @PQsetChunkedRowsMode@ (PG17+). +-- Cached at first call; safe to call repeatedly. +hasChunkedRowsSupport :: IO Bool +hasChunkedRowsSupport = do + rc <- c_hasChunkedRowsSupport + return (rc == 1) + +-- | Get the raw @ExecStatus@ integer from a 'LibPQ.Result'. +-- +-- This bypasses @LibPQ.resultStatus@ which calls @fail@ on status codes +-- it doesn't recognise (e.g. @PGRES_TUPLES_CHUNK@ = 12 on older +-- @postgresql-libpq@ builds). +-- +-- __Safety__: 'LibPQ.Result' is a @newtype@ around @ForeignPtr PGresult@. +-- We use 'unsafeCoerce' to extract the 'ForeignPtr' without the +-- constructor being in scope. This is safe because: +-- +-- * The representation is guaranteed by @postgresql-libpq@'s source. +-- * 'withForeignPtr' keeps the pointer alive for the duration of the call. +-- * The C function is a pure read with no side effects. +rawResultStatusInt :: LibPQ.Result -> IO CInt +rawResultStatusInt result = + withForeignPtr (unsafeCoerce result :: ForeignPtr PGresult) c_resultStatusRaw + +------------------------------------------------------------------------------- +-- Raw ExecStatus constants +-- +-- These mirror the C enum ExecStatusType. Values 0–11 are stable across +-- all supported PostgreSQL versions; 12 (PGRES_TUPLES_CHUNK) was added +-- in PG17. +------------------------------------------------------------------------------- + +pattern PGRES_EMPTY_QUERY :: CInt +pattern PGRES_EMPTY_QUERY = 0 + +pattern PGRES_COMMAND_OK :: CInt +pattern PGRES_COMMAND_OK = 1 + +pattern PGRES_TUPLES_OK :: CInt +pattern PGRES_TUPLES_OK = 2 + +pattern PGRES_COPY_OUT :: CInt +pattern PGRES_COPY_OUT = 3 + +pattern PGRES_COPY_IN :: CInt +pattern PGRES_COPY_IN = 4 + +pattern PGRES_BAD_RESPONSE :: CInt +pattern PGRES_BAD_RESPONSE = 5 + +pattern PGRES_NONFATAL_ERROR :: CInt +pattern PGRES_NONFATAL_ERROR = 6 + +pattern PGRES_FATAL_ERROR :: CInt +pattern PGRES_FATAL_ERROR = 7 + +pattern PGRES_COPY_BOTH :: CInt +pattern PGRES_COPY_BOTH = 8 + +pattern PGRES_SINGLE_TUPLE :: CInt +pattern PGRES_SINGLE_TUPLE = 9 + +pattern PGRES_PIPELINE_SYNC :: CInt +pattern PGRES_PIPELINE_SYNC = 10 + +pattern PGRES_PIPELINE_ABORTED :: CInt +pattern PGRES_PIPELINE_ABORTED = 11 + +-- | Chunked tuple result (PG17+). Each @PGresult@ contains up to N rows +-- as specified by @PQsetChunkedRowsMode@. +pattern PGRES_TUPLES_CHUNK :: CInt +pattern PGRES_TUPLES_CHUNK = 12 diff --git a/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/Internal.hs b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/Internal.hs new file mode 100644 index 000000000..a2a328e3a --- /dev/null +++ b/persistent-postgresql-ng/Database/Persist/Postgresql/Pipeline/Internal.hs @@ -0,0 +1,352 @@ +{-# LANGUAGE OverloadedStrings #-} + +-- | Low-level PostgreSQL connection wrapper with server version info. +-- +-- Wraps 'LibPQ.Connection' and provides pipeline mode primitives. +module Database.Persist.Postgresql.Pipeline.Internal + ( PgConn (..) + , FetchMode (..) + , openPgConn + , closePgConn + , pgSendQueryParams + , pgGetResult + , pgEnterPipelineMode + , pgExitPipelineMode + , pgPipelineSync + , pgPipelineStatus + , pgFlush + , pgSendFlushRequest + , pgRecvResult + , PipelineError (..) + ) where + +import Control.Exception (Exception, throwIO) +import Control.Monad (when) +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as B8 +import Data.IORef +import Data.List.NonEmpty (NonEmpty (..)) +import Data.Maybe (mapMaybe) +import GHC.Conc (threadWaitWrite) +import System.IO.Unsafe (unsafeInterleaveIO) +import qualified Database.PostgreSQL.LibPQ as LibPQ + +import Database.Persist.Postgresql.Internal.PgType (OidCache, emptyOidCache) +import Database.Persist.Postgresql.Pipeline.FFI (hasChunkedRowsSupport) + +-- | Controls how query result rows are fetched from the server. +data FetchMode + = FetchAll + -- ^ Default. The server returns the entire result set in a single + -- @PGresult@. Simplest and fastest for small\/medium result sets, + -- but the full result must fit in memory. + | FetchSingleRow + -- ^ Single-row mode (@PQsetSingleRowMode@). Each @PGresult@ contains + -- exactly one row (status @PGRES_SINGLE_TUPLE@). Lowest memory + -- footprint but highest per-row allocation overhead. Works with + -- any PostgreSQL\/libpq version. + | FetchChunked !Int + -- ^ Chunked mode (@PQsetChunkedRowsMode@, PG17+ libpq). Each + -- @PGresult@ contains up to the given number of rows (status + -- @PGRES_TUPLES_CHUNK@). Good balance of bounded memory and reduced + -- allocation overhead. + -- + -- If the linked libpq does not support chunked mode, this is + -- automatically downgraded to 'FetchSingleRow' at connection time. + deriving (Eq, Show) + +-- | A PostgreSQL connection with cached server version info. +-- +-- Pipeline mode is always on. The 'pgPending' counter tracks how many +-- fire-and-forget query results have not yet been read. These are drained +-- at query time ('stmtQuery') or transaction boundaries ('connCommit', +-- 'connRollback'). +data PgConn = PgConn + { pgConn :: !LibPQ.Connection + , pgVersion :: !(NonEmpty Word) + , pgPending :: !(IORef [LibPQ.Result]) + -- ^ All outstanding pipeline results (lazy thunks popped from + -- 'pgReplies'). Both fire-and-forget DML ('execute'') and lazy + -- reads ('pipelinedSend') append here. 'drainPending' forces, + -- error-checks, and frees all of them. Double-free is safe + -- because 'LibPQ.Result' is a 'ForeignPtr' -- 'unsafeFreeResult' + -- is idempotent. + -- + -- Stored in reverse order (newest first) for O(1) append. + , pgFetchMode :: !FetchMode + -- ^ Row fetch strategy for SELECT-like queries on this connection. + , pgOidCache :: !(IORef OidCache) + -- ^ Cache for dynamically-discovered type OIDs (custom enums, + -- composites, domains). Starts empty; future custom-type support + -- can populate it at connection time or on first encounter. + -- Use 'resolveOid' from "Database.Persist.Postgresql.Internal.PgType" + -- to consult both the built-in table and this cache. + , pgReplies :: !(IORef [LibPQ.Result]) + -- ^ Lazy reply stream (Hedis-style). Each element, when forced, + -- flushes the send buffer and reads one pipeline result + NULL + -- separator from the connection. Created at connection time via + -- 'unsafeInterleaveIO'. Used by 'pgRecvResult' which pops the + -- head lazily (without forcing the IO) so that back-to-back reads + -- are automatically pipelined. + } + +-- | Errors from pipeline operations. +data PipelineError + = PipelineExecError !ByteString + -- ^ Error message from a failed query + | PipelineModeError !String + -- ^ Error entering/exiting pipeline mode + deriving (Show) + +instance Exception PipelineError + +-- | Resolve the requested 'FetchMode', downgrading 'FetchChunked' to +-- 'FetchSingleRow' when the linked libpq doesn't support it. +resolveFetchMode :: FetchMode -> IO FetchMode +resolveFetchMode fm@FetchAll = return fm +resolveFetchMode fm@FetchSingleRow = return fm +resolveFetchMode fm@(FetchChunked _) = do + ok <- hasChunkedRowsSupport + if ok + then return fm + else return FetchSingleRow + +-- | Open a libpq connection, detect the server version, and enter pipeline mode. +-- +-- The 'FetchMode' is resolved (see 'resolveFetchMode') and stored in the +-- returned 'PgConn'. +openPgConn :: FetchMode -> ByteString -> IO PgConn +openPgConn requestedMode connStr = do + conn <- LibPQ.connectdb connStr + st <- LibPQ.status conn + case st of + LibPQ.ConnectionOk -> do + ver <- getServerVersionLibPQ conn + -- Enable nonblocking mode before entering pipeline mode. + -- In blocking mode, pipeline operations risk deadlock when + -- both client and server buffers fill simultaneously. + nbOk <- LibPQ.setnonblocking conn True + when (not nbOk) $ do + LibPQ.finish conn + fail "persistent-postgresql-ng: failed to set nonblocking mode" + pending <- newIORef ([] :: [LibPQ.Result]) + oidCache <- newIORef emptyOidCache + mode <- resolveFetchMode requestedMode + repliesRef <- newIORef [] + let pc = PgConn + { pgConn = conn + , pgVersion = ver + , pgPending = pending + , pgFetchMode = mode + , pgOidCache = oidCache + , pgReplies = repliesRef + } + replies <- mkReplyStream pc + writeIORef repliesRef replies + ok <- pgEnterPipelineMode pc + if ok + then return pc + else do + LibPQ.finish conn + fail "persistent-postgresql-ng: failed to enter pipeline mode" + _ -> do + merr <- LibPQ.errorMessage conn + LibPQ.finish conn + let msg = maybe "Unknown connection error" B8.unpack merr + fail $ "persistent-postgresql-ng: connection failed: " ++ msg + +-- | Close a libpq connection. Exits pipeline mode first. +closePgConn :: PgConn -> IO () +closePgConn pc = do + -- Best-effort: exit pipeline mode before closing. + -- If there are pending results, pipelineSync + drain them first. + pendingList <- readIORef (pgPending pc) + when (not (null pendingList)) $ do + _ <- pgPipelineSync pc + drainToSyncClose (pgConn pc) + _ <- pgExitPipelineMode pc + LibPQ.finish (pgConn pc) + where + -- Drain results until PipelineSync, ignoring all errors. + drainToSyncClose conn = do + mret <- LibPQ.getResult conn + case mret of + Nothing -> drainToSyncClose conn -- NULL separator, keep reading + Just ret -> do + st <- LibPQ.resultStatus ret + LibPQ.unsafeFreeResult ret + case st of + LibPQ.PipelineSync -> return () + _ -> drainToSyncClose conn + +-- | Send a query in pipeline mode (non-blocking). +pgSendQueryParams + :: PgConn + -> ByteString + -> [Maybe (LibPQ.Oid, ByteString, LibPQ.Format)] + -> LibPQ.Format + -> IO Bool +pgSendQueryParams pc sql params resultFmt = + LibPQ.sendQueryParams (pgConn pc) sql params resultFmt + +-- | Get the next result from the connection (for pipeline mode). +pgGetResult :: PgConn -> IO (Maybe LibPQ.Result) +pgGetResult = LibPQ.getResult . pgConn + +-- | Enter pipeline mode. +pgEnterPipelineMode :: PgConn -> IO Bool +pgEnterPipelineMode = LibPQ.enterPipelineMode . pgConn + +-- | Exit pipeline mode. +pgExitPipelineMode :: PgConn -> IO Bool +pgExitPipelineMode = LibPQ.exitPipelineMode . pgConn + +-- | Insert a sync point in pipeline mode. +pgPipelineSync :: PgConn -> IO Bool +pgPipelineSync = LibPQ.pipelineSync . pgConn + +-- | Get the current pipeline status. +pgPipelineStatus :: PgConn -> IO LibPQ.PipelineStatus +pgPipelineStatus = LibPQ.pipelineStatus . pgConn + +-- | Flush the client send buffer, waiting for socket writability as needed. +-- In nonblocking mode, 'LibPQ.flush' may return 'FlushWriting' when the +-- socket's send buffer is full. We wait for writability using GHC's I/O +-- manager ('threadWaitWrite'), then consume any pending server data +-- ('consumeInput') to prevent the server from blocking on its send buffer, +-- and retry. +-- Throws 'PipelineExecError' on failure. +pgFlush :: PgConn -> IO () +pgFlush pc = do + st <- LibPQ.flush (pgConn pc) + case st of + LibPQ.FlushOk -> return () + LibPQ.FlushFailed -> throwIO $ PipelineExecError "flush failed" + LibPQ.FlushWriting -> do + mfd <- LibPQ.socket (pgConn pc) + case mfd of + Nothing -> throwIO $ PipelineExecError "flush: no socket" + Just fd -> flushLoop fd + where + flushLoop fd = do + threadWaitWrite fd + -- Consume any pending server data to prevent server-side blocking. + -- The server may be unable to read our data because its send buffer + -- is full with results we haven't consumed yet. + _ <- LibPQ.consumeInput (pgConn pc) + st <- LibPQ.flush (pgConn pc) + case st of + LibPQ.FlushOk -> return () + LibPQ.FlushFailed -> throwIO $ PipelineExecError "flush failed" + LibPQ.FlushWriting -> flushLoop fd + +-- | Send a request for the server to flush its output buffer. +-- In pipeline mode, the server buffers results until it receives a +-- Sync or FlushRequest message. +pgSendFlushRequest :: PgConn -> IO Bool +pgSendFlushRequest = LibPQ.sendFlushRequest . pgConn + +-- | Pop the next result from the lazy reply stream without forcing the +-- read. The actual flush+read happens when the returned 'LibPQ.Result' +-- is evaluated (via 'unsafeInterleaveIO'). +-- +-- Uses @head@/@tail@ instead of pattern matching to avoid forcing the +-- cons cell -- this is the key to Hedis-style automatic pipelining. +-- @atomicModifyIORef@ stores @tail xs@ and returns @head xs@ as +-- unevaluated thunks. The actual IO fires only when the caller +-- inspects the returned 'LibPQ.Result'. +pgRecvResult :: PgConn -> IO LibPQ.Result +pgRecvResult pc = atomicModifyIORef (pgReplies pc) (\xs -> (tail xs, head xs)) +{-# INLINE pgRecvResult #-} + +-- | Build the infinite lazy reply stream. Each cons cell, when forced, +-- flushes the send buffer and reads one query result (result + NULL +-- separator) from the connection. +mkReplyStream :: PgConn -> IO [LibPQ.Result] +mkReplyStream pc = go + where + go = unsafeInterleaveIO $ do + -- Flush the send buffer so the server sees all queued queries. + _ <- LibPQ.sendFlushRequest (pgConn pc) + pgFlush pc + -- Read one result. + mret <- LibPQ.getResult (pgConn pc) + case mret of + Nothing -> do + -- No result available -- shouldn't happen in a well-formed + -- pipeline, but protect against it. + rest <- go + return rest + Just ret -> do + st <- LibPQ.resultStatus ret + case st of + LibPQ.CommandOk -> do + -- DML result (INSERT/UPDATE/DELETE). + -- Consume the NULL separator. + _ <- LibPQ.getResult (pgConn pc) + rest <- go + return (ret : rest) + LibPQ.TuplesOk -> do + -- Query result (SELECT). + -- Consume the NULL separator. + _ <- LibPQ.getResult (pgConn pc) + rest <- go + return (ret : rest) + _ -> do + -- Error or other status -- still consume separator + -- and include in stream (caller handles errors). + _ <- LibPQ.getResult (pgConn pc) + rest <- go + return (ret : rest) + +-- | Get the server version from a libpq connection. +-- +-- Uses @LibPQ.serverVersion@ which returns an integer encoding like @140005@ +-- for version 14.0.5 (major * 10000 + minor * 100 + patch). +getServerVersionLibPQ :: LibPQ.Connection -> IO (NonEmpty Word) +getServerVersionLibPQ conn = do + v <- LibPQ.serverVersion conn + let ver = fromIntegral v + major = ver `div` 10000 + minor = (ver `mod` 10000) `div` 100 + patch = ver `mod` 100 + if ver == 0 + then + -- Fallback: if serverVersion returns 0, try querying + getServerVersionQuery conn + else + return $ major :| [minor, patch] + +-- | Fallback: query the server version via SQL. +getServerVersionQuery :: LibPQ.Connection -> IO (NonEmpty Word) +getServerVersionQuery conn = do + mret <- LibPQ.exec conn "SHOW server_version" + case mret of + Nothing -> return $ 9 :| [4] -- Default to minimum supported + Just ret -> do + st <- LibPQ.resultStatus ret + case st of + LibPQ.TuplesOk -> do + mbs <- LibPQ.getvalue ret (LibPQ.toRow (0 :: Int)) (LibPQ.toColumn (0 :: Int)) + LibPQ.unsafeFreeResult ret + case mbs of + Nothing -> return $ 9 :| [4] + Just bs -> case parseVersionBS bs of + Just v -> return v + Nothing -> return $ 9 :| [4] + _ -> do + LibPQ.unsafeFreeResult ret + return $ 9 :| [4] + +-- | Parse a version string like "14.5" or "14.5.1" from a ByteString. +parseVersionBS :: ByteString -> Maybe (NonEmpty Word) +parseVersionBS bs = + let parts = B8.split '.' bs + nums = mapMaybe readWord parts + in case nums of + (x:xs) -> Just (x :| xs) + [] -> Nothing + where + readWord s = case reads (B8.unpack s) of + [(n, "")] -> Just n + _ -> Nothing diff --git a/persistent-postgresql-ng/LICENSE b/persistent-postgresql-ng/LICENSE new file mode 100644 index 000000000..f7f439fe3 --- /dev/null +++ b/persistent-postgresql-ng/LICENSE @@ -0,0 +1,20 @@ +Copyright (c) 2026 Ian Duncan + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/persistent-postgresql-ng/bench-baseline.html b/persistent-postgresql-ng/bench-baseline.html new file mode 100644 index 000000000..5ccf8537f --- /dev/null +++ b/persistent-postgresql-ng/bench-baseline.html @@ -0,0 +1,1159 @@ + + + + + criterion report + + + + + + +
+

criterion performance measurements

+

want to understand this report?

+

overview

+
+ + + + + +
+ +
+
+ + + + + +
+
+

colophon

+

+ This report was created using the criterion + benchmark execution and performance analysis tool. +

+

+ Criterion is developed and maintained + by Bryan O'Sullivan. +

+
+
+ + diff --git a/persistent-postgresql-ng/bench/Main.hs b/persistent-postgresql-ng/bench/Main.hs new file mode 100644 index 000000000..97b2952b1 --- /dev/null +++ b/persistent-postgresql-ng/bench/Main.hs @@ -0,0 +1,302 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | Benchmarks comparing persistent-postgresql-ng (binary protocol) +-- against persistent-postgresql (text protocol via postgresql-simple). +-- +-- Run with: +-- +-- @ +-- stack bench persistent-postgresql-ng +-- @ +-- +-- Requires PostgreSQL running on localhost:5432 with a @test@ database +-- accessible by the current user (or the @postgres@ role). +module Main (main) where + +import Control.Exception (evaluate) +import Control.Monad (forM_, void) +import Control.Monad.IO.Class (liftIO) +import Control.Monad.Logger (runNoLoggingT) +import Criterion.Main +import qualified Data.ByteString.Char8 as B8 +import Data.Int (Int64) +import Data.Maybe (fromMaybe) +import Data.Pool (Pool) +import Data.Text (Text) +import qualified Data.Text as T +import Control.Monad.Trans.Reader (ReaderT, withReaderT) +import Database.Persist +import Database.Persist.Class.PersistStore (BackendCompatible (..)) +import Database.Persist.Sql +import Database.Persist.TH +import System.Environment (lookupEnv) + +import qualified Database.Persist.Postgresql as PgSimple +import qualified Database.Persist.Postgresql.Pipeline as PgPipeline + +share + [mkPersist sqlSettings, mkMigrate "benchMigrate"] + [persistLowerCase| +BenchPerson + name Text + age Int + email Text + UniqueBenchPersonEmail email + deriving Show Eq +|] + +getConnString :: IO B8.ByteString +getConnString = do + host <- fromMaybe "localhost" <$> lookupEnv "PGHOST" + port <- fromMaybe "5432" <$> lookupEnv "PGPORT" + user <- fromMaybe "postgres" <$> lookupEnv "PGUSER" + db <- fromMaybe "test" <$> lookupEnv "PGDATABASE" + pure $ B8.pack $ "host=" <> host <> " port=" <> port + <> " user=" <> user <> " dbname=" <> db + +cleanTablesPl :: ReaderT (PgPipeline.WriteBackend PgPipeline.PostgreSQLBackend) IO () +cleanTablesPl = deleteWhere ([] :: [Filter BenchPerson]) + +cleanTablesSim :: SqlPersistT IO () +cleanTablesSim = deleteWhere ([] :: [Filter BenchPerson]) + +people :: Int -> [BenchPerson] +people n = + [ BenchPerson + (T.pack $ "Person " <> show i) + (20 + i `mod` 60) + (T.pack $ "person" <> show i <> "@test.com") + | i <- [1..n] + ] + +pl :: Pool (PgPipeline.WriteBackend PgPipeline.PostgreSQLBackend) + -> ReaderT (PgPipeline.WriteBackend PgPipeline.PostgreSQLBackend) IO a -> IO a +pl pool action = runSqlPool action pool + + +sim :: Pool SqlBackend -> SqlPersistT IO a -> IO a +sim pool action = runSqlPool action pool + +main :: IO () +main = do + connStr <- getConnString + + pipelinePool <- runNoLoggingT $ PgPipeline.createPostgresqlPipelinePool + PgPipeline.defaultPipelineSettings + { PgPipeline.pipelineConnStr = connStr + , PgPipeline.pipelinePoolSize = 1 + } + simplePool <- runNoLoggingT $ PgSimple.createPostgresqlPool connStr 1 + + pl pipelinePool $ withReaderT projectBackend $ void $ runMigrationSilent benchMigrate + + let p = pl pipelinePool + s = sim simplePool + + defaultMain + [ bgroup "delete x100 then select (pipeline sweet spot)" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl + keys <- mapM insert (people 100) + forM_ keys delete + void $ selectList ([] :: [Filter BenchPerson]) [] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim + keys <- mapM insert (people 100) + mapM_ delete keys + void $ selectList ([] :: [Filter BenchPerson]) [] + ] + + , bgroup "update x100 then select (pipeline sweet spot)" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl + keys <- mapM insert (people 100) + forM_ keys $ \k -> update k [BenchPersonAge =. 99] + void $ selectList ([] :: [Filter BenchPerson]) [] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim + keys <- mapM insert (people 100) + forM_ keys $ \k -> update k [BenchPersonAge =. 99] + void $ selectList ([] :: [Filter BenchPerson]) [] + ] + + , bgroup "mixed DML x100 (delete+update+replace, no reads until end)" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl + keys <- mapM insert (people 300) + let (delKeys, rest) = splitAt 100 keys + let (updKeys, repKeys) = splitAt 100 rest + forM_ delKeys delete + forM_ updKeys $ \k -> update k [BenchPersonAge =. 99] + forM_ (zip repKeys (people 100)) $ \(k, person) -> + replace k (person { benchPersonAge = 42 }) + void $ selectList ([] :: [Filter BenchPerson]) [] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim + keys <- mapM insert (people 300) + let (delKeys, rest) = splitAt 100 keys + let (updKeys, repKeys) = splitAt 100 rest + mapM_ delete delKeys + forM_ updKeys $ \k -> update k [BenchPersonAge =. 99] + forM_ (zip repKeys (people 100)) $ \(k, person) -> + replace k (person { benchPersonAge = 42 }) + void $ selectList ([] :: [Filter BenchPerson]) [] + ] + + , bgroup "getManyKeys x100 (single round-trip vs N)" + [ bench "pipeline (= ANY)" $ whnfIO $ p $ do + cleanTablesPl + keys <- mapM insert (people 100) + void $ withReaderT projectBackend $ PgPipeline.getManyKeys keys + , bench "simple (OR chain)" $ whnfIO $ s $ do + cleanTablesSim + keys <- mapM insert (people 100) + void $ getMany keys + ] + + , bgroup "single insert x100" + [ bench "pipeline" $ whnfIO $ p $ cleanTablesPl >> forM_ (people 100) insert_ + , bench "simple" $ whnfIO $ s $ cleanTablesSim >> forM_ (people 100) insert_ + ] + + , bgroup "insertMany x100" + [ bench "pipeline" $ whnfIO $ p $ cleanTablesPl >> insertMany_ (people 100) + , bench "simple" $ whnfIO $ s $ cleanTablesSim >> insertMany_ (people 100) + ] + + , bgroup "insertMany x1000" + [ bench "pipeline" $ whnfIO $ p $ cleanTablesPl >> insertMany_ (people 1000) + , bench "simple" $ whnfIO $ s $ cleanTablesSim >> insertMany_ (people 1000) + ] + + , bgroup "selectList x100" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl >> insertMany_ (people 100) + void $ selectList ([] :: [Filter BenchPerson]) [] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim >> insertMany_ (people 100) + void $ selectList ([] :: [Filter BenchPerson]) [] + ] + + , bgroup "selectList filtered" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl >> insertMany_ (people 100) + void $ selectList [BenchPersonAge >. 50] [] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim >> insertMany_ (people 100) + void $ selectList [BenchPersonAge >. 50] [] + ] + + , bgroup "select IN x20 (= ANY vs IN (?,?...))" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl >> insertMany_ (people 100) + let names = [T.pack $ "Person " <> show i | i <- [1..20 :: Int]] + void $ selectList [BenchPersonName <-. names] [] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim >> insertMany_ (people 100) + let names = [T.pack $ "Person " <> show i | i <- [1..20 :: Int]] + void $ selectList [BenchPersonName <-. names] [] + ] + + -- The key pipelining benchmark: mapM get sends all queries + -- before reading any results (Hedis-style lazy replies). + -- Keys are pre-materialized (as Int64) to isolate the get path. + , env (do rawKeys <- runSqlPool + (do cleanTablesSim + ks <- mapM insert (people 100) + _ <- liftIO $ evaluate (length ks) + return (map (\k -> fromSqlKey k) ks)) + simplePool + evaluate rawKeys) $ \rawKeys -> + bgroup "get by key x100 (pipelined reads)" + [ bench "pipeline" $ whnfIO $ p $ do + let keys = map (toSqlKey . fromIntegral) rawKeys :: [Key BenchPerson] + results <- mapM get keys + -- Force all Maybe values (triggers the deferred reads) + liftIO $ forM_ results $ evaluate . fmap benchPersonAge + , bench "simple" $ whnfIO $ s $ do + let keys = map (toSqlKey . fromIntegral) rawKeys :: [Key BenchPerson] + results <- mapM get keys + liftIO $ forM_ results $ evaluate . fmap benchPersonAge + ] + + -- Insert pipelining: mapM insert sends all INSERT RETURNING + -- queries before reading any keys back. + , bgroup "insert x100 (pipelined RETURNING)" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl + keys <- mapM insert (people 100) + liftIO $ void $ evaluate (length keys) + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim + keys <- mapM insert (people 100) + liftIO $ void $ evaluate (length keys) + ] + + , bgroup "update x100" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl + keys <- mapM insert (people 100) + forM_ keys $ \k -> update k [BenchPersonAge =. 99] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim + keys <- mapM insert (people 100) + forM_ keys $ \k -> update k [BenchPersonAge =. 99] + ] + + , bgroup "delete x100" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl + keys <- mapM insert (people 100) + forM_ keys delete + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim + keys <- mapM insert (people 100) + mapM_ delete keys + ] + + , bgroup "upsert x100" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl + forM_ (people 100) $ \person -> upsert person [BenchPersonAge =. benchPersonAge person] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim + forM_ (people 100) $ \person -> upsert person [BenchPersonAge =. benchPersonAge person] + ] + + , bgroup "replace x100" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl + keys <- mapM insert (people 100) + forM_ (zip keys (people 100)) $ \(k, person) -> + replace k (person { benchPersonAge = 99 }) + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim + keys <- mapM insert (people 100) + forM_ (zip keys (people 100)) $ \(k, person) -> + replace k (person { benchPersonAge = 99 }) + ] + + , bgroup "deleteWhere x100" + [ bench "pipeline" $ whnfIO $ p $ do + cleanTablesPl >> insertMany_ (people 100) + forM_ [1..100 :: Int] $ \i -> + deleteWhere [BenchPersonName ==. T.pack ("Person " <> show i)] + , bench "simple" $ whnfIO $ s $ do + cleanTablesSim >> insertMany_ (people 100) + forM_ [1..100 :: Int] $ \i -> + deleteWhere [BenchPersonName ==. T.pack ("Person " <> show i)] + ] + ] diff --git a/persistent-postgresql-ng/bench/delay-proxy.py b/persistent-postgresql-ng/bench/delay-proxy.py new file mode 100644 index 000000000..64be36ba7 --- /dev/null +++ b/persistent-postgresql-ng/bench/delay-proxy.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +"""TCP proxy that adds artificial latency to every packet. + +Usage: python3 delay-proxy.py + +Each direction (client→server and server→client) gets the delay added, +so the effective round-trip latency increase is 2 * delay_ms. +""" + +import asyncio +import sys +import time + + +async def pipe(reader, writer, delay_sec, label): + try: + while True: + data = await reader.read(65536) + if not data: + break + if delay_sec > 0: + await asyncio.sleep(delay_sec) + writer.write(data) + await writer.drain() + except (ConnectionResetError, BrokenPipeError): + pass + finally: + writer.close() + + +async def handle_client(client_reader, client_writer, target_host, target_port, delay_sec): + try: + server_reader, server_writer = await asyncio.open_connection(target_host, target_port) + except Exception as e: + print(f"Failed to connect to {target_host}:{target_port}: {e}") + client_writer.close() + return + + await asyncio.gather( + pipe(client_reader, server_writer, delay_sec, "c→s"), + pipe(server_reader, client_writer, delay_sec, "s→c"), + ) + + +async def main(): + if len(sys.argv) != 5: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + + listen_port = int(sys.argv[1]) + target_host = sys.argv[2] + target_port = int(sys.argv[3]) + delay_ms = float(sys.argv[4]) + delay_sec = delay_ms / 1000.0 + + async def on_connect(reader, writer): + await handle_client(reader, writer, target_host, target_port, delay_sec) + + server = await asyncio.start_server(on_connect, "127.0.0.1", listen_port) + addr = server.sockets[0].getsockname() + print(f"Delay proxy listening on {addr[0]}:{addr[1]}") + print(f"Forwarding to {target_host}:{target_port} with {delay_ms}ms delay per direction") + print(f"Effective RTT increase: {delay_ms * 2}ms") + + async with server: + await server.serve_forever() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/persistent-postgresql-ng/bench/run-with-latency.sh b/persistent-postgresql-ng/bench/run-with-latency.sh new file mode 100755 index 000000000..b07fffd1a --- /dev/null +++ b/persistent-postgresql-ng/bench/run-with-latency.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# Run benchmarks with artificial network latency on the loopback interface. +# +# Usage: +# sudo ./bench/run-with-latency.sh [delay_ms] +# +# Default delay: 1ms (simulates a local datacenter hop) +# +# Requires root for pfctl/dnctl. Cleans up on exit. + +set -euo pipefail + +DELAY_MS="${1:-1}" +PG_PORT="${PGPORT:-5432}" +PIPE_NR=42 + +cleanup() { + echo "Cleaning up dummynet rules..." + pfctl -f /etc/pf.conf 2>/dev/null || true + dnctl pipe "$PIPE_NR" delete 2>/dev/null || true + echo "Done." +} +trap cleanup EXIT + +echo "=== Adding ${DELAY_MS}ms latency to loopback port ${PG_PORT} ===" + +dnctl pipe "$PIPE_NR" config delay "${DELAY_MS}ms" + +# Build pf rules: anchor existing config + dummynet for localhost:PG_PORT +# Traffic in both directions goes through the pipe (so RTT = 2 * DELAY_MS) +{ + # Keep existing pf.conf rules + cat /etc/pf.conf + echo "" + echo "dummynet in quick on lo0 proto tcp from any to any port ${PG_PORT} pipe ${PIPE_NR}" + echo "dummynet out quick on lo0 proto tcp from any port ${PG_PORT} to any pipe ${PIPE_NR}" +} | pfctl -f - -e 2>/dev/null || true + +echo "=== Verifying latency ===" +# Quick check: time a trivial psql round-trip +START=$(python3 -c "import time; print(time.time())") +psql -h localhost -p "$PG_PORT" -c "SELECT 1" -t -q > /dev/null 2>&1 || true +END=$(python3 -c "import time; print(time.time())") +ELAPSED=$(python3 -c "print(f'{(${END} - ${START})*1000:.1f}ms')") +echo "Single SELECT 1 round-trip: ${ELAPSED} (expected ~${DELAY_MS}ms added)" + +echo "" +echo "=== Running benchmarks with ${DELAY_MS}ms latency ===" +cd "$(dirname "$0")/../.." +stack bench persistent-postgresql-ng \ + --benchmark-arguments "--output persistent-postgresql-ng/bench/bench-latency-${DELAY_MS}ms.html" diff --git a/persistent-postgresql-ng/cbits/hs_libpq_extra.c b/persistent-postgresql-ng/cbits/hs_libpq_extra.c new file mode 100644 index 000000000..d53219b10 --- /dev/null +++ b/persistent-postgresql-ng/cbits/hs_libpq_extra.c @@ -0,0 +1,74 @@ +/* + * Extra libpq helpers for persistent-postgresql-ng. + * + * Provides: + * - hs_PQsetChunkedRowsMode: compile-time guarded wrapper for PG17+ + * - hs_has_chunked_rows_support: reports whether chunked mode is available + * - hs_PQresultStatusRaw: raw int result status (bypasses Haskell enum) + * + * We avoid #include so that this file compiles without + * needing the libpq header search path in the Cabal build. Instead + * we forward-declare the opaque types and functions we use. + */ + +/* Opaque libpq types */ +typedef struct pg_conn PGconn; +typedef struct pg_result PGresult; + +/* libpq functions we call — these are resolved at link time from the + * libpq shared library that postgresql-libpq already links against. */ +extern int PQresultStatus(const PGresult *res); + +/* + * PGRES_TUPLES_CHUNK (= 12) was added in PostgreSQL 17. We hard-code the + * value here because the PG enum values are ABI-stable. + * + * For PQsetChunkedRowsMode we cannot simply call it when the header + * doesn't declare it. We use a weak symbol check on platforms that + * support it (GCC/Clang) to detect availability at load time. + */ + +#if defined(__GNUC__) || defined(__clang__) + +/* Weak declaration: resolves to NULL if the symbol doesn't exist in + * the linked libpq (i.e. libpq < 17). */ +extern int PQsetChunkedRowsMode(PGconn *conn, int chunkSize) + __attribute__((weak)); + +int hs_PQsetChunkedRowsMode(PGconn *conn, int chunkSize) { + if (PQsetChunkedRowsMode) { + return PQsetChunkedRowsMode(conn, chunkSize); + } + return 0; +} + +int hs_has_chunked_rows_support(void) { + return PQsetChunkedRowsMode != 0; +} + +#else + +/* Fallback for compilers without weak symbol support: no chunked mode. */ +int hs_PQsetChunkedRowsMode(PGconn *conn, int chunkSize) { + (void)conn; + (void)chunkSize; + return 0; +} + +int hs_has_chunked_rows_support(void) { + return 0; +} + +#endif + +/* + * Return the raw ExecStatus integer for a PGresult. + * + * The Haskell postgresql-libpq library's resultStatus throws on unknown + * status codes (e.g. PGRES_TUPLES_CHUNK = 12 when compiled against + * older headers). This helper lets us inspect the raw value and branch + * in Haskell without hitting that failure path. + */ +int hs_PQresultStatusRaw(const PGresult *res) { + return PQresultStatus(res); +} diff --git a/persistent-postgresql-ng/persistent-postgresql-ng.cabal b/persistent-postgresql-ng/persistent-postgresql-ng.cabal new file mode 100644 index 000000000..ec2393b1a --- /dev/null +++ b/persistent-postgresql-ng/persistent-postgresql-ng.cabal @@ -0,0 +1,142 @@ +name: persistent-postgresql-ng +version: 0.1.0.0 +license: MIT +license-file: LICENSE +author: Ian Duncan +maintainer: ian@ianduncan.me +synopsis: Pipelined PostgreSQL backend for persistent using binary protocol. +description: + A PostgreSQL backend for persistent that uses the binary wire protocol + via postgresql-libpq (>= 0.11) and postgresql-binary, with support for + libpq pipeline mode to reduce round-trips. +category: Database +stability: Experimental +cabal-version: >=1.10 +build-type: Simple +extra-source-files: + sql/*.sql + +library + build-depends: + aeson >=1.0 + , base >=4.9 && <5 + , bytestring >=0.10 + , bytestring-strict-builder >=0.4 + , conduit >=1.2.12 + , containers >=0.5 + , file-embed >=0.0.16 + , monad-logger >=0.3.25 + , mtl + , persistent >=2.18.1 && <3 + , postgresql-binary >=0.13 && <0.15 + , postgresql-libpq >=0.11 && <0.12 + , resource-pool + , resourcet >=1.1.9 + , scientific >=0.3 + , text >=1.2 + , time >=1.6 + , transformers >=0.5 + , unliftio-core + , vault + , vector + + exposed-modules: + Database.Persist.Postgresql.Pipeline + Database.Persist.Postgresql.Pipeline.Internal + Database.Persist.Postgresql.Pipeline.FFI + Database.Persist.Postgresql.Internal + Database.Persist.Postgresql.Internal.Decoding + Database.Persist.Postgresql.Internal.Encoding + Database.Persist.Postgresql.Internal.DirectDecode + Database.Persist.Postgresql.Internal.DirectEncode + Database.Persist.Postgresql.Internal.PgCodec + Database.Persist.Postgresql.Internal.Migration + Database.Persist.Postgresql.Internal.PgType + Database.Persist.Postgresql.Internal.Placeholders + Database.Persist.Postgresql.JSON + Database.Persist.Postgresql.CustomType + + c-sources: cbits/hs_libpq_extra.c + + ghc-options: -Wall + default-language: Haskell2010 + +test-suite test + type: exitcode-stdio-1.0 + main-is: main.hs + hs-source-dirs: test + other-modules: + ArrayAggTest + BinaryRoundTripSpec + DirectDecodeSpec + DirectEntityPOC + CustomConstraintTest + EquivalentTypeTestPostgres + ImplicitUuidSpec + InCollapseSpec + JSONTest + MigrationReferenceSpec + MigrationSpec + PgPipelineInit + PgPipelineIntervalTest + PipelineDeferralSpec + PipelineModeSpec + PipelineRegressionSpec + PlaceholderSpec + UpsertWhere + + ghc-options: -Wall + build-depends: + aeson + , base >=4.9 && <5 + , bytestring + , containers + , fast-logger + , hspec >=2.4 + , hspec-expectations + , hspec-expectations-lifted + , HUnit + , monad-logger + , persistent + , persistent-postgresql-ng + , persistent-qq + , persistent-test + , postgresql-binary + , postgresql-libpq + , QuickCheck + , quickcheck-instances + , resourcet + , scientific + , text + , time + , transformers + , unliftio + , unliftio-core + , unordered-containers + , vector + + default-language: Haskell2010 + +benchmark bench + type: exitcode-stdio-1.0 + main-is: Main.hs + hs-source-dirs: bench + ghc-options: -Wall -O2 -rtsopts + build-depends: + base >=4.9 && <5 + , bytestring + , criterion >=1.5 + , monad-logger + , persistent + , persistent-postgresql + , persistent-postgresql-ng + , resource-pool + , text + , time + , transformers + + default-language: Haskell2010 + +source-repository head + type: git + location: https://github.com/yesodweb/persistent.git diff --git a/persistent-postgresql-ng/sql/getForeignKeyReferences.sql b/persistent-postgresql-ng/sql/getForeignKeyReferences.sql new file mode 100644 index 000000000..3bd1f8d35 --- /dev/null +++ b/persistent-postgresql-ng/sql/getForeignKeyReferences.sql @@ -0,0 +1,84 @@ +-- Get all foreign key references among the given set of table names in the +-- current namespace/schema. This query is used by the migrator to check whether +-- foreign key definitions are up to date. +-- +-- This query takes one parameter: an array of table names. +with + foreign_constraints as ( + select + c.* + from + pg_constraint AS c + inner join pg_class src_table + on src_table.oid = c.conrelid + inner join pg_namespace ns + on ns.oid = c.connamespace + where + -- f = foreign key constraint + c.contype = 'f' + and src_table.relname = ANY (?) + and ns.nspname = current_schema() + ), + foreign_constraint_with_source_columns as ( + select + c.oid, + array_agg( + a.attname::text + ORDER BY + k.n ASC + ) as column_names + from + foreign_constraints AS c + -- conkey is a list of the column indices on the source + -- table + CROSS JOIN LATERAL unnest(c.conkey) WITH ORDINALITY AS k (attnum, n) + INNER JOIN pg_attribute AS a + -- conrelid is the id of the source table + ON k.attnum = a.attnum AND c.conrelid = a.attrelid + group by + c.oid + ), + foreign_constraint_with_foreign_columns as ( + select + c.oid, + array_agg( + a.attname::text + ORDER BY + k.n ASC + ) as foreign_column_names + from + foreign_constraints AS c + -- confkey is a list of the column indices on the foreign + -- table + CROSS JOIN LATERAL unnest(c.confkey) WITH ORDINALITY AS k (attnum, n) + JOIN pg_attribute AS a + -- confrelid is the id of the foreign table + ON k.attnum = a.attnum AND c.confrelid = a.attrelid + group by + c.oid + ) +SELECT + fkey_constraint.conname::text as fkey_name, + src_table.relname::text AS source_table, + foreign_table.relname::text AS referenced_table, + -- NB: postgres arrays are one-indexed! + src_columns.column_names[1], + foreign_columns.foreign_column_names[1], + fkey_constraint.confupdtype, + fkey_constraint.confdeltype +from + foreign_constraints AS fkey_constraint + inner join foreign_constraint_with_source_columns src_columns + on src_columns.oid = fkey_constraint.oid + inner join foreign_constraint_with_foreign_columns foreign_columns + on foreign_columns.oid = fkey_constraint.oid + inner join pg_class src_table + on src_table.oid = fkey_constraint.conrelid + inner join pg_class foreign_table + on foreign_table.oid = fkey_constraint.confrelid + +-- In the future, we may want to look at multi-column FK constraints too. but +-- for now we only care about single-column constraints. +where + array_length(src_columns.column_names, 1) = 1 + and array_length(foreign_columns.foreign_column_names, 1) = 1; diff --git a/persistent-postgresql-ng/test/ArrayAggTest.hs b/persistent-postgresql-ng/test/ArrayAggTest.hs new file mode 100644 index 000000000..cf97a7c10 --- /dev/null +++ b/persistent-postgresql-ng/test/ArrayAggTest.hs @@ -0,0 +1,71 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module ArrayAggTest where + +import Control.Monad.IO.Class (MonadIO) +import Data.Aeson +import Data.List (sort) +import qualified Data.Text as T +import Test.Hspec.Expectations () + +import PersistentTestModels +import PgPipelineInit + +share + [mkPersist persistSettings, mkMigrate "jsonTestMigrate"] + [persistLowerCase| + TestValue + json Value +|] + +cleanDB + :: (BaseBackend backend ~ SqlBackend, PersistQueryWrite backend, MonadIO m) + => ReaderT backend m () +cleanDB = deleteWhere ([] :: [Filter TestValue]) + +emptyArr :: Value +emptyArr = toJSON ([] :: [Value]) + +specs :: Spec +specs = do + describe "rawSql/array_agg" $ do + let + runArrayAggTest :: (PersistField [a], Ord a, Show a) => Text -> [a] -> Assertion + runArrayAggTest dbField expected = runConnAssert $ do + void $ + insertMany + [ UserPT "a" $ Just "b" + , UserPT "c" $ Just "d" + , UserPT "e" Nothing + , UserPT "g" $ Just "h" + ] + escape <- getEscapeRawNameFunction + let + query = + T.concat + [ "SELECT array_agg(" + , escape dbField + , ") " + , "FROM " + , escape "UserPT" + ] + [Single xs] <- rawSql query [] + liftIO $ sort xs @?= expected + + it "works for [Text]" $ do + runArrayAggTest "ident" ["a", "c", "e", "g" :: Text] + it "works for [Maybe Text]" $ do + runArrayAggTest "password" [Nothing, Just "b", Just "d", Just "h" :: Maybe Text] diff --git a/persistent-postgresql-ng/test/BinaryRoundTripSpec.hs b/persistent-postgresql-ng/test/BinaryRoundTripSpec.hs new file mode 100644 index 000000000..fa50acca0 --- /dev/null +++ b/persistent-postgresql-ng/test/BinaryRoundTripSpec.hs @@ -0,0 +1,152 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} + +-- | Pure round-trip tests for binary encoding/decoding of PersistValues. +module BinaryRoundTripSpec (specs) where + +import Test.Hspec +import Data.ByteString (ByteString) +import Data.Int (Int64) +import Data.Text (Text) +import Data.Time +import qualified Database.PostgreSQL.LibPQ as LibPQ +import Database.Persist (PersistValue (..)) +import Database.Persist.Postgresql.Internal.Encoding (encodePersistValue) +import Database.Persist.Postgresql.Internal.Decoding (decodePersistValue) + +-- | Encode a value then decode it, verifying round-trip equality. +roundTrip :: PersistValue -> Either Text PersistValue +roundTrip pv = case encodePersistValue pv of + Nothing -> Right PersistNull + Just (oid, bs, _fmt) -> decodePersistValue oid (Just bs) + +specs :: Spec +specs = describe "Binary round-trip" $ do + describe "scalar types" $ do + it "PersistNull" $ do + roundTrip PersistNull `shouldBe` Right PersistNull + + it "PersistBool True" $ do + roundTrip (PersistBool True) `shouldBe` Right (PersistBool True) + + it "PersistBool False" $ do + roundTrip (PersistBool False) `shouldBe` Right (PersistBool False) + + it "PersistInt64" $ do + roundTrip (PersistInt64 42) `shouldBe` Right (PersistInt64 42) + + it "PersistInt64 negative" $ do + roundTrip (PersistInt64 (-9999)) `shouldBe` Right (PersistInt64 (-9999)) + + it "PersistInt64 max" $ do + roundTrip (PersistInt64 maxBound) `shouldBe` Right (PersistInt64 maxBound) + + it "PersistDouble" $ do + roundTrip (PersistDouble 3.14) `shouldBe` Right (PersistDouble 3.14) + + it "PersistText" $ do + roundTrip (PersistText "hello") `shouldBe` Right (PersistText "hello") + + it "PersistText empty" $ do + roundTrip (PersistText "") `shouldBe` Right (PersistText "") + + it "PersistText unicode" $ do + roundTrip (PersistText "\x1F600\x00E9") `shouldBe` Right (PersistText "\x1F600\x00E9") + + it "PersistByteString" $ do + roundTrip (PersistByteString "bytes\x00\xff") + `shouldBe` Right (PersistByteString "bytes\x00\xff") + + it "PersistDay" $ do + let d = fromGregorian 2024 6 15 + roundTrip (PersistDay d) `shouldBe` Right (PersistDay d) + + it "PersistTimeOfDay" $ do + let t = TimeOfDay 14 30 0 + roundTrip (PersistTimeOfDay t) `shouldBe` Right (PersistTimeOfDay t) + + it "PersistUTCTime" $ do + let t = UTCTime (fromGregorian 2024 6 15) (12 * 3600 + 30 * 60) + roundTrip (PersistUTCTime t) `shouldBe` Right (PersistUTCTime t) + + it "PersistRational" $ do + -- Rational -> numeric -> Rational round-trip + let r = 355 / 113 -- approximation of pi + case roundTrip (PersistRational r) of + Right (PersistRational r') -> + -- numeric has finite precision, so we check approximate equality + abs (fromRational (r - r') :: Double) `shouldSatisfy` (< 1e-10) + other -> expectationFailure $ "Expected PersistRational, got: " ++ show other + + it "PersistRational integer" $ do + roundTrip (PersistRational 42) `shouldBe` Right (PersistRational 42) + + describe "PersistList (JSON text for storage)" $ do + it "PersistList encodes with OID 0 (unknown)" $ do + let val = PersistList [PersistInt64 1, PersistInt64 2, PersistInt64 3] + case encodePersistValue val of + Just (oid, _, _) -> + oid `shouldBe` LibPQ.Oid 0 -- unknown, PG infers from column + Nothing -> expectationFailure "Expected Just for PersistList encoding" + + it "PersistList empty encodes with OID 0" $ do + let val = PersistList [] + case encodePersistValue val of + Just (oid, _, _) -> + oid `shouldBe` LibPQ.Oid 0 -- unknown, PG infers from column + Nothing -> expectationFailure "Expected Just for empty PersistList encoding" + + describe "PersistArray (native PostgreSQL arrays)" $ do + it "PersistArray of Int64" $ do + let val = PersistArray [PersistInt64 1, PersistInt64 2, PersistInt64 3] + case encodePersistValue val of + Just (oid, _, _) -> do + oid `shouldBe` LibPQ.Oid 1016 -- int8[] + case roundTrip val of + Right (PersistList xs) -> + xs `shouldBe` [PersistInt64 1, PersistInt64 2, PersistInt64 3] + other -> expectationFailure $ "Expected PersistList, got: " ++ show other + Nothing -> expectationFailure "Expected Just for array encoding" + + it "PersistArray of Text" $ do + let val = PersistArray [PersistText "hello", PersistText "world"] + case encodePersistValue val of + Just (oid, _, _) -> do + oid `shouldBe` LibPQ.Oid 1009 -- text[] + case roundTrip val of + Right (PersistList xs) -> + xs `shouldBe` [PersistText "hello", PersistText "world"] + other -> expectationFailure $ "Expected PersistList, got: " ++ show other + Nothing -> expectationFailure "Expected Just for array encoding" + + it "PersistArray of Bool" $ do + let val = PersistArray [PersistBool True, PersistBool False] + case encodePersistValue val of + Just (oid, _, _) -> + oid `shouldBe` LibPQ.Oid 1000 -- bool[] + Nothing -> expectationFailure "Expected Just for array encoding" + + it "PersistArray of Double" $ do + let val = PersistArray [PersistDouble 1.5, PersistDouble 2.5] + case encodePersistValue val of + Just (oid, _, _) -> + oid `shouldBe` LibPQ.Oid 1022 -- float8[] + Nothing -> expectationFailure "Expected Just for array encoding" + + it "PersistArray with nulls" $ do + let val = PersistArray [PersistInt64 1, PersistNull, PersistInt64 3] + case encodePersistValue val of + Just (oid, _, _) -> do + oid `shouldBe` LibPQ.Oid 1016 -- int8[] + case roundTrip val of + Right (PersistList xs) -> + xs `shouldBe` [PersistInt64 1, PersistNull, PersistInt64 3] + other -> expectationFailure $ "Expected PersistList, got: " ++ show other + Nothing -> expectationFailure "Expected Just for array encoding" + + it "PersistArray empty encodes with OID 0" $ do + let val = PersistArray [] + case encodePersistValue val of + Just (oid, _, _) -> + oid `shouldBe` LibPQ.Oid 0 -- unknown, PG infers from column + Nothing -> expectationFailure "Expected Just for empty array encoding" diff --git a/persistent-postgresql-ng/test/CustomConstraintTest.hs b/persistent-postgresql-ng/test/CustomConstraintTest.hs new file mode 100644 index 000000000..5944e2e60 --- /dev/null +++ b/persistent-postgresql-ng/test/CustomConstraintTest.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE EmptyDataDecls #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} + +module CustomConstraintTest where + +import qualified Data.Text as T +import PgPipelineInit + +share + [mkPersist sqlSettings, mkMigrate "customConstraintMigrate"] + [persistLowerCase| +CustomConstraint1 + some_field Text + deriving Show + +CustomConstraint2 + cc_id CustomConstraint1Id constraint=custom_constraint + deriving Show + +CustomConstraint3 + cc_id1 CustomConstraint1Id + cc_id2 CustomConstraint1Id + deriving Show +|] + +specs :: Spec +specs = do + describe "custom constraint used in migration" $ do + it "custom constraint is actually created" $ runConnAssert $ do + void $ runMigrationSilent customConstraintMigrate + void $ runMigrationSilent customConstraintMigrate + let + query = + T.concat + [ "SELECT DISTINCT COUNT(*) " + , "FROM information_schema.constraint_column_usage ccu, " + , "information_schema.key_column_usage kcu, " + , "information_schema.table_constraints tc " + , "WHERE tc.constraint_type='FOREIGN KEY' " + , "AND kcu.constraint_name=tc.constraint_name " + , "AND ccu.constraint_name=kcu.constraint_name " + , "AND kcu.ordinal_position=1 " + , "AND ccu.table_name=? " + , "AND ccu.column_name=? " + , "AND kcu.table_name=? " + , "AND kcu.column_name=? " + , "AND tc.constraint_name=?" + ] + [Single exists_] <- + rawSql + query + [ PersistText "custom_constraint1" + , PersistText "id" + , PersistText "custom_constraint2" + , PersistText "cc_id" + , PersistText "custom_constraint" + ] + liftIO $ 1 @?= (exists_ :: Int) + + it "allows multiple constraints on a single column" $ runConnAssert $ do + void $ runMigrationSilent customConstraintMigrate + rawExecute + "ALTER TABLE \"custom_constraint3\" ADD CONSTRAINT \"extra_constraint\" FOREIGN KEY(\"cc_id1\") REFERENCES \"custom_constraint1\"(\"id\")" + [] + void $ getMigration customConstraintMigrate + pure () diff --git a/persistent-postgresql-ng/test/DirectDecodeSpec.hs b/persistent-postgresql-ng/test/DirectDecodeSpec.hs new file mode 100644 index 000000000..c43ea7d3e --- /dev/null +++ b/persistent-postgresql-ng/test/DirectDecodeSpec.hs @@ -0,0 +1,128 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE FlexibleContexts #-} + +-- | Tests for the direct decode path: FieldDecode instances for PgRowEnv, +-- FromRow composition, and $N parameter detection. +module DirectDecodeSpec (specs) where + +import Test.Hspec + +import Data.ByteString (ByteString) +import Data.IORef (newIORef) +import Data.Int (Int16, Int32, Int64) +import Data.Text (Text) +import Data.Time +import qualified Data.Vector as V +import qualified Database.PostgreSQL.LibPQ as LibPQ +import qualified PostgreSQL.Binary.Encoding as PE + +import Database.Persist.DirectDecode + (FieldDecode (..), FromRow (..), RowReader, runRowReader, nextField) +import Database.Persist.Names (FieldNameDB (..)) +import Database.Persist.Postgresql.Internal.DirectDecode (PgRowEnv (..)) +import Database.Persist.Postgresql.Internal.PgType +import Database.Persist.Postgresql.Internal.Placeholders + (ParamStyle (..), detectParamStyle) + +specs :: Spec +specs = do + paramStyleSpecs + fieldDecodeRoundTripSpecs + +--------------------------------------------------------------------------- +-- Parameter style detection +--------------------------------------------------------------------------- + +paramStyleSpecs :: Spec +paramStyleSpecs = describe "detectParamStyle" $ do + it "detects ? as QuestionMarkParams" $ do + detectParamStyle "SELECT * FROM t WHERE a = ?" `shouldBe` QuestionMarkParams + + it "detects $1 as Numbered 1" $ do + detectParamStyle "SELECT * FROM t WHERE a = $1" `shouldBe` NumberedParams 1 + + it "detects max $N" $ do + detectParamStyle "SELECT * FROM t WHERE a = $1 AND b = $3" `shouldBe` NumberedParams 3 + + it "detects $1 used multiple times" $ do + detectParamStyle "SELECT * FROM t WHERE a = $1 OR b = $1" `shouldBe` NumberedParams 1 + + it "ignores $N inside string literals" $ do + detectParamStyle "SELECT * FROM t WHERE a = '$1'" `shouldBe` QuestionMarkParams + + it "ignores $N inside quoted identifiers" $ do + detectParamStyle "SELECT * FROM \"$1\" WHERE a = ?" `shouldBe` QuestionMarkParams + + it "ignores $N inside line comments" $ do + detectParamStyle "SELECT * -- $1\nFROM t WHERE a = ?" `shouldBe` QuestionMarkParams + + it "ignores $N inside block comments" $ do + detectParamStyle "SELECT /* $1 */ * FROM t WHERE a = ?" `shouldBe` QuestionMarkParams + + it "handles empty SQL" $ do + detectParamStyle "" `shouldBe` QuestionMarkParams + + it "handles no params" $ do + detectParamStyle "SELECT 1" `shouldBe` QuestionMarkParams + +--------------------------------------------------------------------------- +-- FieldDecode round-trip: encode with postgresql-binary → build a +-- fake PGresult-like PgRowEnv → decode via FieldDecode. +-- +-- These tests verify the FieldDecode instances decode correctly from +-- known binary encodings without needing a live PostgreSQL connection. +-- We construct a minimal PgRowEnv by encoding values into ByteStrings +-- with known OIDs and wrapping them in a structure the instances can +-- read from. +--------------------------------------------------------------------------- + +-- For pure tests without a real PGresult, we can't easily construct one. +-- Instead, we test the parameter style detection (above) and verify the +-- FieldDecode instances are correctly generated by TH by checking that +-- the FromRow instance for the DataTypeTable entity (from main.hs) is +-- available. Full integration testing requires a PostgreSQL connection. +-- +-- The key value of these tests: +-- 1. detectParamStyle correctly handles all SQL patterns +-- 2. The TH generates FromRow instances that compile +-- 3. The FieldDecode instance dispatch on PgType is correct (tested +-- via the BinaryRoundTripSpec for the PersistValue path, and via +-- integration tests for the direct path) + +fieldDecodeRoundTripSpecs :: Spec +fieldDecodeRoundTripSpecs = describe "FieldDecode PgRowEnv" $ do + -- These are compile-time tests: if the module compiles, the + -- FieldDecode instances and nextField combinators exist. + it "nextField @Text compiles" $ do + let _decoder :: RowReader PgRowEnv Text + _decoder = nextField (FieldNameDB "col") + True `shouldBe` True + + it "nextField @(Maybe Int64) compiles" $ do + let _decoder :: RowReader PgRowEnv (Maybe Int64) + _decoder = nextField (FieldNameDB "col") + True `shouldBe` True + + it "applicative composition compiles" $ do + let _decoder :: RowReader PgRowEnv (Text, Int64, Maybe Bool) + _decoder = (,,) + <$> nextField (FieldNameDB "name") + <*> nextField (FieldNameDB "age") + <*> nextField (FieldNameDB "active") + True `shouldBe` True + + it "all scalar FieldDecode instances compile" $ do + let _bool :: RowReader PgRowEnv Bool = nextField (FieldNameDB "") + _int16 :: RowReader PgRowEnv Int16 = nextField (FieldNameDB "") + _int32 :: RowReader PgRowEnv Int32 = nextField (FieldNameDB "") + _int64 :: RowReader PgRowEnv Int64 = nextField (FieldNameDB "") + _int :: RowReader PgRowEnv Int = nextField (FieldNameDB "") + _double :: RowReader PgRowEnv Double = nextField (FieldNameDB "") + _text :: RowReader PgRowEnv Text = nextField (FieldNameDB "") + _bs :: RowReader PgRowEnv ByteString = nextField (FieldNameDB "") + _day :: RowReader PgRowEnv Day = nextField (FieldNameDB "") + _tod :: RowReader PgRowEnv TimeOfDay = nextField (FieldNameDB "") + _utc :: RowReader PgRowEnv UTCTime = nextField (FieldNameDB "") + True `shouldBe` True diff --git a/persistent-postgresql-ng/test/DirectEntityPOC.hs b/persistent-postgresql-ng/test/DirectEntityPOC.hs new file mode 100644 index 000000000..bf88a4e36 --- /dev/null +++ b/persistent-postgresql-ng/test/DirectEntityPOC.hs @@ -0,0 +1,183 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +-- | Proof-of-concept: the DirectEntity + Typeable/HRefl approach for +-- bridging the SqlBackend type-erasure gap. +-- +-- Part 1 (unit tests): Mock env, compile-time + runtime verification. +-- Part 2 (integration): Real PostgreSQL query through SqlBackend using +-- rawSqlDirectCompat + DirectEntity with PgRowEnv. +module DirectEntityPOC (specs, integrationSpecs, rawSqlDirectCompatTest, PgPair (..)) where + +import Control.Monad.IO.Class (MonadIO, liftIO) +import Control.Monad.Trans.Reader (ReaderT) +import Data.Int (Int64) +import Data.Proxy (Proxy (..)) +import Data.Text (Text) +import qualified Data.Text as T +import Data.Typeable (Typeable) +import Type.Reflection (eqTypeRep, typeRep, (:~~:)(HRefl)) +import Test.Hspec + +import Database.Persist.DirectDecode +import Database.Persist.Names (FieldNameDB (..)) +import Database.Persist.Sql.DirectRaw (rawSqlDirectCompat) +import Database.Persist.SqlBackend.Internal (SqlBackend) +import Database.Persist.Postgresql.Internal.DirectDecode (PgRowEnv (..)) + +--------------------------------------------------------------------------- +-- Mock env for unit tests +--------------------------------------------------------------------------- + +data MockEnv = MockEnv + { mockRow :: !Int + , mockData :: ![(Text, Text)] + } + +instance FieldDecode MockEnv Text where + prepareField _env name _col _onErr onOk = + onOk $ FieldRunner $ \env' _onErr' onOk' -> + case lookup (unFieldNameDB name) (mockData env') of + Just t -> onOk' t + Nothing -> onOk' "" + +--------------------------------------------------------------------------- +-- Test record with MockEnv support +--------------------------------------------------------------------------- + +data TestUser = TestUser + { testUserName :: !Text + , testUserCity :: !Text + } deriving (Eq, Show) + +instance FromRow MockEnv TestUser where + rowReader = TestUser + <$> nextField (FieldNameDB "name") + <*> nextField (FieldNameDB "city") + + prepareRow env ctr onErr onOk = do + col0 <- advanceCounter ctr + prepareField @MockEnv @Text env (FieldNameDB "name") col0 onErr $ \rName -> do + col1 <- advanceCounter ctr + prepareField @MockEnv @Text env (FieldNameDB "city") col1 onErr $ \rCity -> + onOk $ RowDecoder $ \env' onErr' onOk' -> + runField rName env' onErr' $ \n -> + runField rCity env' onErr' $ \c -> + onOk' (TestUser n c) + +instance DirectEntity TestUser where + lookupDirectDecoder :: forall env. Typeable env + => Proxy env -> Maybe (RowDecoder env TestUser) + lookupDirectDecoder _ = + case eqTypeRep (typeRep @env) (typeRep @MockEnv) of + Just HRefl -> Just $ RowDecoder $ \env' onErr' onOk' -> do + ctr <- newCounter + prepareRow env' ctr onErr' $ \decoder -> + runRowDecoder decoder env' onErr' onOk' + Nothing -> Nothing + +--------------------------------------------------------------------------- +-- A record for the real PgRowEnv integration test +--------------------------------------------------------------------------- + +data PgPair = PgPair + { pgPairName :: !Text + , pgPairAge :: !Int64 + } deriving (Eq, Show) + +instance FromRow PgRowEnv PgPair where + rowReader = PgPair + <$> nextField (FieldNameDB "name") + <*> nextField (FieldNameDB "age") + + prepareRow env ctr onErr onOk = do + col0 <- advanceCounter ctr + prepareField @PgRowEnv @Text env (FieldNameDB "name") col0 onErr $ \rName -> do + col1 <- advanceCounter ctr + prepareField @PgRowEnv @Int64 env (FieldNameDB "age") col1 onErr $ \rAge -> + onOk $ RowDecoder $ \env' onErr' onOk' -> + runField rName env' onErr' $ \n -> + runField rAge env' onErr' $ \a -> + onOk' (PgPair n a) + +instance DirectEntity PgPair where + lookupDirectDecoder :: forall env. Typeable env + => Proxy env -> Maybe (RowDecoder env PgPair) + lookupDirectDecoder _ = + case eqTypeRep (typeRep @env) (typeRep @PgRowEnv) of + Just HRefl -> Just $ RowDecoder $ \env' onErr' onOk' -> do + ctr <- newCounter + prepareRow env' ctr onErr' $ \decoder -> + runRowDecoder decoder env' onErr' onOk' + Nothing -> Nothing + +--------------------------------------------------------------------------- +-- Unit tests (no database needed) +--------------------------------------------------------------------------- + +specs :: Spec +specs = describe "DirectEntity / SqlBackend bridge (unit)" $ do + it "lookupDirectDecoder returns Just for matching env" $ do + let mDecoder = lookupDirectDecoder @TestUser (Proxy @MockEnv) + case mDecoder of + Nothing -> expectationFailure "expected Just" + Just _ -> pure () + + it "lookupDirectDecoder returns Nothing for non-matching env" $ do + let mDecoder = lookupDirectDecoder @TestUser (Proxy @PgRowEnv) + case mDecoder of + Nothing -> pure () + Just _ -> expectationFailure "expected Nothing" + + it "end-to-end mock: decode through existential" $ do + let rows = + [ MockEnv 0 [("name", "Alice"), ("city", "NYC")] + , MockEnv 1 [("name", "Bob"), ("city", "SF")] + ] + case lookupDirectDecoder @TestUser (Proxy @MockEnv) of + Nothing -> expectationFailure "expected Just" + Just decoder -> do + let onErr e = fail (T.unpack e) + results <- mapM (\env -> runRowDecoderCPS decoder env onErr pure) rows + results `shouldBe` + [ TestUser "Alice" "NYC" + , TestUser "Bob" "SF" + ] + + it "PgPair: lookupDirectDecoder returns Just for PgRowEnv" $ do + let mDecoder = lookupDirectDecoder @PgPair (Proxy @PgRowEnv) + case mDecoder of + Nothing -> expectationFailure "expected Just" + Just _ -> pure () + +--------------------------------------------------------------------------- +-- Integration tests (require PostgreSQL) +--------------------------------------------------------------------------- + +-- | Run inside a SqlPersistT from the test harness. +integrationSpecs :: Spec +integrationSpecs = describe "DirectEntity / SqlBackend bridge (integration)" $ + pure () + +-- | The actual integration test action, run inside SqlPersistT by main.hs. +-- +-- Executes a raw SQL query through 'rawSqlDirectCompat' which uses: +-- 1. connDirectQueryCap on SqlBackend (existential PgRowEnv) +-- 2. DirectEntity PgPair (eqTypeRep recovers PgRowEnv) +-- 3. RowDecoder applied to real PGresult rows +rawSqlDirectCompatTest + :: (MonadIO m) + => ReaderT SqlBackend m (Maybe [PgPair]) +rawSqlDirectCompatTest = do + -- Use a CTE to avoid needing a real table + rawSqlDirectCompat + "SELECT name, age FROM (VALUES ('Alice'::text, 30::int8), ('Bob'::text, 25::int8)) AS t(name, age)" + [] diff --git a/persistent-postgresql-ng/test/EquivalentTypeTestPostgres.hs b/persistent-postgresql-ng/test/EquivalentTypeTestPostgres.hs new file mode 100644 index 000000000..74f85e196 --- /dev/null +++ b/persistent-postgresql-ng/test/EquivalentTypeTestPostgres.hs @@ -0,0 +1,54 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -Wno-unused-top-binds #-} + +module EquivalentTypeTestPostgres (specs) where + +import Control.Monad.Trans.Resource (runResourceT) +import qualified Data.Text as T + +import Database.Persist.TH +import PgPipelineInit + +share + [mkPersist sqlSettings, mkMigrate "migrateAll1"] + [persistLowerCase| +EquivalentType sql=equivalent_types + field1 Int sqltype=bigint + field2 T.Text sqltype=text + field3 T.Text sqltype=us_postal_code + deriving Eq Show +|] + +share + [mkPersist sqlSettings, mkMigrate "migrateAll2"] + [persistLowerCase| +EquivalentType2 sql=equivalent_types + field1 Int sqltype=int8 + field2 T.Text + field3 T.Text sqltype=us_postal_code + deriving Eq Show +|] + +specs :: Spec +specs = describe "doesn't migrate equivalent types" $ do + it "works" $ asIO $ runResourceT $ runConn $ do + _ <- rawExecute "DROP DOMAIN IF EXISTS us_postal_code CASCADE" [] + _ <- + rawExecute "CREATE DOMAIN us_postal_code AS TEXT CHECK(VALUE ~ '^\\d{5}$')" [] + + _ <- runMigrationSilent migrateAll1 + xs <- getMigration migrateAll2 + liftIO $ xs @?= [] diff --git a/persistent-postgresql-ng/test/ImplicitUuidSpec.hs b/persistent-postgresql-ng/test/ImplicitUuidSpec.hs new file mode 100644 index 000000000..a8a7197cc --- /dev/null +++ b/persistent-postgresql-ng/test/ImplicitUuidSpec.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module ImplicitUuidSpec where + +import PgPipelineInit + +import Data.Proxy +import Database.Persist.Postgresql.Pipeline + +import Database.Persist.ImplicitIdDef +import Database.Persist.ImplicitIdDef.Internal (fieldTypeFromTypeable) + +share + [ mkPersist (sqlSettingsUuid "uuid_generate_v1mc()") + , mkEntityDefList "entities" + ] + [persistLowerCase| + +WithDefUuid + name Text sqltype=varchar(80) + + deriving Eq Show Ord + +|] + +implicitUuidMigrate :: Migration +implicitUuidMigrate = do + runSqlCommand $ rawExecute "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"" [] + migrateModels entities + +wipe :: IO () +wipe = runConnAssert $ do + rawExecute "DROP TABLE with_def_uuid;" [] + runMigration implicitUuidMigrate + +itDb + :: String -> SqlPersistT (LoggingT (ResourceT IO)) a -> SpecWith (Arg (IO ())) +itDb msg action = it msg $ runConnAssert $ void action + +pass :: IO () +pass = pure () + +spec :: Spec +spec = describe "ImplicitUuidSpec" $ before_ wipe $ do + describe "WithDefUuidKey" $ do + it "works on UUIDs" $ do + let + withDefUuidKey = WithDefUuidKey (UUID "Hello") + pass + describe "getEntityId" $ do + let + Just idField = getEntityIdField (entityDef (Proxy @WithDefUuid)) + it "has a UUID SqlType" $ asIO $ do + fieldSqlType idField `shouldBe` SqlOther "UUID" + it "is an implicit ID column" $ asIO $ do + fieldIsImplicitIdColumn idField `shouldBe` True + + describe "insert" $ do + itDb "successfully has a default" $ do + let + matt = + WithDefUuid + { withDefUuidName = + "Matt" + } + k <- insert matt + mrec <- get k + mrec `shouldBe` Just matt diff --git a/persistent-postgresql-ng/test/InCollapseSpec.hs b/persistent-postgresql-ng/test/InCollapseSpec.hs new file mode 100644 index 000000000..03b66d0bd --- /dev/null +++ b/persistent-postgresql-ng/test/InCollapseSpec.hs @@ -0,0 +1,76 @@ +{-# LANGUAGE OverloadedStrings #-} + +-- | Tests for the IN -> ANY SQL rewriting in the pipeline backend. +module InCollapseSpec (specs) where + +import Test.Hspec +import Database.Persist (PersistValue (..)) +import Database.Persist.Postgresql.Pipeline (inlineAndRewrite, collapseInClauses) + +specs :: Spec +specs = describe "IN -> ANY collapsing" $ do + describe "collapseInClauses" $ do + it "collapses IN (?,?,?) into = ANY(?) with PersistArray" $ do + let (sql, params) = collapseInClauses + "SELECT * FROM t WHERE \"id\" IN (?,?,?)" + [PersistInt64 1, PersistInt64 2, PersistInt64 3] + sql `shouldBe` "SELECT * FROM t WHERE \"id\" = ANY(?)" + params `shouldBe` [PersistArray [PersistInt64 1, PersistInt64 2, PersistInt64 3]] + + it "collapses NOT IN (?,?,?) into <> ALL(?) with PersistArray" $ do + let (sql, params) = collapseInClauses + "SELECT * FROM t WHERE \"id\" NOT IN (?,?,?)" + [PersistInt64 1, PersistInt64 2, PersistInt64 3] + sql `shouldBe` "SELECT * FROM t WHERE \"id\" <> ALL(?)" + params `shouldBe` [PersistArray [PersistInt64 1, PersistInt64 2, PersistInt64 3]] + + it "does NOT collapse IN (?) with single param" $ do + let (sql, params) = collapseInClauses + "SELECT * FROM t WHERE \"id\" IN (?)" + [PersistInt64 42] + sql `shouldBe` "SELECT * FROM t WHERE \"id\" IN (?)" + params `shouldBe` [PersistInt64 42] + + it "preserves params outside IN clauses" $ do + let (sql, params) = collapseInClauses + "SELECT * FROM t WHERE \"name\" IN (?,?) AND \"age\" > ?" + [PersistText "a", PersistText "b", PersistInt64 21] + sql `shouldBe` "SELECT * FROM t WHERE \"name\" = ANY(?) AND \"age\" > ?" + params `shouldBe` [PersistArray [PersistText "a", PersistText "b"], PersistInt64 21] + + it "handles multiple IN clauses" $ do + let (sql, params) = collapseInClauses + "WHERE \"a\" IN (?,?) AND \"b\" IN (?,?,?)" + [ PersistInt64 1, PersistInt64 2 + , PersistText "x", PersistText "y", PersistText "z" + ] + sql `shouldBe` "WHERE \"a\" = ANY(?) AND \"b\" = ANY(?)" + params `shouldBe` + [ PersistArray [PersistInt64 1, PersistInt64 2] + , PersistArray [PersistText "x", PersistText "y", PersistText "z"] + ] + + it "does not collapse ? inside string literals" $ do + let (sql, params) = collapseInClauses + "SELECT * FROM t WHERE name = 'IN (?,?)' AND id IN (?,?)" + [PersistInt64 1, PersistInt64 2] + sql `shouldBe` "SELECT * FROM t WHERE name = 'IN (?,?)' AND id = ANY(?)" + params `shouldBe` [PersistArray [PersistInt64 1, PersistInt64 2]] + + it "passes through when no IN clauses" $ do + let (sql, params) = collapseInClauses + "SELECT * FROM t WHERE id = ? AND name = ?" + [PersistInt64 1, PersistText "foo"] + sql `shouldBe` "SELECT * FROM t WHERE id = ? AND name = ?" + params `shouldBe` [PersistInt64 1, PersistText "foo"] + + describe "inlineAndRewrite (full pipeline)" $ do + it "collapses IN and rewrites to $N" $ do + let (sql, params) = inlineAndRewrite + "SELECT * FROM t WHERE \"id\" IN (?,?,?) AND \"name\" = ?" + [PersistInt64 1, PersistInt64 2, PersistInt64 3, PersistText "foo"] + sql `shouldBe` "SELECT * FROM t WHERE \"id\" = ANY($1) AND \"name\" = $2" + params `shouldBe` + [ PersistArray [PersistInt64 1, PersistInt64 2, PersistInt64 3] + , PersistText "foo" + ] diff --git a/persistent-postgresql-ng/test/JSONTest.hs b/persistent-postgresql-ng/test/JSONTest.hs new file mode 100644 index 000000000..a0ab87b11 --- /dev/null +++ b/persistent-postgresql-ng/test/JSONTest.hs @@ -0,0 +1,760 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} + +module JSONTest where + +import Control.Monad.IO.Class (MonadIO) +import Data.Aeson hiding (Key) +import qualified Data.Vector as V (fromList) +import Test.HUnit (assertBool) +import Test.Hspec.Expectations () + +import Database.Persist +import Database.Persist.Postgresql.JSON + +import PgPipelineInit + +share + [mkPersist persistSettings, mkMigrate "jsonTestMigrate"] + [persistLowerCase| + TestValue + json Value + deriving Show +|] + +cleanDB + :: (BaseBackend backend ~ SqlBackend, PersistQueryWrite backend, MonadIO m) + => ReaderT backend m () +cleanDB = deleteWhere ([] :: [Filter TestValue]) + +emptyArr :: Value +emptyArr = toJSON ([] :: [Value]) + +insert' + :: (MonadIO m, PersistStoreWrite backend, BaseBackend backend ~ SqlBackend) + => Value -> ReaderT backend m (Key TestValue) +insert' = insert . TestValue + +matchKeys + :: (Show record, Show (Key record), MonadIO m, Eq (Key record)) + => [Key record] -> [Entity record] -> m () +matchKeys ys xs = do + msg1 `assertBoolIO` (xLen == yLen) + forM_ ys $ \y -> msg2 y `assertBoolIO` (y `elem` ks) + where + ks = entityKey <$> xs + xLen = length xs + yLen = length ys + msg1 = + mconcat + [ "\nexpected: " + , show yLen + , "\n but got: " + , show xLen + , "\n[xs: " + , show xs + , "]" + , "\n[ys: " + , show ys + , "]" + ] + msg2 y = + mconcat + [ "key \"" + , show y + , "\" not in result:\n " + , show ks + ] + +setup :: IO TestKeys +setup = asIO $ runConn_ $ do + void $ runMigrationSilent jsonTestMigrate + testKeys + +teardown :: IO () +teardown = asIO $ runConn_ $ do + cleanDB + +shouldBeIO :: (Show a, Eq a, MonadIO m) => a -> a -> m () +shouldBeIO x y = liftIO $ shouldBe x y + +assertBoolIO :: (MonadIO m) => String -> Bool -> m () +assertBoolIO s b = liftIO $ assertBool s b + +testKeys :: (Monad m, MonadIO m) => ReaderT SqlBackend m TestKeys +testKeys = do + nullK <- insert' Null + + boolTK <- insert' $ Bool True + boolFK <- insert' $ toJSON False + + num0K <- insert' $ Number 0 + num1K <- insert' $ Number 1 + numBigK <- insert' $ toJSON (1234567890 :: Int) + numFloatK <- insert' $ Number 0.0 + numSmallK <- insert' $ Number 0.0000000000000000123 + numFloat2K <- insert' $ Number 1.5 + -- numBigFloatK will turn into 9876543210.123457 because JSON + numBigFloatK <- insert' $ toJSON (9876543210.123456789 :: Double) + + strNullK <- insert' $ String "" + strObjK <- insert' $ String "{}" + strArrK <- insert' $ String "[]" + strAK <- insert' $ String "a" + strTestK <- insert' $ toJSON ("testing" :: Text) + str2K <- insert' $ String "2" + strFloatK <- insert' $ String "0.45876" + + arrNullK <- insert' $ Array $ V.fromList [] + arrListK <- insert' $ toJSON [emptyArr, emptyArr, toJSON [emptyArr, emptyArr]] + arrList2K <- + insert' $ + toJSON + [ emptyArr + , toJSON [Number 3, Bool False] + , toJSON [emptyArr, toJSON [Object mempty]] + ] + arrFilledK <- + insert' $ + toJSON + [ Null + , Number 4 + , String "b" + , Object mempty + , emptyArr + , object ["test" .= [Null], "test2" .= String "yes"] + ] + arrList3K <- insert' $ toJSON [toJSON [String "a"], Number 1] + arrList4K <- insert' $ toJSON [String "a", String "b", String "c", String "d"] + + objNullK <- insert' $ Object mempty + objTestK <- insert' $ object ["test" .= Null, "test1" .= String "no"] + objDeepK <- + insert' $ object ["c" .= Number 24.986, "foo" .= object ["deep1" .= Bool True]] + objEmptyK <- insert' $ object ["" .= Number 9001] + objFullK <- + insert' $ + object + [ "a" .= Number 1 + , "b" .= Number 2 + , "c" .= Number 3 + , "d" .= Number 4 + ] + return TestKeys{..} + +data TestKeys + = TestKeys + { nullK :: Key TestValue + , boolTK :: Key TestValue + , boolFK :: Key TestValue + , num0K :: Key TestValue + , num1K :: Key TestValue + , numBigK :: Key TestValue + , numFloatK :: Key TestValue + , numSmallK :: Key TestValue + , numFloat2K :: Key TestValue + , numBigFloatK :: Key TestValue + , strNullK :: Key TestValue + , strObjK :: Key TestValue + , strArrK :: Key TestValue + , strAK :: Key TestValue + , strTestK :: Key TestValue + , str2K :: Key TestValue + , strFloatK :: Key TestValue + , arrNullK :: Key TestValue + , arrListK :: Key TestValue + , arrList2K :: Key TestValue + , arrFilledK :: Key TestValue + , objNullK :: Key TestValue + , objTestK :: Key TestValue + , objDeepK :: Key TestValue + , arrList3K :: Key TestValue + , arrList4K :: Key TestValue + , objEmptyK :: Key TestValue + , objFullK :: Key TestValue + } + deriving (Eq, Ord, Show) + +specs :: Spec +specs = afterAll_ teardown $ do + beforeAll setup $ do + describe "Testing JSON operators" $ do + describe "@>. object queries" $ do + it "matches an empty Object with any object" $ + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. Object mempty] [] + [objNullK, objTestK, objDeepK, objEmptyK, objFullK] `matchKeys` vals + + it "matches a subset of object properties" $ + -- {test: null, test1: no} @>. {test: null} == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. object ["test" .= Null]] [] + [objTestK] `matchKeys` vals + + it "matches a nested object against an empty object at the same key" $ + -- {c: 24.986, foo: {deep1: true}} @>. {foo: {}} == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. object ["foo" .= object []]] [] + [objDeepK] `matchKeys` vals + + it "doesn't match a nested object against a string at the same key" $ + -- {c: 24.986, foo: {deep1: true}} @>. {foo: nope} == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. object ["foo" .= String "nope"]] [] + [] `matchKeys` vals + + it "matches a nested object when the query object is identical" $ + -- {c: 24.986, foo: {deep1: true}} @>. {foo: {deep1: true}} == True + \TestKeys{..} -> runConnAssert $ do + vals <- + selectList [TestValueJson @>. (object ["foo" .= object ["deep1" .= True]])] [] + [objDeepK] `matchKeys` vals + + it "doesn't match a nested object when queried with that exact object" $ + -- {c: 24.986, foo: {deep1: true}} @>. {deep1: true} == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. object ["deep1" .= True]] [] + [] `matchKeys` vals + + describe "@>. array queries" $ do + it "matches an empty Array with any list" $ + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. emptyArr] [] + [arrNullK, arrListK, arrList2K, arrFilledK, arrList3K, arrList4K] + `matchKeys` vals + + it "matches list when queried with subset (1 item)" $ + -- [null, 4, 'b', {}, [], {test: [null], test2: 'yes'}] @>. [4] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [4 :: Int]] [] + [arrFilledK] `matchKeys` vals + + it "matches list when queried with subset (2 items)" $ + -- [null, 4, 'b', {}, [], {test: [null], test2: 'yes'}] @>. [null,'b'] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [Null, String "b"]] [] + [arrFilledK] `matchKeys` vals + + it "doesn't match list when queried with intersecting list (1 match, 1 diff)" $ + -- [null, 4, 'b', {}, [], {test: [null], test2: 'yes'}] @>. [null,'d'] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [emptyArr, String "d"]] [] + [] `matchKeys` vals + + it "matches list when queried with same list in different order" $ + -- [null, 4, 'b', {}, [], {test: [null], test2: 'yes'}] @>. + -- [[],'b',{test: [null],test2: 'yes'},4,null,{}] == True + \TestKeys{..} -> runConnAssert $ do + let + queryList = + toJSON + [ emptyArr + , String "b" + , object ["test" .= [Null], "test2" .= String "yes"] + , Number 4 + , Null + , Object mempty + ] + + vals <- selectList [TestValueJson @>. queryList] [] + [arrFilledK] `matchKeys` vals + + it "doesn't match list when queried with same list + 1 item" $ + -- [null,4,'b',{},[],{test:[null],test2:'yes'}] @>. + -- [null,4,'b',{},[],{test:[null],test2: 'yes'}, false] == False + \TestKeys{..} -> runConnAssert $ do + let + testList = + toJSON + [ Null + , Number 4 + , String "b" + , Object mempty + , emptyArr + , object ["test" .= [Null], "test2" .= String "yes"] + , Bool False + ] + + vals <- selectList [TestValueJson @>. testList] [] + [] `matchKeys` vals + + it "matches list when it shares an empty object with the query list" $ + -- [null,4,'b',{},[],{test: [null],test2: 'yes'}] @>. [{}] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [Object mempty]] [] + [arrFilledK] `matchKeys` vals + + it "matches list with nested list, when queried with an empty nested list" $ + -- [null,4,'b',{},[],{test:[null],test2:'yes'}] @>. [{test:[]}] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [object ["test" .= emptyArr]]] [] + [arrFilledK] `matchKeys` vals + + it "doesn't match list with nested list, when queried with a diff. nested list" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. + -- [{"test1":[null]}] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [object ["test1" .= [Null]]]] [] + [] `matchKeys` vals + + it "matches many nested lists when queried with empty nested list" $ + -- [[],[],[[],[]]] @>. [[]] == True + -- [[],[3,false],[[],[{}]]] @>. [[]] == True + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. [[]] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [emptyArr]] [] + [arrListK, arrList2K, arrFilledK, arrList3K] `matchKeys` vals + + it "matches nested list when queried with a subset of that list" $ + -- [[],[3,false],[[],[{}]]] @>. [[3]] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [[3 :: Int]]] [] + [arrList2K] `matchKeys` vals + + it "doesn't match nested list againts a partial intersection of that list" $ + -- [[],[3,false],[[],[{}]]] @>. [[true,3]] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON [[Bool True, Number 3]]] [] + [] `matchKeys` vals + + it "matches list when queried with raw number contained in the list" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. 4 == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. Number 4] [] + [arrFilledK] `matchKeys` vals + + it "doesn't match list when queried with raw value not contained in the list" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. 99 == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. Number 99] [] + [] `matchKeys` vals + + it "matches list when queried with raw string contained in the list" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. "b" == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. String "b"] [] + [arrFilledK, arrList4K] `matchKeys` vals + + it "doesn't match list with empty object when queried with \"{}\" " $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. "{}" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. String "{}"] [] + [strObjK] `matchKeys` vals + + it "doesnt match list with nested object when queried with object (not in list)" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] @>. + -- {"test":[null],"test2":"yes"} == False + \TestKeys{..} -> runConnAssert $ do + let + queryObject = object ["test" .= [Null], "test2" .= String "yes"] + vals <- selectList [TestValueJson @>. queryObject] [] + [] `matchKeys` vals + + describe "@>. string queries" $ do + it "matches identical strings" $ + -- "testing" @>. "testing" == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. String "testing"] [] + [strTestK] `matchKeys` vals + + it "doesnt match case insensitive" $ + -- "testing" @>. "Testing" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. String "Testing"] [] + [] `matchKeys` vals + + it "doesn't match substrings" $ + -- "testing" @>. "test" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. String "test"] [] + [] `matchKeys` vals + + it "doesn't match strings with object keys" $ + -- "testing" @>. {"testing":1} == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. object ["testing" .= Number 1]] [] + [] `matchKeys` vals + + describe "@>. number queries" $ do + it "matches identical numbers" $ + -- 1 @>. 1 == True + -- [1] @>. 1 == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON (1 :: Int)] [] + [num1K, arrList3K] `matchKeys` vals + + it "matches numbers when queried with float" $ + -- 0 @>. 0.0 == True + -- 0.0 @>. 0.0 == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON (0.0 :: Double)] [] + [num0K, numFloatK] `matchKeys` vals + + it "does not match numbers when queried with a substring of that number" $ + -- 1234567890 @>. 123456789 == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON (123456789 :: Int)] [] + [] `matchKeys` vals + + it "does not match number when queried with different number" $ + -- 1234567890 @>. 234567890 == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON (234567890 :: Int)] [] + [] `matchKeys` vals + + it "does not match number when queried with string of that number" $ + -- 1 @>. "1" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. String "1"] [] + [] `matchKeys` vals + + it "does not match number when queried with list of digits" $ + -- 1234567890 @>. [1,2,3,4,5,6,7,8,9,0] == False + \TestKeys{..} -> runConnAssert $ do + vals <- + selectList + [TestValueJson @>. toJSON ([1, 2, 3, 4, 5, 6, 7, 8, 9, 0] :: [Int])] + [] + [] `matchKeys` vals + + describe "@>. boolean queries" $ do + it "matches identical booleans (True)" $ + -- true @>. true == True + -- false @>. true == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. toJSON True] [] + [boolTK] `matchKeys` vals + + it "matches identical booleans (False)" $ + -- false @>. false == True + -- true @>. false == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. Bool False] [] + [boolFK] `matchKeys` vals + + it "does not match boolean with string of boolean" $ + -- true @>. "true" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. String "true"] [] + [] `matchKeys` vals + + describe "@>. null queries" $ do + it "matches nulls" $ + -- null @>. null == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. Null] [] + [nullK, arrFilledK] `matchKeys` vals + + it "does not match null with string of null" $ + -- null @>. "null" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson @>. String "null"] [] + [] `matchKeys` vals + + describe "<@. queries" $ do + it "matches subobject when queried with superobject" $ + -- {} <@. {"test":null,"test1":"no","blabla":[]} == True + -- {"test":null,"test1":"no"} <@. {"test":null,"test1":"no","blabla":[]} == True + \TestKeys{..} -> runConnAssert $ do + let + queryObject = + object + [ "test" .= Null + , "test1" .= String "no" + , "blabla" .= emptyArr + ] + vals <- selectList [TestValueJson <@. queryObject] [] + [objNullK, objTestK] `matchKeys` vals + + it "matches raw values and sublists when queried with superlist" $ + -- [] <@. [null,4,"b",{},[],{"test":[null],"test2":"yes"},false] == True + -- null <@. [null,4,"b",{},[],{"test":[null],"test2":"yes"},false] == True + -- false <@. [null,4,"b",{},[],{"test":[null],"test2":"yes"},false] == True + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] <@. + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"},false] == True + \TestKeys{..} -> runConnAssert $ do + let + queryList = + toJSON + [ Null + , Number 4 + , String "b" + , Object mempty + , emptyArr + , object ["test" .= [Null], "test2" .= String "yes"] + , Bool False + ] + + vals <- selectList [TestValueJson <@. queryList] [] + [arrNullK, arrFilledK, boolFK, nullK] `matchKeys` vals + + it "matches identical strings" $ + -- "a" <@. "a" == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson <@. String "a"] [] + [strAK] `matchKeys` vals + + it "matches identical big floats" $ + -- 9876543210.123457 <@ 9876543210.123457 == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson <@. Number 9876543210.123457] [] + [numBigFloatK] `matchKeys` vals + + it "doesn't match different big floats" $ + -- 9876543210.123457 <@. 9876543210.123456789 == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson <@. Number 9876543210.123456789] [] + [] `matchKeys` vals + + it "matches nulls" $ + -- null <@. null == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson <@. Null] [] + [nullK] `matchKeys` vals + + describe "?. queries" $ do + it "matches top level keys and not the keys of nested objects" $ + -- {"test":null,"test1":"no"} ?. "test" == True + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?. "test" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. "test"] [] + [objTestK] `matchKeys` vals + + it "doesn't match nested key" $ + -- {"c":24.986,"foo":{"deep1":true"}} ?. "deep1" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. "deep1"] [] + [] `matchKeys` vals + + it "matches \"{}\" but not empty object when queried with \"{}\"" $ + -- "{}" ?. "{}" == True + -- {} ?. "{}" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. "{}"] [] + [strObjK] `matchKeys` vals + + it "matches raw empty str and empty str key when queried with \"\"" $ + ---- {} ?. "" == False + ---- "" ?. "" == True + ---- {"":9001} ?. "" == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. ""] [] + [strNullK, objEmptyK] `matchKeys` vals + + it "matches lists containing string value when queried with raw string value" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?. "b" == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. "b"] [] + [arrFilledK, arrList4K, objFullK] `matchKeys` vals + + it "matches lists, objects, and raw values correctly when queried with string" $ + -- [["a"]] ?. "a" == False + -- "a" ?. "a" == True + -- ["a","b","c","d"] ?. "a" == True + -- {"a":1,"b":2,"c":3,"d":4} ?. "a" == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. "a"] [] + [strAK, arrList4K, objFullK] `matchKeys` vals + + it "matches string list but not real list when queried with \"[]\"" $ + -- "[]" ?. "[]" == True + -- [] ?. "[]" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. "[]"] [] + [strArrK] `matchKeys` vals + + it "does not match null when queried with string null" $ + -- null ?. "null" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. "null"] [] + [] `matchKeys` vals + + it "does not match bool whe nqueried with string bool" $ + -- true ?. "true" == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?. "true"] [] + [] `matchKeys` vals + + describe "?|. queries" $ do + it "matches raw vals, lists, objects, and nested objects" $ + -- "a" ?|. ["a","b","c"] == True + -- [["a"],1] ?|. ["a","b","c"] == False + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?|. ["a","b","c"] == True + -- ["a","b","c","d"] ?|. ["a","b","c"] == True + -- {"a":1,"b":2,"c":3,"d":4} ?|. ["a","b","c"] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?|. ["a", "b", "c"]] [] + [strAK, arrFilledK, objDeepK, arrList4K, objFullK] `matchKeys` vals + + it "matches str object but not object when queried with \"{}\"" $ + -- "{}" ?|. ["{}"] == True + -- {} ?|. ["{}"] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?|. ["{}"]] [] + [strObjK] `matchKeys` vals + + it "doesn't match superstrings when queried with substring" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?|. ["test"] == False + -- "testing" ?|. ["test"] == False + -- {"test":null,"test1":"no"} ?|. ["test"] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?|. ["test"]] [] + [objTestK] `matchKeys` vals + + it "doesn't match nested keys" $ + -- {"c":24.986,"foo":{"deep1":true"}} ?|. ["deep1"] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?|. ["deep1"]] [] + [] `matchKeys` vals + + it "doesn't match anything when queried with empty list" $ + -- ANYTHING ?|. [] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?|. []] [] + [] `matchKeys` vals + + it "doesn't match raw, non-string, values when queried with strings" $ + -- true ?|. ["true","null","1"] == False + -- null ?|. ["true","null","1"] == False + -- 1 ?|. ["true","null","1"] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?|. ["true", "null", "1"]] [] + [] `matchKeys` vals + + it "matches string array when queried with \"[]\"" $ + -- [] ?|. ["[]"] == False + -- "[]" ?|. ["[]"] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?|. ["[]"]] [] + [strArrK] `matchKeys` vals + + describe "?&. queries" $ do + it "matches anything when queried with an empty list" $ + -- ANYTHING ?&. [] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. []] [] + flip + matchKeys + vals + [ nullK + , boolTK + , boolFK + , num0K + , num1K + , numBigK + , numFloatK + , numSmallK + , numFloat2K + , numBigFloatK + , strNullK + , strObjK + , strArrK + , strAK + , strTestK + , str2K + , strFloatK + , arrNullK + , arrListK + , arrList2K + , arrFilledK + , arrList3K + , arrList4K + , objNullK + , objTestK + , objDeepK + , objEmptyK + , objFullK + ] + + it "matches raw values, lists, and objects when queried with string" $ + -- "a" ?&. ["a"] == True + -- [["a"],1] ?&. ["a"] == False + -- ["a","b","c","d"] ?&. ["a"] == True + -- {"a":1,"b":2,"c":3,"d":4} ?&. ["a"] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. ["a"]] [] + [strAK, arrList4K, objFullK] `matchKeys` vals + + it "matches raw values, lists, and objects when queried with multiple string" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?&. ["b","c"] == False + -- {"c":24.986,"foo":{"deep1":true"}} ?&. ["b","c"] == False + -- ["a","b","c","d"] ?&. ["b","c"] == True + -- {"a":1,"b":2,"c":3,"d":4} ?&. ["b","c"] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. ["b", "c"]] [] + [arrList4K, objFullK] `matchKeys` vals + + it "matches object string when queried with \"{}\"" $ + -- {} ?&. ["{}"] == False + -- "{}" ?&. ["{}"] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. ["{}"]] [] + [strObjK] `matchKeys` vals + + it "doesn't match superstrings when queried with substring" $ + -- [null,4,"b",{},[],{"test":[null],"test2":"yes"}] ?&. ["test"] == False + -- "testing" ?&. ["test"] == False + -- {"test":null,"test1":"no"} ?&. ["test"] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. ["test"]] [] + [objTestK] `matchKeys` vals + + it "doesn't match nested keys" $ + -- {"c":24.986,"foo":{"deep1":true"}} ?&. ["deep1"] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. ["deep1"]] [] + [] `matchKeys` vals + + it "doesn't match anything when there is a partial match" $ + -- "a" ?&. ["a","e"] == False + -- ["a","b","c","d"] ?&. ["a","e"] == False + -- {"a":1,"b":2,"c":3,"d":4} ?&. ["a","e"] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. ["a", "e"]] [] + [] `matchKeys` vals + + it "matches string array when queried with \"[]\"" $ + -- [] ?&. ["[]"] == False + -- "[]" ?&. ["[]"] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. ["[]"]] [] + [strArrK] `matchKeys` vals + + it "doesn't match null when queried with string null" $ + -- THIS WILL FAIL IF THE IMPLEMENTATION USES + -- @ '{null}' @ + -- INSTEAD OF + -- @ ARRAY['null'] @ + -- null ?&. ["null"] == False + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. ["null"]] [] + [] `matchKeys` vals + + it "doesn't match number when queried with str of that number" $ + -- [["a"],1] ?&. ["1"] == False + -- "1" ?&. ["1"] == True + \TestKeys{..} -> runConnAssert $ do + str1 <- insert' $ toJSON $ String "1" + vals <- selectList [TestValueJson ?&. ["1"]] [] + [str1] `matchKeys` vals + + it "doesn't match empty objs or list when queried with empty string" $ + -- {} ?&. [""] == False + -- [] ?&. [""] == False + -- "" ?&. [""] == True + -- {"":9001} ?&. [""] == True + \TestKeys{..} -> runConnAssert $ do + vals <- selectList [TestValueJson ?&. [""]] [] + [strNullK, objEmptyK] `matchKeys` vals diff --git a/persistent-postgresql-ng/test/MigrationReferenceSpec.hs b/persistent-postgresql-ng/test/MigrationReferenceSpec.hs new file mode 100644 index 000000000..f2079a065 --- /dev/null +++ b/persistent-postgresql-ng/test/MigrationReferenceSpec.hs @@ -0,0 +1,61 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -Wno-unused-top-binds #-} + +module MigrationReferenceSpec where + +import PgPipelineInit + +import Control.Monad.Trans.Writer (censor, mapWriterT) +import Data.Text (Text, isInfixOf) + +share + [mkPersist sqlSettings, mkMigrate "referenceMigrate"] + [persistLowerCase| + +LocationCapabilities + Id Text + bio Text + +LocationCapabilitiesPrintingProcess + locationCapabilitiesId LocationCapabilitiesId + +LocationCapabilitiesPrintingFinish + locationCapabilitiesId LocationCapabilitiesId + +LocationCapabilitiesSubstrate + locationCapabilitiesId LocationCapabilitiesId + +|] + +spec :: Spec +spec = describe "MigrationReferenceSpec" $ do + it "works" $ runConnAssert $ do + let + noForeignKeys :: CautiousMigration -> CautiousMigration + noForeignKeys = filter ((not . isReference) . snd) + + onlyForeignKeys :: CautiousMigration -> CautiousMigration + onlyForeignKeys = filter (isReference . snd) + + isReference :: Text -> Bool + isReference migration = "REFERENCES" `isInfixOf` migration + + runMigration $ + mapWriterT (censor noForeignKeys) $ + referenceMigrate + + runMigration $ + mapWriterT (censor onlyForeignKeys) $ + referenceMigrate diff --git a/persistent-postgresql-ng/test/MigrationSpec.hs b/persistent-postgresql-ng/test/MigrationSpec.hs new file mode 100644 index 000000000..44ed41876 --- /dev/null +++ b/persistent-postgresql-ng/test/MigrationSpec.hs @@ -0,0 +1,687 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module MigrationSpec where + +import PgPipelineInit + +import Data.Foldable (traverse_) +import qualified Data.Map as Map +import Data.Proxy +import qualified Data.Set as Set +import qualified Data.Text as T +import Database.Persist.Postgresql.Internal.Migration + +getStmtGetter + :: (Monad m) => SqlPersistT m (Text -> IO Statement) +getStmtGetter = do + backend <- ask + pure (getStmtConn backend) + +-- NB: we do not perform these migrations in main.hs +share + [mkPersist persistSettings{mpsGeneric = False}] + [persistLowerCase| +User sql=users + name Text + title Text Maybe + deriving Show Eq + +UserFriendship sql=user_friendships + user1Id UserId Maybe + user2Id UserId Maybe + deriving Show Eq + +Password sql=passwords + passwordHash Text + userId UserId Maybe + UniqueUserId userId !force + +Password2 sql=passwords_2 + passwordHash Text + userId UserId Maybe OnDeleteCascade OnUpdateSetNull + UniqueUserId2 userId !force + +AdminUser sql=admin_users + userId UserId + Primary userId + + promotedByUserId UserId + UniquePromotedByUserId promotedByUserId + +FKParent sql=migration_fk_parent + +FKChildV1 sql=migration_fk_child + +-- Simulate creating a new FK field on an existing table +FKChildV2 sql=migration_fk_child + parentId FKParentId + +ExplicitPrimaryKey sql=explicit_primary_key + Id Text +|] + +userEntityDef :: EntityDef +userEntityDef = entityDef (Proxy :: Proxy User) + +userFriendshipEntityDef :: EntityDef +userFriendshipEntityDef = entityDef (Proxy :: Proxy UserFriendship) + +passwordEntityDef :: EntityDef +passwordEntityDef = entityDef (Proxy :: Proxy Password) + +password2EntityDef :: EntityDef +password2EntityDef = entityDef (Proxy :: Proxy Password2) + +adminUserEntityDef :: EntityDef +adminUserEntityDef = entityDef (Proxy :: Proxy AdminUser) + +fkParentEntityDef :: EntityDef +fkParentEntityDef = entityDef (Proxy :: Proxy FKParent) + +fkChildV1EntityDef :: EntityDef +fkChildV1EntityDef = entityDef (Proxy :: Proxy FKChildV1) + +fkChildV2EntityDef :: EntityDef +fkChildV2EntityDef = entityDef (Proxy :: Proxy FKChildV2) + +explicitPrimaryKeyEntityDef :: EntityDef +explicitPrimaryKeyEntityDef = entityDef (Proxy :: Proxy ExplicitPrimaryKey) + +-- Note that FKChild is deliberately omitted here because we have two +-- versions of it +allEntityDefs :: [EntityDef] +allEntityDefs = + [ userEntityDef + , userFriendshipEntityDef + , passwordEntityDef + , password2EntityDef + , adminUserEntityDef + , fkParentEntityDef + , explicitPrimaryKeyEntityDef + ] + +-- Note that this function migrates to the schema expected by FKChildV1 +migrateManually :: (HasCallStack, MonadIO m) => SqlPersistT m () +migrateManually = do + cleanDB + let + rawEx sql = rawExecute sql [] + rawEx + "CREATE TABLE users(id int8 primary key, name text not null, title text);" + rawEx $ + T.concat + [ "CREATE TABLE user_friendships(" + , " id int8 primary key," + , " user1_id int8 references users(id) on delete restrict on update restrict," + , " user2_id int8 references users(id) on delete restrict on update restrict" + , ");" + ] + rawEx $ + T.concat + [ "CREATE TABLE passwords(" + , " id int8 primary key," + , " password_hash text not null," + , " user_id int8 references users(id) on delete restrict on update restrict" + , ");" + ] + rawEx $ + T.concat + [ "ALTER TABLE passwords" + , " ADD CONSTRAINT unique_user_id" + , " UNIQUE(user_id);" + ] + rawEx $ + T.concat + [ "CREATE TABLE passwords_2(" + , " id int8 primary key," + , " password_hash text not null," + , " user_id int8 references users(id) on delete cascade on update set null" + , ");" + ] + rawEx $ + T.concat + [ "ALTER TABLE passwords_2" + , " ADD CONSTRAINT unique_user_id2" + , " UNIQUE(user_id);" + ] + -- Add an extra redundant FK constraint on passwords_2.user_id, so that we + -- can test that the migrator ignores it + rawEx $ + T.concat + [ "ALTER TABLE passwords_2" + , " ADD CONSTRAINT duplicate_passwords_2_user_id_fkey" + , " FOREIGN KEY (user_id) REFERENCES users(id);" + ] + rawEx $ + T.concat + [ "CREATE TABLE admin_users(" + , " user_id int8 not null references users(id) on delete restrict on update restrict primary key," + , " promoted_by_user_id int8 not null references users(id) on delete restrict on update restrict" + , ");" + ] + rawEx $ + T.concat + [ "ALTER TABLE admin_users" + , " ADD CONSTRAINT unique_promoted_by_user_id" + , " UNIQUE(promoted_by_user_id);" + ] + rawEx "CREATE TABLE migration_fk_parent(id int8 primary key);" + rawEx "CREATE TABLE migration_fk_child(id int8 primary key);" + rawEx "CREATE TABLE explicit_primary_key(id text primary key);" + rawEx "CREATE TABLE ignored(id int8 primary key);" + +cleanDB :: (HasCallStack, MonadIO m) => SqlPersistT m () +cleanDB = do + let + rawEx sql = rawExecute sql [] + rawEx "DROP TABLE IF EXISTS user_friendships;" + rawEx "DROP TABLE IF EXISTS passwords;" + rawEx "DROP TABLE IF EXISTS passwords_2;" + rawEx "DROP TABLE IF EXISTS ignored;" + rawEx "DROP TABLE IF EXISTS admin_users;" + rawEx "DROP TABLE IF EXISTS users;" + rawEx "DROP TABLE IF EXISTS migration_fk_child;" + rawEx "DROP TABLE IF EXISTS migration_fk_parent;" + rawEx "DROP TABLE IF EXISTS explicit_primary_key;" + +spec :: Spec +spec = describe "MigrationSpec" $ do + it "gathers schema state" $ runConnAssert $ do + migrateManually + + getter <- getStmtGetter + actual <- + liftIO $ + collectSchemaState getter $ + map + EntityNameDB + [ "users" + , "admin_users" + , "user_friendships" + , "passwords" + , "passwords_2" + , "nonexistent" + ] + + cleanDB + + let + expected = + SchemaState + ( Map.fromList + [ + ( EntityNameDB{unEntityNameDB = "admin_users"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + Map.fromList + [ + ( FieldNameDB{unFieldNameDB = "promoted_by_user_id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "promoted_by_user_id"} + , cNull = False + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList + [ ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "admin_users_promoted_by_user_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict} + } + ] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "user_id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "user_id"} + , cNull = False + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList + [ ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "admin_users_user_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict} + } + ] + ) + ) + ] + , essUniqueConstraints = + Map.fromList + [ + ( ConstraintNameDB{unConstraintNameDB = "unique_promoted_by_user_id"} + , [FieldNameDB{unFieldNameDB = "promoted_by_user_id"}] + ) + ] + } + ) + ) + , (EntityNameDB{unEntityNameDB = "nonexistent"}, EntityDoesNotExist) + , + ( EntityNameDB{unEntityNameDB = "passwords"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + Map.fromList + [ + ( FieldNameDB{unFieldNameDB = "id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "id"} + , cNull = False + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList [] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "password_hash"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "password_hash"} + , cNull = False + , cSqlType = SqlString + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList [] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "user_id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "user_id"} + , cNull = True + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList + [ ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "passwords_user_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict} + } + ] + ) + ) + ] + , essUniqueConstraints = + Map.fromList + [ + ( ConstraintNameDB{unConstraintNameDB = "unique_user_id"} + , [FieldNameDB{unFieldNameDB = "user_id"}] + ) + ] + } + ) + ) + , + ( EntityNameDB{unEntityNameDB = "passwords_2"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + Map.fromList + [ + ( FieldNameDB{unFieldNameDB = "id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "id"} + , cNull = False + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList [] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "password_hash"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "password_hash"} + , cNull = False + , cSqlType = SqlString + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList [] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "user_id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "user_id"} + , cNull = True + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList + [ ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "duplicate_passwords_2_user_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just NoAction, fcOnDelete = Just NoAction} + } + , ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "passwords_2_user_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just SetNull, fcOnDelete = Just Cascade} + } + ] + ) + ) + ] + , essUniqueConstraints = + Map.fromList + [ + ( ConstraintNameDB{unConstraintNameDB = "unique_user_id2"} + , [FieldNameDB{unFieldNameDB = "user_id"}] + ) + ] + } + ) + ) + , + ( EntityNameDB{unEntityNameDB = "user_friendships"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + Map.fromList + [ + ( FieldNameDB{unFieldNameDB = "id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "id"} + , cNull = False + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList [] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "user1_id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "user1_id"} + , cNull = True + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList + [ ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "user_friendships_user1_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict} + } + ] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "user2_id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "user2_id"} + , cNull = True + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList + [ ColumnReference + { crTableName = EntityNameDB{unEntityNameDB = "users"} + , crConstraintName = + ConstraintNameDB{unConstraintNameDB = "user_friendships_user2_id_fkey"} + , crFieldCascade = + FieldCascade{fcOnUpdate = Just Restrict, fcOnDelete = Just Restrict} + } + ] + ) + ) + ] + , essUniqueConstraints = Map.fromList [] + } + ) + ) + , + ( EntityNameDB{unEntityNameDB = "users"} + , EntityExists + ( ExistingEntitySchemaState + { essColumns = + Map.fromList + [ + ( FieldNameDB{unFieldNameDB = "id"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "id"} + , cNull = False + , cSqlType = SqlInt64 + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList [] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "name"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "name"} + , cNull = False + , cSqlType = SqlString + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList [] + ) + ) + , + ( FieldNameDB{unFieldNameDB = "title"} + , + ( Column + { cName = FieldNameDB{unFieldNameDB = "title"} + , cNull = True + , cSqlType = SqlString + , cDefault = Nothing + , cGenerated = Nothing + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = Nothing + } + , Set.fromList [] + ) + ) + ] + , essUniqueConstraints = Map.fromList [] + } + ) + ) + ] + ) + + actual `shouldBe` Right expected + + it "no-ops on a migrated DB" $ runConnAssert $ do + migrateManually + + getter <- getStmtGetter + result <- + liftIO $ + migrateEntitiesStructured + emptyBackendSpecificOverrides + getter + allEntityDefs + allEntityDefs + + cleanDB + + case result of + Right [] -> + pure () + Left err -> + expectationFailure $ show err + Right alters -> + map (snd . showAlterDb) alters `shouldBe` [] + + it "migrates a clean DB" $ runConnAssert $ do + cleanDB + + getter <- getStmtGetter + result <- + liftIO $ + migrateEntitiesStructured + emptyBackendSpecificOverrides + getter + allEntityDefs + allEntityDefs + + cleanDB + + case result of + Right [] -> + pure () + Left err -> + expectationFailure $ show err + Right alters -> do + traverse_ (flip rawExecute [] . snd . showAlterDb) alters + result2 <- + liftIO $ + migrateEntitiesStructured + emptyBackendSpecificOverrides + getter + allEntityDefs + allEntityDefs + result2 `shouldBe` Right [] + + it "suggests FK constraints for new fields first time" $ runConnAssert $ do + migrateManually + + getter <- getStmtGetter + result <- + liftIO $ + migrateEntitiesStructured + emptyBackendSpecificOverrides + getter + (fkChildV2EntityDef : allEntityDefs) + [fkChildV2EntityDef] + + cleanDB + + case result of + Right [] -> + pure () + Left err -> + expectationFailure $ show err + Right alters -> + map (snd . showAlterDb) alters + `shouldBe` [ "ALTER TABLE \"migration_fk_child\" ADD COLUMN \"parent_id\" INT8 NOT NULL" + , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE RESTRICT ON UPDATE RESTRICT" + ] + + it "Uses overrides for empty cascade action" $ runConnAssert $ do + migrateManually + + getter <- getStmtGetter + + let + overrideWithDefault = + setBackendSpecificForeignKeyCascadeDefault Cascade emptyBackendSpecificOverrides + result <- + liftIO $ + migrateEntitiesStructured + overrideWithDefault + getter + (fkChildV2EntityDef : allEntityDefs) + [fkChildV2EntityDef] + + cleanDB + + case result of + Right [] -> + pure () + Left err -> + expectationFailure $ show err + Right alters -> + map (snd . showAlterDb) alters + `shouldBe` [ "ALTER TABLE \"migration_fk_child\" ADD COLUMN \"parent_id\" INT8 NOT NULL" + , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE CASCADE ON UPDATE CASCADE" + ] diff --git a/persistent-postgresql-ng/test/PgPipelineInit.hs b/persistent-postgresql-ng/test/PgPipelineInit.hs new file mode 100644 index 000000000..53de436e0 --- /dev/null +++ b/persistent-postgresql-ng/test/PgPipelineInit.hs @@ -0,0 +1,242 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} + +module PgPipelineInit + ( runConn + , runConn_ + , runConnAssert + , runConnWith + , runConnWith_ + , runConnAssertWith + , FetchMode (..) + , MonadIO + , persistSettings + , MkPersistSettings (..) + , BackendKey (..) + , GenerateKey (..) + -- re-exports + , module Control.Monad.Trans.Reader + , module Control.Monad + , module Database.Persist.Sql + , module Database.Persist.SqlBackend + , module Database.Persist + , module Database.Persist.Sql.Raw.QQ + , module Init + , module Test.Hspec + , module Test.Hspec.Expectations.Lifted + , module Test.HUnit + , AValue (..) + , BS.ByteString + , Int32 + , Int64 + , liftIO + , mkPersist + , migrateModels + , mkMigrate + , share + , sqlSettings + , persistLowerCase + , persistUpperCase + , mkEntityDefList + , setImplicitIdDef + , SomeException + , Text + , TestFn (..) + , LoggingT + , ResourceT + , UUID (..) + , sqlSettingsUuid + ) where + +import Init + ( GenerateKey (..) + , MonadFail + , RunDb + , TestFn (..) + , UUID (..) + , arbText + , asIO + , assertEmpty + , assertNotEmpty + , assertNotEqual + , isTravis + , liftA2 + , sqlSettingsUuid + , truncateTimeOfDay + , truncateToMicro + , truncateUTCTime + , (==@) + , (@/=) + , (@==) + ) + +-- re-exports +import Control.Exception (SomeException) +import Control.Monad (forM_, liftM, replicateM, void, when) +import Control.Monad.Trans.Reader +import Data.Aeson (FromJSON, ToJSON, Value (..), object) +import qualified Data.Text.Encoding as TE +import Database.Persist.Postgresql.JSON () +import Database.Persist.Sql.Raw.QQ +import Database.Persist.SqlBackend +import Database.Persist.TH + ( MkPersistSettings (..) + , migrateModels + , mkEntityDefList + , mkMigrate + , mkPersist + , persistLowerCase + , persistUpperCase + , setImplicitIdDef + , share + , sqlSettings + ) +import Test.Hspec + ( Arg + , Spec + , SpecWith + , afterAll_ + , before + , beforeAll + , before_ + , describe + , fdescribe + , fit + , hspec + , it + ) +import Test.Hspec.Expectations.Lifted +import Test.QuickCheck.Instances () +import UnliftIO + +-- testing +import Test.HUnit (Assertion, assertBool, assertFailure, (@=?), (@?=)) +import Test.QuickCheck + +import Control.Monad (unless, (>=>)) +import Control.Monad.IO.Unlift (MonadUnliftIO) +import Control.Monad.Logger +import Control.Monad.Trans.Resource (ResourceT, runResourceT) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Char8 as B8 +import qualified Data.HashMap.Strict as HM +import Data.Int (Int32, Int64) +import Data.Maybe (fromMaybe) +import Data.Monoid ((<>)) +import Data.Text (Text) +import Data.Vector (Vector) +import System.Environment (getEnvironment, lookupEnv) +import System.Log.FastLogger (fromLogStr) + +import Database.Persist +import Database.Persist.Postgresql.Pipeline +import Database.Persist.Sql +import Database.Persist.TH () + +_debugOn :: Bool +_debugOn = False + +dockerPg :: IO (Maybe BS.ByteString) +dockerPg = do + env <- liftIO getEnvironment + return $ case lookup "POSTGRES_NAME" env of + Just _name -> Just "postgres" + _ -> Nothing + +persistSettings :: MkPersistSettings +persistSettings = sqlSettings{mpsGeneric = True} + +-- | Run with default 'FetchAll' mode. +runConn :: (MonadUnliftIO m) => SqlPersistT (LoggingT m) t -> m () +runConn f = runConn_ f >>= const (return ()) + +-- | Run with default 'FetchAll' mode, returning the result. +runConn_ :: (MonadUnliftIO m) => SqlPersistT (LoggingT m) t -> m t +runConn_ = runConnWith_ FetchAll + +-- | Run with the given 'FetchMode', discarding the result. +runConnWith :: (MonadUnliftIO m) => FetchMode -> SqlPersistT (LoggingT m) t -> m () +runConnWith mode f = runConnWith_ mode f >>= const (return ()) + +-- | Run with the given 'FetchMode', returning the result. +runConnWith_ :: (MonadUnliftIO m) => FetchMode -> SqlPersistT (LoggingT m) t -> m t +runConnWith_ fetchMode f = do + travis <- liftIO isTravis + let + debugPrint = not travis && _debugOn + printDebug = if debugPrint then print . fromLogStr else void . return + poolSize = 1 + connString <- + if travis + then do + pure + "host=localhost port=5432 user=perstest password=perstest dbname=persistent" + else do + host <- fromMaybe "localhost" <$> liftIO dockerPg + port <- fromMaybe "5432" <$> liftIO (lookupEnv "PGPORT") + mpass <- liftIO (lookupEnv "PGPASSWORD") + let passFragment = case mpass of + Just pw -> " password=" <> B8.pack pw + Nothing -> "" + pure ("host=" <> host <> " port=" <> B8.pack port <> " user=postgres" <> passFragment <> " dbname=test") + + flip runLoggingT (\_ _ _ s -> printDebug s) $ do + logInfoN (if travis then "Running in CI" else "CI not detected") + let settings = defaultPipelineSettings + { pipelineFetchMode = fetchMode + , pipelineConnStr = connString + , pipelinePoolSize = poolSize + } + go = withPostgresqlPipelinePool settings $ runSqlPool (withBaseBackend f) + -- Retry on connection failure, same as persistent-postgresql + eres <- try go + case eres of + Left (err :: SomeException) -> do + eres' <- try go + case eres' of + Left (err' :: SomeException) -> + if show err == show err' + then throwIO err + else throwIO err' + Right a -> + pure a + Right a -> + pure a + +-- | Default mode assertion runner. +runConnAssert :: SqlPersistT (LoggingT (ResourceT IO)) () -> Assertion +runConnAssert = runConnAssertWith FetchAll + +-- | Assertion runner parameterized by 'FetchMode'. +runConnAssertWith :: FetchMode -> SqlPersistT (LoggingT (ResourceT IO)) () -> Assertion +runConnAssertWith mode actions = do + runResourceT $ runConnWith mode $ actions >> transactionUndo + +newtype AValue = AValue {getValue :: Value} + +-- Need a specialized Arbitrary instance +instance Arbitrary AValue where + arbitrary = + AValue + <$> frequency + [ (1, pure Null) + , (1, Bool <$> arbitrary) + , (2, Number <$> arbitrary) + , (2, String <$> arbText) + , (3, Array <$> limitIt 4 (fmap (fmap getValue) arbitrary)) + , (3, object <$> arbObject) + ] + where + limitIt :: Int -> Gen a -> Gen a + limitIt i x = sized $ \n -> do + let + m = if n > i then i else n + resize m x + arbObject = + limitIt 4 + $ listOf + . liftA2 (,) arbText + $ limitIt 4 (fmap getValue arbitrary) diff --git a/persistent-postgresql-ng/test/PgPipelineIntervalTest.hs b/persistent-postgresql-ng/test/PgPipelineIntervalTest.hs new file mode 100644 index 000000000..71b4e486f --- /dev/null +++ b/persistent-postgresql-ng/test/PgPipelineIntervalTest.hs @@ -0,0 +1,56 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE EmptyDataDecls #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module PgPipelineIntervalTest where + +import Data.Fixed (Fixed (MkFixed), Micro, Pico) +import Data.Time.Clock (secondsToNominalDiffTime) +import Database.Persist.Postgresql.Pipeline (PgInterval (..)) +import PgPipelineInit +import Test.Hspec.QuickCheck + +share + [mkPersist sqlSettings, mkMigrate "pgIntervalMigrate"] + [persistLowerCase| +PgIntervalDb + interval_field PgInterval + deriving Eq + deriving Show +|] + +clamp :: (Ord a) => a -> a -> a -> a +clamp lo hi = max lo . min hi + +microsecondLimit :: Int64 +microsecondLimit = 2147483647 * 60 * 60 * 1000000 + +specs :: Spec +specs = do + describe "Postgres Interval Property tests" $ do + prop "Round trips" $ \int64 -> runConnAssert $ do + let + eg = + PgIntervalDb + . PgInterval + . secondsToNominalDiffTime + . (realToFrac :: Micro -> Pico) + . MkFixed + . toInteger + $ clamp (-microsecondLimit) microsecondLimit int64 + rid <- insert eg + r <- getJust rid + liftIO $ r `shouldBe` eg diff --git a/persistent-postgresql-ng/test/PipelineDeferralSpec.hs b/persistent-postgresql-ng/test/PipelineDeferralSpec.hs new file mode 100644 index 000000000..2cb47486c --- /dev/null +++ b/persistent-postgresql-ng/test/PipelineDeferralSpec.hs @@ -0,0 +1,407 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | Tests proving that pipeline mode defers result evaluation. +-- +-- Reads the @pgPending@ counter directly to verify that fire-and-forget +-- operations (delete, update, replace, etc.) increment the counter without +-- reading results, and that drain points (get, selectList, commit) reset it. +module PipelineDeferralSpec (specs) where + +import Control.Exception (SomeException, try) +import Control.Monad.Trans.Resource (runResourceT) +import Data.IORef (readIORef) +import Database.Persist.Postgresql.Pipeline (getPipelineConn) +import Database.Persist.Postgresql.Pipeline.Internal (PgConn (..)) +import PgPipelineInit + +share + [mkPersist sqlSettings, mkMigrate "deferralMigrate"] + [persistLowerCase| +DeferItem + name Text + value Int + deriving Show Eq +|] + +db :: SqlPersistT (LoggingT (ResourceT IO)) a -> IO a +db actions = runResourceT $ runConn_ $ do + void $ runMigrationSilent deferralMigrate + actions + +-- | Read the pending result count from the pipeline. +getPending :: (MonadIO m) => SqlPersistT m Int +getPending = do + backend <- ask + case getPipelineConn backend of + Nothing -> error "getPending: no PgConn (not a pipeline backend)" + Just pc -> liftIO $ length <$> readIORef (pgPending pc) + +specs :: Spec +specs = describe "Pipeline deferral (pgPending counter)" $ do + + it "delete increments pgPending, get drains it" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + k1 <- insert $ DeferItem "a" 1 + k2 <- insert $ DeferItem "b" 2 + k3 <- insert $ DeferItem "c" 3 + + -- insert uses stmtQuery (RETURNING), so pending should be 0 + n0 <- getPending + liftIO $ n0 `shouldBe` 0 + + -- delete uses stmtExecute (fire-and-forget) + delete k1 + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + delete k2 + n2 <- getPending + liftIO $ n2 `shouldBe` 2 + + -- get uses stmtQuery → drains all pending first + _ <- get k3 + n3 <- getPending + liftIO $ n3 `shouldBe` 0 + + it "update increments pgPending, selectList drains it" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + k1 <- insert $ DeferItem "x" 10 + k2 <- insert $ DeferItem "y" 20 + + n0 <- getPending + liftIO $ n0 `shouldBe` 0 + + update k1 [DeferItemValue =. 99] + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + update k2 [DeferItemValue =. 99] + n2 <- getPending + liftIO $ n2 `shouldBe` 2 + + -- selectList drains pending + items <- selectList ([] :: [Filter DeferItem]) [] + n3 <- getPending + liftIO $ n3 `shouldBe` 0 + liftIO $ map (deferItemValue . entityVal) items `shouldBe` [99, 99] + + it "replace increments pgPending" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + k <- insert $ DeferItem "orig" 1 + + n0 <- getPending + liftIO $ n0 `shouldBe` 0 + + replace k (DeferItem "replaced" 2) + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + mitem <- get k + n2 <- getPending + liftIO $ n2 `shouldBe` 0 + liftIO $ mitem `shouldBe` Just (DeferItem "replaced" 2) + + it "deleteWhere increments pgPending" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + _ <- insert $ DeferItem "del1" 1 + _ <- insert $ DeferItem "del2" 2 + + n0 <- getPending + liftIO $ n0 `shouldBe` 0 + + deleteWhere [DeferItemValue ==. 1] + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + remaining <- selectList ([] :: [Filter DeferItem]) [] + n2 <- getPending + liftIO $ n2 `shouldBe` 0 + liftIO $ length remaining `shouldBe` 1 + + it "updateWhere increments pgPending" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + _ <- insert $ DeferItem "uw1" 1 + _ <- insert $ DeferItem "uw2" 2 + + n0 <- getPending + liftIO $ n0 `shouldBe` 0 + + updateWhere [DeferItemValue ==. 1] [DeferItemValue =. 100] + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + updateWhere [DeferItemValue ==. 2] [DeferItemValue =. 200] + n2 <- getPending + liftIO $ n2 `shouldBe` 2 + + items <- selectList ([] :: [Filter DeferItem]) [Asc DeferItemValue] + n3 <- getPending + liftIO $ n3 `shouldBe` 0 + liftIO $ map (deferItemValue . entityVal) items `shouldBe` [100, 200] + + it "many fire-and-forget operations accumulate" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + keys <- mapM insert + [ DeferItem "m1" 1, DeferItem "m2" 2, DeferItem "m3" 3 + , DeferItem "m4" 4, DeferItem "m5" 5, DeferItem "m6" 6 + , DeferItem "m7" 7, DeferItem "m8" 8, DeferItem "m9" 9 + , DeferItem "m10" 10 + ] + + -- 10 deletes, all fire-and-forget + forM_ keys delete + n <- getPending + liftIO $ n `shouldBe` 10 + + -- selectList drains all 10 at once + remaining <- selectList ([] :: [Filter DeferItem]) [] + nAfter <- getPending + liftIO $ nAfter `shouldBe` 0 + liftIO $ remaining `shouldBe` [] + + it "first select forces only prior DML, second batch stays pending" $ db $ do + -- Proves that drain is scoped: only DML before the read is forced. + -- DML after the read remains pending until the next drain point. + deleteWhere ([] :: [Filter DeferItem]) + k1 <- insert $ DeferItem "a" 1 + k2 <- insert $ DeferItem "b" 2 + + -- Batch 1: two updates (fire-and-forget) + update k1 [DeferItemValue =. 10] + update k2 [DeferItemValue =. 20] + n1 <- getPending + liftIO $ n1 `shouldBe` 2 + + -- First select: drains batch 1 + items1 <- selectList ([] :: [Filter DeferItem]) [Asc DeferItemName] + n2 <- getPending + liftIO $ n2 `shouldBe` 0 + liftIO $ map (deferItemValue . entityVal) items1 `shouldBe` [10, 20] + + -- Batch 2: two more updates (new pending, independent of batch 1) + update k1 [DeferItemValue =. 100] + update k2 [DeferItemValue =. 200] + n3 <- getPending + liftIO $ n3 `shouldBe` 2 -- batch 2 is pending, NOT forced by first select + + -- Second select: drains batch 2 + items2 <- selectList ([] :: [Filter DeferItem]) [Asc DeferItemName] + n4 <- getPending + liftIO $ n4 `shouldBe` 0 + liftIO $ map (deferItemValue . entityVal) items2 `shouldBe` [100, 200] + + it "without a second read, second batch remains pending" $ db $ do + -- Proves that DML results are ONLY forced by explicit drain points. + -- If you never read after the second batch, those results stay pending + -- (commit at transaction end will eventually drain them). + deleteWhere ([] :: [Filter DeferItem]) + k <- insert $ DeferItem "x" 1 + + -- First DML + read + update k [DeferItemValue =. 10] + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + Just item1 <- get k + n2 <- getPending + liftIO $ n2 `shouldBe` 0 + liftIO $ deferItemValue item1 `shouldBe` 10 + + -- Second DML — NOT forced by the prior get + update k [DeferItemValue =. 20] + n3 <- getPending + liftIO $ n3 `shouldBe` 1 -- still pending, first read didn't force this + + -- Third DML — also not forced + update k [DeferItemValue =. 30] + n4 <- getPending + liftIO $ n4 `shouldBe` 2 -- both second and third are pending + + -- Only when we read again does the second batch drain + Just item2 <- get k + n5 <- getPending + liftIO $ n5 `shouldBe` 0 -- both forced now + liftIO $ deferItemValue item2 `shouldBe` 30 + + it "interleaved read-write: each read only forces prior writes" $ db $ do + -- Step-by-step interleaving to prove each read forces exactly + -- the writes that preceded it, and no more. + deleteWhere ([] :: [Filter DeferItem]) + k <- insert $ DeferItem "z" 0 + + -- Write 1 + update k [DeferItemValue =. 1] + liftIO . (1 `shouldBe`) =<< getPending + + -- Read 1: forces write 1 + Just v1 <- get k + liftIO . (0 `shouldBe`) =<< getPending + liftIO $ deferItemValue v1 `shouldBe` 1 + + -- Write 2 + update k [DeferItemValue =. 2] + liftIO . (1 `shouldBe`) =<< getPending + + -- Write 3 (stacks on write 2) + update k [DeferItemValue =. 3] + liftIO . (2 `shouldBe`) =<< getPending + + -- Read 2: forces writes 2 and 3 together + Just v2 <- get k + liftIO . (0 `shouldBe`) =<< getPending + liftIO $ deferItemValue v2 `shouldBe` 3 + + -- Write 4 + update k [DeferItemValue =. 4] + liftIO . (1 `shouldBe`) =<< getPending + + -- Read 3: forces write 4 + Just v3 <- get k + liftIO . (0 `shouldBe`) =<< getPending + liftIO $ deferItemValue v3 `shouldBe` 4 + + it "commit drains all pending (no explicit read needed)" $ db $ do + -- This test uses transactionSave (which commits the current + -- transaction and begins a new one) to prove that commit is + -- a drain point even without any read operations. + deleteWhere ([] :: [Filter DeferItem]) + k1 <- insert $ DeferItem "c1" 1 + k2 <- insert $ DeferItem "c2" 2 + + delete k1 + delete k2 + n <- getPending + liftIO $ n `shouldBe` 2 + + -- transactionSave commits + begins a new transaction + transactionSave + + -- After commit, pending should be 0 + -- (the new BEGIN from transactionSave is also fire-and-forget, + -- but it gets drained by the next read operation) + nAfterCommit <- getPending + -- BEGIN from the new transaction is pending + liftIO $ nAfterCommit `shouldBe` 1 + + -- Verify the deletes actually took effect + remaining <- selectList ([] :: [Filter DeferItem]) [] + liftIO $ remaining `shouldBe` [] + + it "insert (RETURNING) is a drain point" $ db $ do + -- insert uses stmtQuery for the RETURNING clause, so it must + -- drain all pending DML before executing. + deleteWhere ([] :: [Filter DeferItem]) + k1 <- insert $ DeferItem "drain1" 1 + + -- Fire-and-forget DML + update k1 [DeferItemValue =. 99] + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + -- insert triggers drain (stmtQuery for RETURNING) + _ <- insert $ DeferItem "drain2" 2 + n2 <- getPending + liftIO $ n2 `shouldBe` 0 + + -- Verify the update was applied (it was drained by the insert) + Just item <- get k1 + liftIO $ deferItemValue item `shouldBe` 99 + + it "rawExecute is fire-and-forget" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + _ <- insert $ DeferItem "raw" 1 + + n0 <- getPending + liftIO $ n0 `shouldBe` 0 + + -- rawExecute uses stmtExecute → fire-and-forget + rawExecute "UPDATE \"defer_item\" SET \"value\" = 42" [] + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + rawExecute "UPDATE \"defer_item\" SET \"value\" = 99" [] + n2 <- getPending + liftIO $ n2 `shouldBe` 2 + + -- Read forces drain + items <- selectList ([] :: [Filter DeferItem]) [] + n3 <- getPending + liftIO $ n3 `shouldBe` 0 + liftIO $ map (deferItemValue . entityVal) items `shouldBe` [99] + + it "count is a drain point" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + _ <- insert $ DeferItem "cnt1" 1 + _ <- insert $ DeferItem "cnt2" 2 + + -- Fire-and-forget + deleteWhere [DeferItemValue ==. 1] + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + -- count uses stmtQuery → drains pending + c <- count ([] :: [Filter DeferItem]) + n2 <- getPending + liftIO $ n2 `shouldBe` 0 + liftIO $ c `shouldBe` 1 + + it "exists is a drain point" $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + k <- insert $ DeferItem "ex" 1 + + delete k + n1 <- getPending + liftIO $ n1 `shouldBe` 1 + + -- exists uses stmtQuery → drains pending + e <- exists [DeferItemName ==. "ex"] + n2 <- getPending + liftIO $ n2 `shouldBe` 0 + liftIO $ e `shouldBe` False + + it "error in deferred DML surfaces at next drain point" $ do + -- A fire-and-forget query that will fail (bad table name). + -- The error is deferred until the next drain point (selectList). + result <- (try $ db $ do + deleteWhere ([] :: [Filter DeferItem]) + _ <- insert $ DeferItem "err" 1 + -- Fire-and-forget an invalid query + rawExecute "UPDATE \"nonexistent_table\" SET x = 1" [] + n <- getPending + liftIO $ n `shouldBe` 1 + -- selectList triggers drain → error surfaces + _ <- selectList ([] :: [Filter DeferItem]) [] + return () + ) :: IO (Either SomeException ()) + case result of + Left _e -> return () -- error propagated correctly + Right () -> expectationFailure "Expected exception from deferred bad query" + + it "connection pool works after deferred error + rollback" $ do + -- After the above error, verify the pool gives us a working + -- connection. (runSqlPool called connRollback, cleaning up + -- the pipeline before returning the connection to the pool.) + _ <- try $ db $ do + rawExecute "UPDATE \"nonexistent_table\" SET x = 1" [] + _ <- selectList ([] :: [Filter DeferItem]) [] + return () + :: IO (Either SomeException ()) + -- Fresh transaction should work fine + db $ do + deleteWhere ([] :: [Filter DeferItem]) + k <- insert $ DeferItem "recovery" 42 + Just item <- get k + liftIO $ deferItemValue item `shouldBe` 42 diff --git a/persistent-postgresql-ng/test/PipelineModeSpec.hs b/persistent-postgresql-ng/test/PipelineModeSpec.hs new file mode 100644 index 000000000..347662b2a --- /dev/null +++ b/persistent-postgresql-ng/test/PipelineModeSpec.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module PipelineModeSpec (specs, specsWith) where + +import Control.Monad.Trans.Resource (runResourceT) +import Database.Persist.Postgresql.Pipeline (withPipeline) +import PgPipelineInit + +share + [mkPersist sqlSettings, mkMigrate "pipelineMigrate"] + [persistLowerCase| +PipelineItem + name Text + value Int + deriving Show Eq +|] + +db :: SqlPersistT (LoggingT (ResourceT IO)) a -> IO a +db = dbWith FetchAll + +dbWith :: FetchMode -> SqlPersistT (LoggingT (ResourceT IO)) a -> IO a +dbWith mode actions = runResourceT $ runConnWith_ mode $ do + void $ runMigrationSilent pipelineMigrate + actions + +specs :: Spec +specs = specsWith FetchAll + +specsWith :: FetchMode -> Spec +specsWith mode = describe "Pipeline mode (automatic)" $ do + let db' = dbWith mode + + it "multiple deletes are pipelined" $ db' $ do + deleteWhere ([] :: [Filter PipelineItem]) + keys <- mapM insert + [ PipelineItem "a" 1 + , PipelineItem "b" 2 + , PipelineItem "c" 3 + ] + forM_ keys delete + remaining <- selectList ([] :: [Filter PipelineItem]) [] + liftIO $ remaining @?= [] + + it "multiple updates are pipelined" $ db' $ do + deleteWhere ([] :: [Filter PipelineItem]) + keys <- mapM insert + [ PipelineItem "x" 10 + , PipelineItem "y" 20 + , PipelineItem "z" 30 + ] + forM_ keys $ \k -> update k [PipelineItemValue =. 99] + items <- selectList ([] :: [Filter PipelineItem]) [] + let vals = map (pipelineItemValue . entityVal) items + liftIO $ vals @?= [99, 99, 99] + + it "mixed DML then read sees prior writes" $ db' $ do + deleteWhere ([] :: [Filter PipelineItem]) + k1 <- insert $ PipelineItem "first" 1 + update k1 [PipelineItemValue =. 42] + items <- selectList ([] :: [Filter PipelineItem]) [] + liftIO $ length items @?= 1 + liftIO $ (pipelineItemValue . entityVal . head) items @?= 42 + + it "deleteWhere is pipelined" $ db' $ do + deleteWhere ([] :: [Filter PipelineItem]) + _ <- insert $ PipelineItem "del1" 1 + _ <- insert $ PipelineItem "del2" 2 + deleteWhere [PipelineItemValue ==. 1] + remaining <- selectList ([] :: [Filter PipelineItem]) [] + liftIO $ length remaining @?= 1 + liftIO $ (pipelineItemName . entityVal . head) remaining @?= "del2" + + it "connection works across multiple transactions" $ db' $ do + deleteWhere ([] :: [Filter PipelineItem]) + _ <- insert $ PipelineItem "txn1" 100 + items1 <- selectList ([] :: [Filter PipelineItem]) [] + liftIO $ length items1 @?= 1 + + it "empty transaction succeeds" $ db' $ do + return () + + it "replace is pipelined" $ db' $ do + deleteWhere ([] :: [Filter PipelineItem]) + k <- insert $ PipelineItem "orig" 1 + replace k (PipelineItem "replaced" 2) + mitem <- get k + liftIO $ mitem @?= Just (PipelineItem "replaced" 2) + + it "withPipeline is a no-op (pipeline always on)" $ db' $ do + deleteWhere ([] :: [Filter PipelineItem]) + withPipeline $ do + _ <- insert $ PipelineItem "inside" 42 + items <- selectList ([] :: [Filter PipelineItem]) [] + liftIO $ length items @?= 1 diff --git a/persistent-postgresql-ng/test/PipelineRegressionSpec.hs b/persistent-postgresql-ng/test/PipelineRegressionSpec.hs new file mode 100644 index 000000000..e290542e5 --- /dev/null +++ b/persistent-postgresql-ng/test/PipelineRegressionSpec.hs @@ -0,0 +1,278 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +-- | Regression tests for Hedis-style automatic pipelining. +-- +-- These tests verify that all pipelined operations produce correct +-- results when interleaved in various patterns: +-- +-- * Back-to-back reads (get, getBy, count, exists) +-- * Back-to-back inserts with RETURNING +-- * Writes followed by reads +-- * Reads followed by writes followed by reads +-- * Mixed operations across entity types +-- * Large batches (100+ operations) +module PipelineRegressionSpec (specs) where + +import Control.Monad.Trans.Resource (runResourceT) +import qualified Data.Map.Strict as Map +import Data.Maybe (catMaybes, isJust, isNothing) +import qualified Data.Text as T +import PgPipelineInit + +share + [mkPersist sqlSettings, mkMigrate "pipeRegMigrate"] + [persistLowerCase| +PipeReg + name Text + value Int + UniquePipeRegName name + deriving Show Eq + +PipeRegOther + label Text + score Double + deriving Show Eq +|] + +dbWith :: FetchMode -> SqlPersistT (LoggingT (ResourceT IO)) a -> IO a +dbWith mode actions = runResourceT $ runConnWith_ mode $ do + void $ runMigrationSilent pipeRegMigrate + deleteWhere ([] :: [Filter PipeReg]) + deleteWhere ([] :: [Filter PipeRegOther]) + actions + +specs :: Spec +specs = describe "Pipeline regression" $ do + forM_ [("FetchAll", FetchAll), ("FetchSingleRow", FetchSingleRow)] $ + \(modeName, mode) -> describe modeName $ do + let db = dbWith mode + + --------------------------------------------------------------------------- + -- Back-to-back gets (Hedis-style pipelining) + --------------------------------------------------------------------------- + + describe "back-to-back get" $ do + it "returns correct results for 10 keys" $ db $ do + keys <- mapM insert + [ PipeReg (T.pack $ "g" <> show i) (i * 10) | i <- [1..10 :: Int] ] + results <- mapM get keys + liftIO $ length (catMaybes results) @?= 10 + liftIO $ forM_ (zip [1..] results) $ \(i :: Int, mr) -> + mr @?= Just (PipeReg (T.pack $ "g" <> show i) (i * 10)) + + it "returns Nothing for missing keys" $ db $ do + k <- insert $ PipeReg "exists" 1 + let missingK = toSqlKey 999999 + results <- sequence [get k, get missingK, get k] + liftIO $ map isJust results @?= [True, False, True] + + it "handles 100 gets" $ db $ do + keys <- mapM insert + [ PipeReg (T.pack $ "h" <> show i) i | i <- [1..100 :: Int] ] + results <- mapM get keys + liftIO $ length (catMaybes results) @?= 100 + + --------------------------------------------------------------------------- + -- Back-to-back inserts (pipelined RETURNING) + --------------------------------------------------------------------------- + + describe "back-to-back insert" $ do + it "returns distinct keys for 10 inserts" $ db $ do + keys <- mapM insert + [ PipeReg (T.pack $ "ins" <> show i) i | i <- [1..10 :: Int] ] + liftIO $ length keys @?= 10 + liftIO $ length (Map.fromList (zip keys keys)) @?= 10 + + it "inserted records are readable" $ db $ do + keys <- mapM insert + [ PipeReg (T.pack $ "ir" <> show i) (i * 100) | i <- [1..5 :: Int] ] + results <- mapM get keys + liftIO $ forM_ (zip [1..] results) $ \(i :: Int, mr) -> + mr @?= Just (PipeReg (T.pack $ "ir" <> show i) (i * 100)) + + --------------------------------------------------------------------------- + -- Back-to-back getBy (unique lookups) + --------------------------------------------------------------------------- + + describe "back-to-back getBy" $ do + it "finds existing records" $ db $ do + mapM_ insert + [ PipeReg "alice" 1, PipeReg "bob" 2, PipeReg "carol" 3 ] + results <- mapM (getBy . UniquePipeRegName) + ["alice", "bob", "carol"] + liftIO $ length (catMaybes results) @?= 3 + liftIO $ map (fmap (pipeRegValue . entityVal)) results + @?= [Just 1, Just 2, Just 3] + + it "returns Nothing for missing" $ db $ do + _ <- insert $ PipeReg "present" 1 + results <- sequence + [ getBy (UniquePipeRegName "present") + , getBy (UniquePipeRegName "absent") + ] + liftIO $ isJust (head results) @?= True + liftIO $ isNothing (results !! 1) @?= True + + --------------------------------------------------------------------------- + -- Back-to-back count / exists + --------------------------------------------------------------------------- + + describe "back-to-back count and exists" $ do + it "counts correctly" $ db $ do + mapM_ insert [ PipeReg (T.pack $ "c" <> show i) i | i <- [1..5 :: Int] ] + counts <- sequence + [ count ([] :: [Filter PipeReg]) + , count [PipeRegValue >. 3] + , count [PipeRegValue <. 0] + ] + liftIO $ counts @?= [5, 2, 0] + + it "exists correctly" $ db $ do + mapM_ insert [ PipeReg "e1" 1, PipeReg "e2" 2 ] + results <- sequence + [ exists ([] :: [Filter PipeReg]) + , exists [PipeRegValue ==. 1] + , exists [PipeRegValue ==. 999] + ] + liftIO $ results @?= [True, True, False] + + --------------------------------------------------------------------------- + -- Interleaved writes and reads + --------------------------------------------------------------------------- + + describe "interleaved writes and reads" $ do + it "write-read-write-read" $ db $ do + k1 <- insert $ PipeReg "wr1" 10 + r1 <- get k1 + k2 <- insert $ PipeReg "wr2" 20 + r2 <- get k2 + liftIO $ r1 @?= Just (PipeReg "wr1" 10) + liftIO $ r2 @?= Just (PipeReg "wr2" 20) + + it "fire-and-forget writes then reads see results" $ db $ do + k1 <- insert $ PipeReg "ff1" 1 + k2 <- insert $ PipeReg "ff2" 2 + k3 <- insert $ PipeReg "ff3" 3 + update k1 [PipeRegValue =. 100] + update k2 [PipeRegValue =. 200] + delete k3 + r1 <- get k1 + r2 <- get k2 + r3 <- get k3 + liftIO $ r1 @?= Just (PipeReg "ff1" 100) + liftIO $ r2 @?= Just (PipeReg "ff2" 200) + liftIO $ r3 @?= Nothing + + it "replace then get" $ db $ do + k <- insert $ PipeReg "rep" 1 + replace k (PipeReg "rep" 999) + r <- get k + liftIO $ r @?= Just (PipeReg "rep" 999) + + it "deleteWhere then count" $ db $ do + mapM_ insert [ PipeReg (T.pack $ "dw" <> show i) i | i <- [1..10 :: Int] ] + deleteWhere [PipeRegValue <=. 5] + n <- count ([] :: [Filter PipeReg]) + liftIO $ n @?= 5 + + it "updateWhere then selectList" $ db $ do + mapM_ insert [ PipeReg (T.pack $ "uw" <> show i) i | i <- [1..5 :: Int] ] + updateWhere [PipeRegValue <=. 3] [PipeRegValue =. 0] + items <- selectList ([] :: [Filter PipeReg]) [Asc PipeRegValue] + let vals = map (pipeRegValue . entityVal) items + liftIO $ filter (== 0) vals @?= [0, 0, 0] + liftIO $ length vals @?= 5 + + --------------------------------------------------------------------------- + -- Cross-entity operations + --------------------------------------------------------------------------- + + describe "cross-entity operations" $ do + it "interleave inserts across tables" $ db $ do + k1 <- insert $ PipeReg "cross1" 1 + k2 <- insert $ PipeRegOther "other1" 1.5 + k3 <- insert $ PipeReg "cross2" 2 + k4 <- insert $ PipeRegOther "other2" 2.5 + r1 <- get k1 + r2 <- get k2 + r3 <- get k3 + r4 <- get k4 + liftIO $ r1 @?= Just (PipeReg "cross1" 1) + liftIO $ r2 @?= Just (PipeRegOther "other1" 1.5) + liftIO $ r3 @?= Just (PipeReg "cross2" 2) + liftIO $ r4 @?= Just (PipeRegOther "other2" 2.5) + + --------------------------------------------------------------------------- + -- Large batches + --------------------------------------------------------------------------- + + describe "large batches" $ do + it "100 inserts then 100 gets" $ db $ do + keys <- mapM insert + [ PipeReg (T.pack $ "batch" <> show i) i | i <- [1..100 :: Int] ] + results <- mapM get keys + liftIO $ length (catMaybes results) @?= 100 + + it "100 deletes then count" $ db $ do + keys <- mapM insert + [ PipeReg (T.pack $ "del" <> show i) i | i <- [1..100 :: Int] ] + forM_ keys delete + n <- count ([] :: [Filter PipeReg]) + liftIO $ n @?= 0 + + it "100 updates then verify" $ db $ do + keys <- mapM insert + [ PipeReg (T.pack $ "upd" <> show i) i | i <- [1..100 :: Int] ] + forM_ keys $ \k -> update k [PipeRegValue =. 42] + results <- mapM get keys + liftIO $ all (== Just 42) (map (fmap pipeRegValue) results) @?= True + + --------------------------------------------------------------------------- + -- Edge cases + --------------------------------------------------------------------------- + + describe "edge cases" $ do + it "empty transaction" $ db $ return () + + it "get nonexistent key" $ db $ do + r <- get (toSqlKey 999999 :: Key PipeReg) + liftIO $ r @?= Nothing + + it "count empty table" $ db $ do + n <- count ([] :: [Filter PipeReg]) + liftIO $ n @?= 0 + + it "exists empty table" $ db $ do + b <- exists ([] :: [Filter PipeReg]) + liftIO $ b @?= False + + it "getBy nonexistent unique" $ db $ do + r <- getBy (UniquePipeRegName "nonexistent") + liftIO $ r @?= Nothing + + it "insert then immediate get same transaction" $ db $ do + k <- insert $ PipeReg "immediate" 42 + r <- get k + liftIO $ r @?= Just (PipeReg "immediate" 42) + + it "upsert pipeline" $ db $ do + e1 <- upsert (PipeReg "upserted" 1) [PipeRegValue =. 1] + e2 <- upsert (PipeReg "upserted" 2) [PipeRegValue =. 2] + r <- getBy (UniquePipeRegName "upserted") + liftIO $ pipeRegValue (entityVal e1) @?= 1 + liftIO $ pipeRegValue (entityVal e2) @?= 2 + liftIO $ (pipeRegValue . entityVal) <$> r @?= Just 2 diff --git a/persistent-postgresql-ng/test/PlaceholderSpec.hs b/persistent-postgresql-ng/test/PlaceholderSpec.hs new file mode 100644 index 000000000..9252bb5e3 --- /dev/null +++ b/persistent-postgresql-ng/test/PlaceholderSpec.hs @@ -0,0 +1,59 @@ +{-# LANGUAGE OverloadedStrings #-} + +module PlaceholderSpec (specs) where + +import Test.Hspec +import Database.Persist.Postgresql.Internal.Placeholders (rewritePlaceholders) + +specs :: Spec +specs = describe "rewritePlaceholders" $ do + it "rewrites a single ? to $1" $ do + rewritePlaceholders "SELECT * FROM t WHERE id = ?" + `shouldBe` ("SELECT * FROM t WHERE id = $1", 1) + + it "rewrites multiple ? to $1, $2, ..." $ do + rewritePlaceholders "INSERT INTO t (a, b, c) VALUES (?, ?, ?)" + `shouldBe` ("INSERT INTO t (a, b, c) VALUES ($1, $2, $3)", 3) + + it "replaces ?? with literal ?" $ do + rewritePlaceholders "RETURNING ??" + `shouldBe` ("RETURNING ?", 0) + + it "handles mixed ? and ??" $ do + rewritePlaceholders "INSERT INTO t (a) VALUES (?) RETURNING ??" + `shouldBe` ("INSERT INTO t (a) VALUES ($1) RETURNING ?", 1) + + it "does not replace ? inside single-quoted strings" $ do + rewritePlaceholders "SELECT * FROM t WHERE name = '?'" + `shouldBe` ("SELECT * FROM t WHERE name = '?'", 0) + + it "handles escaped quotes in string literals" $ do + rewritePlaceholders "SELECT * FROM t WHERE name = 'it''s a ?' AND id = ?" + `shouldBe` ("SELECT * FROM t WHERE name = 'it''s a ?' AND id = $1", 1) + + it "does not replace ? inside double-quoted identifiers" $ do + rewritePlaceholders "SELECT \"what?\" FROM t WHERE id = ?" + `shouldBe` ("SELECT \"what?\" FROM t WHERE id = $1", 1) + + it "does not replace ? inside line comments" $ do + rewritePlaceholders "SELECT * FROM t -- where id = ?\nWHERE name = ?" + `shouldBe` ("SELECT * FROM t -- where id = ?\nWHERE name = $1", 1) + + it "does not replace ? inside block comments" $ do + rewritePlaceholders "SELECT * FROM t /* where id = ? */ WHERE name = ?" + `shouldBe` ("SELECT * FROM t /* where id = ? */ WHERE name = $1", 1) + + it "handles nested block comments" $ do + rewritePlaceholders "SELECT /* outer /* inner ? */ still comment */ ?" + `shouldBe` ("SELECT /* outer /* inner ? */ still comment */ $1", 1) + + it "handles empty input" $ do + rewritePlaceholders "" `shouldBe` ("", 0) + + it "handles no placeholders" $ do + rewritePlaceholders "SELECT 1" `shouldBe` ("SELECT 1", 0) + + it "handles large parameter counts" $ do + let sql = mconcat $ replicate 100 "?," + (result, count) = rewritePlaceholders sql + count `shouldBe` 100 diff --git a/persistent-postgresql-ng/test/UpsertWhere.hs b/persistent-postgresql-ng/test/UpsertWhere.hs new file mode 100644 index 000000000..ebd91ffbd --- /dev/null +++ b/persistent-postgresql-ng/test/UpsertWhere.hs @@ -0,0 +1,207 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module UpsertWhere where + +import PgPipelineInit + +import Data.Time +import Database.Persist.Postgresql.Pipeline + +share + [mkPersist sqlSettings, mkMigrate "upsertWhereMigrate"] + [persistLowerCase| + +Item + name Text sqltype=varchar(80) + description Text + price Double Maybe + quantity Int Maybe + + UniqueName name + deriving Eq Show Ord + +ItemMigOnly + name Text + price Double + quantity Int + + UniqueNameMigOnly name + + createdAt UTCTime MigrationOnly default=CURRENT_TIMESTAMP + +|] + +wipe :: IO () +wipe = runConnAssert $ do + deleteWhere ([] :: [Filter Item]) + deleteWhere ([] :: [Filter ItemMigOnly]) + +itDb + :: String -> SqlPersistT (LoggingT (ResourceT IO)) a -> SpecWith (Arg (IO ())) +itDb msg action = it msg $ runConnAssert $ void action + +specs :: Spec +specs = describe "UpsertWhere" $ do + let + item1 = Item "item1" "" (Just 3) Nothing + item2 = Item "item2" "hello world" Nothing (Just 2) + items = [item1, item2] + + describe "upsertWhere" $ before_ wipe $ do + itDb "inserts appropriately" $ do + upsertWhere item1 [ItemDescription =. "i am item 1"] [] + Just item <- fmap entityVal <$> getBy (UniqueName "item1") + item `shouldBe` item1 + itDb "performs only updates given if record already exists" $ do + let + newDescription = "I am a new description" + insert_ item1 + upsertWhere + (Item "item1" "i am an inserted description" (Just 1) (Just 2)) + [ItemDescription =. newDescription] + [] + Just item <- fmap entityVal <$> getBy (UniqueName "item1") + item `shouldBe` item1{itemDescription = newDescription} + + itDb "inserts with MigrationOnly fields (#1330)" $ do + upsertWhere + (ItemMigOnly "foobar" 20 1) + [ItemMigOnlyPrice +=. 2] + [] + + describe "upsertManyWhere" $ do + itDb "inserts fresh records" $ do + insertMany_ items + let + newItem = Item "item3" "fresh" Nothing Nothing + upsertManyWhere + (newItem : items) + [copyField ItemDescription] + [] + [] + dbItems <- map entityVal <$> selectList [] [] + dbItems `shouldMatchList` (newItem : items) + itDb "updates existing records" $ do + let + postUpdate = + map (\i -> i{itemQuantity = fmap (+ 1) (itemQuantity i)}) items + insertMany_ items + upsertManyWhere + items + [] + [ItemQuantity +=. Just 1] + [] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` postUpdate + itDb "only copies passing values" $ do + insertMany_ items + let + newItems = + map + (\i -> i{itemQuantity = Just 0, itemPrice = fmap (* 2) (itemPrice i)}) + items + postUpdate = map (\i -> i{itemPrice = fmap (* 2) (itemPrice i)}) items + upsertManyWhere + newItems + [ copyUnlessEq ItemQuantity (Just 0) + , copyField ItemPrice + ] + [] + [] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` postUpdate + itDb "inserts without modifying existing records if no updates specified" $ do + let + newItem = Item "item3" "hi friends!" Nothing Nothing + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [] + [] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` (newItem : items) + itDb "inserts without modifying existing records with True filter condition" $ do + let + newItem = Item "item3" "hi friends!" Nothing Nothing + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [] + [ItemDescription ==. "hi friends!"] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` (newItem : items) + itDb "inserts without updating with False filter condition" $ do + let + newItem = Item "item3" "hi friends!" Nothing Nothing + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [ItemQuantity +=. Just 1] + [ItemDescription ==. "hi friends!"] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` (newItem : items) + itDb "doesn't apply update with excludeNotEqualToOriginal" $ do + let + newItem = Item "item3" "hi friends!" Nothing Nothing + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [ItemQuantity +=. Just 1] + [excludeNotEqualToOriginal ItemDescription] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` (newItem : items) + itDb "inserts new and updates existing with empty filter" $ do + let + newItem = Item "item3" "hello world" Nothing Nothing + postUpdate = map (\i -> i{itemQuantity = fmap (+ 1) (itemQuantity i)}) items + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [ItemQuantity +=. Just 1] + [] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` (newItem : postUpdate) + itDb "inserts new and updates existing with matching filter" $ do + let + newItem = Item "item3" "hi friends!" Nothing Nothing + postUpdate = map (\i -> i{itemQuantity = fmap (+ 1) (itemQuantity i)}) items + insertMany_ items + upsertManyWhere + (newItem : items) + [ copyUnlessEq ItemDescription "hi friends!" + , copyField ItemPrice + ] + [ItemQuantity +=. Just 1] + [ItemDescription !=. "bye friends!"] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` (newItem : postUpdate) + itDb "insert item doesn't apply update with excludeNotEqualToOriginal" $ do + let + newItem = Item "item3" "hello world" Nothing Nothing + insertMany_ items + upsertManyWhere + (newItem : items) + [] + [ItemQuantity +=. Just 1] + [excludeNotEqualToOriginal ItemDescription] + dbItems <- fmap entityVal <$> selectList [] [] + dbItems `shouldMatchList` (newItem : items) diff --git a/persistent-postgresql-ng/test/main.hs b/persistent-postgresql-ng/test/main.hs new file mode 100644 index 000000000..5fde21e64 --- /dev/null +++ b/persistent-postgresql-ng/test/main.hs @@ -0,0 +1,280 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -Wno-unused-top-binds #-} + +import PgPipelineInit + +import Data.Aeson +import qualified Data.ByteString as BS +import Data.Fixed +import Data.IntMap (IntMap) +import qualified Data.Text as T +import Data.Time +import Test.QuickCheck + +import qualified ArrayAggTest +import qualified BinaryRoundTripSpec +import qualified DirectDecodeSpec +import qualified DirectEntityPOC +import qualified CompositeTest +import qualified CustomConstraintTest +import qualified CustomPersistFieldTest +import qualified CustomPrimaryKeyReferenceTest +import qualified DataTypeTest +import qualified EmbedOrderTest +import qualified EmbedTest +import qualified EmptyEntityTest +import qualified EquivalentTypeTestPostgres +import qualified ForeignKey +import qualified GeneratedColumnTestSQL +import qualified HtmlTest +import qualified ImplicitUuidSpec +import qualified InCollapseSpec +import qualified JSONTest +import qualified LargeNumberTest +import qualified LongIdentifierTest +import qualified MaxLenTest +import qualified MaybeFieldDefsTest +import qualified MigrationColumnLengthTest +import qualified MigrationOnlyTest +import qualified MigrationReferenceSpec +import qualified MigrationSpec +import qualified MigrationTest +import qualified MpsCustomPrefixTest +import qualified MpsNoPrefixTest +import qualified PersistUniqueTest +import qualified PersistentTest +import qualified PgPipelineIntervalTest +import qualified PipelineDeferralSpec +import qualified PipelineModeSpec +import qualified PipelineRegressionSpec +import qualified PlaceholderSpec +import qualified PrimaryTest +import qualified RawSqlTest +import qualified ReadWriteTest +import qualified Recursive +import qualified RenameTest +import qualified SumTypeTest +import qualified TransactionLevelTest +import qualified TreeTest +import qualified TypeLitFieldDefsTest +import qualified UniqueTest +import qualified UpsertTest +import qualified UpsertWhere + +type Tuple = (,) + +-- Test lower case names +share + [mkPersist persistSettings, mkMigrate "dataTypeMigrate"] + [persistLowerCase| +DataTypeTable no-json + text Text + textMaxLen Text maxlen=100 + bytes ByteString + bytesTextTuple (Tuple ByteString Text) + bytesMaxLen ByteString maxlen=100 + int Int + intList [Int] + intMap (IntMap Int) + double Double + bool Bool + day Day + pico Pico + time TimeOfDay + utc UTCTime + jsonb Value +|] + +instance Arbitrary DataTypeTable where + arbitrary = + DataTypeTable + <$> arbText -- text + <*> (T.take 100 <$> arbText) -- textManLen + <*> arbitrary -- bytes + <*> liftA2 (,) arbitrary arbText -- bytesTextTuple + <*> (BS.take 100 <$> arbitrary) -- bytesMaxLen + <*> arbitrary -- int + <*> arbitrary -- intList + <*> arbitrary -- intMap + <*> arbitrary -- double + <*> arbitrary -- bool + <*> arbitrary -- day + <*> arbitrary -- pico + <*> (arbitrary) -- utc + <*> (truncateUTCTime =<< arbitrary) -- utc + <*> fmap getValue arbitrary -- value + +setup :: (MonadIO m) => Migration -> ReaderT SqlBackend m () +setup migration = do + printMigration migration + runMigrationUnsafe migration + +-- | All fetch modes to test. FetchChunked uses a small chunk size to +-- exercise the multi-result loop even on small result sets. +allFetchModes :: [(String, FetchMode)] +allFetchModes = + [ ("FetchAll", FetchAll) + , ("FetchSingleRow", FetchSingleRow) + , ("FetchChunked 2", FetchChunked 2) + ] + +main :: IO () +main = do + -- Pure tests run first (no database required) + hspec $ do + PlaceholderSpec.specs + InCollapseSpec.specs + BinaryRoundTripSpec.specs + DirectDecodeSpec.specs + DirectEntityPOC.specs + + -- Database tests require PostgreSQL + -- Migrations run once under FetchAll (DDL doesn't use row fetch modes) + runConn $ do + mapM_ + setup + [ PersistentTest.testMigrate + , PersistentTest.noPrefixMigrate + , PersistentTest.customPrefixMigrate + , PersistentTest.treeMigrate + , EmbedTest.embedMigrate + , EmbedOrderTest.embedOrderMigrate + , LargeNumberTest.numberMigrate + , UniqueTest.uniqueMigrate + , MaxLenTest.maxlenMigrate + , MaybeFieldDefsTest.maybeFieldDefMigrate + , TypeLitFieldDefsTest.typeLitFieldDefsMigrate + , Recursive.recursiveMigrate + , CompositeTest.compositeMigrate + , TreeTest.treeMigrate + , PersistUniqueTest.migration + , RenameTest.migration + , CustomPersistFieldTest.customFieldMigrate + , PrimaryTest.migration + , CustomPrimaryKeyReferenceTest.migration + , MigrationColumnLengthTest.migration + , TransactionLevelTest.migration + , LongIdentifierTest.migration + , ForeignKey.compositeMigrate + , MigrationTest.migrationMigrate + , PgPipelineIntervalTest.pgIntervalMigrate + , UpsertWhere.upsertWhereMigrate + , ImplicitUuidSpec.implicitUuidMigrate + ] + PersistentTest.cleanDB + ForeignKey.cleanDB + + -- Run database tests for each fetch mode + forM_ allFetchModes $ \(modeName, mode) -> do + -- Clean up between mode runs to prevent test isolation issues + runConnWith mode $ do + PersistentTest.cleanDB + ForeignKey.cleanDB + let run = runConnAssertWith mode + hspec $ describe modeName $ do + -- Pipeline-specific tests + PipelineModeSpec.specsWith mode + PipelineDeferralSpec.specs + PgPipelineIntervalTest.specs + PipelineRegressionSpec.specs + + -- DirectEntity integration: rawSqlDirectCompat through SqlBackend + DirectEntityPOC.integrationSpecs + describe "rawSqlDirectCompat through SqlBackend" $ do + it "decodes rows via DirectEntity + PgRowEnv" $ run $ do + result <- DirectEntityPOC.rawSqlDirectCompatTest + liftIO $ result `shouldBe` Just + [ DirectEntityPOC.PgPair (T.pack "Alice") 30 + , DirectEntityPOC.PgPair (T.pack "Bob") 25 + ] + + -- Postgres-specific tests + ImplicitUuidSpec.spec + MigrationReferenceSpec.spec + MigrationSpec.spec + EquivalentTypeTestPostgres.specs + JSONTest.specs + CustomConstraintTest.specs + UpsertWhere.specs + ArrayAggTest.specs + + -- Shared persistent-test suite + RenameTest.specsWith run + DataTypeTest.specsWith + run + (Just (runMigrationSilent dataTypeMigrate)) + [ TestFn "text" dataTypeTableText + , TestFn "textMaxLen" dataTypeTableTextMaxLen + , TestFn "bytes" dataTypeTableBytes + , TestFn "bytesTextTuple" dataTypeTableBytesTextTuple + , TestFn "bytesMaxLen" dataTypeTableBytesMaxLen + , TestFn "int" dataTypeTableInt + , TestFn "intList" dataTypeTableIntList + , TestFn "intMap" dataTypeTableIntMap + , TestFn "bool" dataTypeTableBool + , TestFn "day" dataTypeTableDay + , TestFn "time" (DataTypeTest.roundTime . dataTypeTableTime) + , TestFn "utc" (DataTypeTest.roundUTCTime . dataTypeTableUtc) + , TestFn "jsonb" dataTypeTableJsonb + ] + [("pico", dataTypeTablePico)] + dataTypeTableDouble + HtmlTest.specsWith + run + (Just (runMigrationSilent HtmlTest.htmlMigrate)) + + EmbedTest.specsWith run + EmbedOrderTest.specsWith run + LargeNumberTest.specsWith run + ForeignKey.specsWith run + UniqueTest.specsWith run + MaxLenTest.specsWith run + MaybeFieldDefsTest.specsWith run + TypeLitFieldDefsTest.specsWith run + Recursive.specsWith run + SumTypeTest.specsWith + run + (Just (runMigrationSilent SumTypeTest.sumTypeMigrate)) + MigrationTest.specsWith run + MigrationOnlyTest.specsWith + run + ( Just $ + runMigrationSilent MigrationOnlyTest.migrateAll1 + >> runMigrationSilent MigrationOnlyTest.migrateAll2 + ) + PersistentTest.specsWith run + ReadWriteTest.specsWith run + PersistentTest.filterOrSpecs run + RawSqlTest.specsWith run + UpsertTest.specsWith + run + UpsertTest.Don'tUpdateNull + UpsertTest.UpsertPreserveOldKey + + MpsNoPrefixTest.specsWith run + MpsCustomPrefixTest.specsWith run + EmptyEntityTest.specsWith + run + (Just (runMigrationSilent EmptyEntityTest.migration)) + CompositeTest.specsWith run + TreeTest.specsWith run + PersistUniqueTest.specsWith run + PrimaryTest.specsWith run + CustomPersistFieldTest.specsWith run + CustomPrimaryKeyReferenceTest.specsWith run + MigrationColumnLengthTest.specsWith run + TransactionLevelTest.specsWith run + LongIdentifierTest.specsWith run + GeneratedColumnTestSQL.specsWith run