From c015d2fd3723e23a28a6f29ef2bd8545755a7e52 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 22 Apr 2024 08:20:03 +0800 Subject: [PATCH 1/2] chore: merge from dev --- packages/plugins/swr/tests/test-model-meta.ts | 4 +- .../tanstack-query/tests/test-model-meta.ts | 4 +- .../routers/generated/routers/Post.router.ts | 23 +++ .../routers/generated/routers/User.router.ts | 23 +++ packages/runtime/src/cross/utils.ts | 20 +- .../src/enhancements/policy/handler.ts | 129 ++++++++---- .../runtime/src/enhancements/query-utils.ts | 53 ++++- packages/runtime/src/enhancements/utils.ts | 8 +- packages/sdk/src/model-meta-generator.ts | 45 +++- .../with-policy/multi-id-fields.test.ts | 146 ++++++++++++- .../with-policy/nested-to-many.test.ts | 95 +++++++++ .../with-policy/nested-to-one.test.ts | 58 ++++++ .../with-policy/toplevel-operations.test.ts | 76 +++++++ .../tests/regression/issue-1271.test.ts | 192 ++++++++++++++++++ 14 files changed, 808 insertions(+), 68 deletions(-) create mode 100644 tests/integration/tests/regression/issue-1271.test.ts diff --git a/packages/plugins/swr/tests/test-model-meta.ts b/packages/plugins/swr/tests/test-model-meta.ts index 71a657bad..001d773a9 100644 --- a/packages/plugins/swr/tests/test-model-meta.ts +++ b/packages/plugins/swr/tests/test-model-meta.ts @@ -32,7 +32,7 @@ export const modelMeta: ModelMeta = { name: 'posts', }, }, - uniqueConstraints: {}, + uniqueConstraints: { id: { name: 'id', fields: ['id'] } }, }, post: { name: 'post', @@ -48,7 +48,7 @@ export const modelMeta: ModelMeta = { owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, }, - uniqueConstraints: {}, + uniqueConstraints: { id: { name: 'id', fields: ['id'] } }, }, }, deleteCascade: { diff --git a/packages/plugins/tanstack-query/tests/test-model-meta.ts b/packages/plugins/tanstack-query/tests/test-model-meta.ts index 71a657bad..001d773a9 100644 --- a/packages/plugins/tanstack-query/tests/test-model-meta.ts +++ b/packages/plugins/tanstack-query/tests/test-model-meta.ts @@ -32,7 +32,7 @@ export const modelMeta: ModelMeta = { name: 'posts', }, }, - uniqueConstraints: {}, + uniqueConstraints: { id: { name: 'id', fields: ['id'] } }, }, post: { name: 'post', @@ -48,7 +48,7 @@ export const modelMeta: ModelMeta = { owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, }, - uniqueConstraints: {}, + uniqueConstraints: { id: { name: 'id', fields: ['id'] } }, }, }, deleteCascade: { diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts index fbc73cf06..15408f3ef 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts @@ -61,6 +61,29 @@ export interface ClientType; }; + createMany: { + useMutation: ( + opts?: UseTRPCMutationOptions< + Prisma.PostCreateManyArgs, + TRPCClientErrorLike, + Prisma.BatchPayload, + Context + >, + ) => Omit< + UseTRPCMutationResult< + Prisma.BatchPayload, + TRPCClientErrorLike, + Prisma.SelectSubset, + Context + >, + 'mutateAsync' + > & { + mutateAsync: ( + variables: T, + opts?: UseTRPCMutationOptions, Prisma.BatchPayload, Context>, + ) => Promise; + }; + }; create: { useMutation: (opts?: UseTRPCMutationOptions< diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts index c4bdb89de..cb9c8614b 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts @@ -61,6 +61,29 @@ export interface ClientType; }; + createMany: { + useMutation: ( + opts?: UseTRPCMutationOptions< + Prisma.UserCreateManyArgs, + TRPCClientErrorLike, + Prisma.BatchPayload, + Context + >, + ) => Omit< + UseTRPCMutationResult< + Prisma.BatchPayload, + TRPCClientErrorLike, + Prisma.SelectSubset, + Context + >, + 'mutateAsync' + > & { + mutateAsync: ( + variables: T, + opts?: UseTRPCMutationOptions, Prisma.BatchPayload, Context>, + ) => Promise; + }; + }; create: { useMutation: (opts?: UseTRPCMutationOptions< diff --git a/packages/runtime/src/cross/utils.ts b/packages/runtime/src/cross/utils.ts index 1982513b3..304b9b618 100644 --- a/packages/runtime/src/cross/utils.ts +++ b/packages/runtime/src/cross/utils.ts @@ -1,5 +1,5 @@ import { lowerCaseFirst } from 'lower-case-first'; -import { ModelInfo, ModelMeta } from '.'; +import { requireField, type ModelInfo, type ModelMeta } from '.'; /** * Gets field names in a data model entity, filtering out internal fields. @@ -47,19 +47,17 @@ export function zip(x: Enumerable, y: Enumerable): Array<[T1, T2 } export function getIdFields(modelMeta: ModelMeta, model: string, throwIfNotFound = false) { - let fields = modelMeta.models[lowerCaseFirst(model)]?.fields; - if (!fields) { + const uniqueConstraints = modelMeta.models[lowerCaseFirst(model)]?.uniqueConstraints ?? {}; + + const entries = Object.values(uniqueConstraints); + if (entries.length === 0) { if (throwIfNotFound) { - throw new Error(`Unable to load fields for ${model}`); - } else { - fields = {}; + throw new Error(`Model ${model} does not have any id field`); } + return []; } - const result = Object.values(fields).filter((f) => f.isId); - if (result.length === 0 && throwIfNotFound) { - throw new Error(`model ${model} does not have an id field`); - } - return result; + + return entries[0].fields.map((f) => requireField(modelMeta, model, f)); } export function getModelInfo( diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index ea27fc1db..d6d893d4e 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -22,7 +22,7 @@ import { Logger } from '../logger'; import { createDeferredPromise, createFluentPromise } from '../promise'; import { PrismaProxyHandler } from '../proxy'; import { QueryUtils } from '../query-utils'; -import { formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; +import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils'; import { PolicyUtil } from './policy-utils'; // a record for post-write policy check @@ -117,7 +117,7 @@ export class PolicyProxyHandler implements Pr // make a find query promise with fluent API call stubs installed private findWithFluent(method: FindOperations, args: any, handleRejection: () => any) { - args = this.policyUtils.clone(args); + args = clone(args); return createFluentPromise( () => this.doFind(args, method, handleRejection), args, @@ -128,7 +128,7 @@ export class PolicyProxyHandler implements Pr private async doFind(args: any, actionName: FindOperations, handleRejection: () => any) { const origArgs = args; - const _args = this.policyUtils.clone(args); + const _args = clone(args); if (!this.policyUtils.injectForRead(this.prisma, this.model, _args)) { if (this.shouldLogQuery) { this.logger.info(`[policy] \`${actionName}\` ${this.model}: unconditionally denied`); @@ -167,7 +167,7 @@ export class PolicyProxyHandler implements Pr this.policyUtils.tryReject(this.prisma, this.model, 'create'); const origArgs = args; - args = this.policyUtils.clone(args); + args = clone(args); // static input policy check for top-level create data const inputCheck = this.policyUtils.checkInputGuard(this.model, args.data, 'create'); @@ -364,7 +364,7 @@ export class PolicyProxyHandler implements Pr }); // return only the ids of the top-level entity - const ids = this.policyUtils.getEntityIds(this.model, result); + const ids = this.policyUtils.getEntityIds(model, result); return { result: ids, postWriteChecks: [...postCreateChecks.values()] }; } @@ -434,7 +434,7 @@ export class PolicyProxyHandler implements Pr return createDeferredPromise(async () => { this.policyUtils.tryReject(this.prisma, this.model, 'create'); - args = this.policyUtils.clone(args); + args = clone(args); // go through create items, statically check input to determine if post-create // check is needed, and also validate zod schema @@ -615,7 +615,7 @@ export class PolicyProxyHandler implements Pr } return createDeferredPromise(async () => { - args = this.policyUtils.clone(args); + args = clone(args); const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { // proceed with nested writes and collect post-write checks @@ -743,8 +743,10 @@ export class PolicyProxyHandler implements Pr } // proceed with the create and collect post-create checks - const { postWriteChecks: checks } = await this.doCreate(model, { data: createData }, db); + const { postWriteChecks: checks, result } = await this.doCreate(model, { data: createData }, db); postWriteChecks.push(...checks); + + return result; }; const _createMany = async ( @@ -832,18 +834,10 @@ export class PolicyProxyHandler implements Pr // check pre-update guard await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); - // handles the case where id fields are updated - const postUpdateIds = this.policyUtils.clone(existing); - for (const key of Object.keys(existing)) { - const updateValue = (args as any).data ? (args as any).data[key] : (args as any)[key]; - if ( - typeof updateValue === 'string' || - typeof updateValue === 'number' || - typeof updateValue === 'bigint' - ) { - postUpdateIds[key] = updateValue; - } - } + // handle the case where id fields are updated + const _args: any = args; + const updatePayload = _args.data && typeof _args.data === 'object' ? _args.data : _args; + const postUpdateIds = this.calculatePostUpdateIds(model, existing, updatePayload); // register post-update check await _registerPostUpdateCheck(model, existing, postUpdateIds); @@ -935,10 +929,13 @@ export class PolicyProxyHandler implements Pr // update case // check pre-update guard - await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, existing, 'update', db, args); + + // handle the case where id fields are updated + const postUpdateIds = this.calculatePostUpdateIds(model, existing, args.update); // register post-update check - await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter); + await _registerPostUpdateCheck(model, existing, postUpdateIds); // convert upsert to update const convertedUpdate = { @@ -972,9 +969,22 @@ export class PolicyProxyHandler implements Pr if (existing) { // connect await _connectDisconnect(model, args.where, context); + return true; } else { // create - await _create(model, args.create, context); + const created = await _create(model, args.create, context); + + const upperContext = context.nestingPath[context.nestingPath.length - 2]; + if (upperContext?.where && context.field) { + // check if the where clause of the upper context references the id + // of the connected entity, if so, we need to update it + this.overrideForeignKeyFields(upperContext.model, upperContext.where, context.field, created); + } + + // remove the payload from the parent + this.removeFromParent(context.parent, 'connectOrCreate', args); + + return false; } }, @@ -1044,6 +1054,52 @@ export class PolicyProxyHandler implements Pr return { result, postWriteChecks }; } + // calculate id fields used for post-update check given an update payload + private calculatePostUpdateIds(_model: string, currentIds: any, updatePayload: any) { + const result = clone(currentIds); + for (const key of Object.keys(currentIds)) { + const updateValue = updatePayload[key]; + if (typeof updateValue === 'string' || typeof updateValue === 'number' || typeof updateValue === 'bigint') { + result[key] = updateValue; + } + } + return result; + } + + // updates foreign key fields inside `payload` based on relation id fields in `newIds` + private overrideForeignKeyFields( + model: string, + payload: any, + relation: FieldInfo, + newIds: Record + ) { + if (!relation.foreignKeyMapping || Object.keys(relation.foreignKeyMapping).length === 0) { + return; + } + + // override foreign key values + for (const [id, fk] of Object.entries(relation.foreignKeyMapping)) { + if (payload[fk] !== undefined && newIds[id] !== undefined) { + payload[fk] = newIds[id]; + } + } + + // deal with compound id fields + const uniqueConstraints = this.policyUtils.getUniqueConstraints(model); + for (const [name, constraint] of Object.entries(uniqueConstraints)) { + if (constraint.fields.length > 1) { + const target = payload[name]; + if (target) { + for (const [id, fk] of Object.entries(relation.foreignKeyMapping)) { + if (target[fk] !== undefined && newIds[id] !== undefined) { + target[fk] = newIds[id]; + } + } + } + } + } + } + // Validates the given update payload against Zod schema if any private validateUpdateInputSchema(model: string, data: any) { const schema = this.policyUtils.getZodSchema(model, 'update'); @@ -1085,16 +1141,12 @@ export class PolicyProxyHandler implements Pr this.prismaModule, 'data field is required in query argument' ); - throw prismaClientValidationError(this.prisma, this.options, 'query argument is required'); - } - if (!args.data) { - throw prismaClientValidationError(this.prisma, this.options, 'data field is required in query argument'); } return createDeferredPromise(() => { this.policyUtils.tryReject(this.prisma, this.model, 'update'); - args = this.policyUtils.clone(args); + args = clone(args); this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); args.data = this.validateUpdateInputSchema(this.model, args.data); @@ -1174,18 +1226,25 @@ export class PolicyProxyHandler implements Pr this.policyUtils.tryReject(this.prisma, this.model, 'create'); this.policyUtils.tryReject(this.prisma, this.model, 'update'); - args = this.policyUtils.clone(args); + args = clone(args); // We can call the native "upsert" because we can't tell if an entity was created or updated // for doing post-write check accordingly. Instead, decompose it into create or update. const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { const { where, create, update, ...rest } = args; - const existing = await this.policyUtils.checkExistence(tx, this.model, args.where); + const existing = await this.policyUtils.checkExistence(tx, this.model, where); if (existing) { // update case - const { result, postWriteChecks } = await this.doUpdate({ where, data: update, ...rest }, tx); + const { result, postWriteChecks } = await this.doUpdate( + { + where: this.policyUtils.composeCompoundUniqueField(this.model, existing), + data: update, + ...rest, + }, + tx + ); await this.runPostWriteChecks(postWriteChecks, tx); return this.policyUtils.readBack(tx, this.model, 'update', args, result); } else { @@ -1281,7 +1340,7 @@ export class PolicyProxyHandler implements Pr } return createDeferredPromise(() => { - args = this.policyUtils.clone(args); + args = clone(args); // inject policy conditions this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); @@ -1299,7 +1358,7 @@ export class PolicyProxyHandler implements Pr } return createDeferredPromise(() => { - args = this.policyUtils.clone(args); + args = clone(args); // inject policy conditions this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); @@ -1314,7 +1373,7 @@ export class PolicyProxyHandler implements Pr count(args: any) { return createDeferredPromise(() => { // inject policy conditions - args = args ? this.policyUtils.clone(args) : {}; + args = args ? clone(args) : {}; this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { @@ -1350,7 +1409,7 @@ export class PolicyProxyHandler implements Pr // include all args = { create: {}, update: {}, delete: {} }; } else { - args = this.policyUtils.clone(args); + args = clone(args); } } diff --git a/packages/runtime/src/enhancements/query-utils.ts b/packages/runtime/src/enhancements/query-utils.ts index c161d5e2c..81c8d1da9 100644 --- a/packages/runtime/src/enhancements/query-utils.ts +++ b/packages/runtime/src/enhancements/query-utils.ts @@ -1,16 +1,16 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { - FieldInfo, - NestedWriteVisitorContext, getIdFields, getModelInfo, getUniqueConstraints, resolveField, + type FieldInfo, + type NestedWriteVisitorContext, } from '../cross'; -import { CrudContract, DbClientContract } from '../types'; +import type { CrudContract, DbClientContract } from '../types'; import { getVersion } from '../version'; import { InternalEnhancementOptions } from './create-enhancement'; -import { prismaClientUnknownRequestError, prismaClientValidationError } from './utils'; +import { clone, prismaClientUnknownRequestError, prismaClientValidationError } from './utils'; export class QueryUtils { constructor(private readonly prisma: DbClientContract, protected readonly options: InternalEnhancementOptions) {} @@ -56,7 +56,10 @@ export class QueryUtils { } } - buildReversedQuery(context: NestedWriteVisitorContext, mutating = false, unsafeOperation = false) { + /** + * Builds a reversed query for the given nested path. + */ + buildReversedQuery(context: NestedWriteVisitorContext, forMutationPayload = false, unsafeOperation = false) { let result, currQuery: any; let currField: FieldInfo | undefined; @@ -87,7 +90,7 @@ export class QueryUtils { throw this.unknownError(`missing backLink field ${currField.backLink} in ${currField.type}`); } - if (backLinkField.isArray && !mutating) { + if (backLinkField.isArray && !forMutationPayload) { // many-side of relationship, wrap with "some" query currQuery[currField.backLink] = { some: { ...visitWhere } }; currQuery = currQuery[currField.backLink].some; @@ -97,7 +100,7 @@ export class QueryUtils { // calculate if we should preserve the relation condition (e.g., { user: { id: 1 } }) const shouldPreserveRelationCondition = // doing a mutation - mutating && + forMutationPayload && // and it's a safe mutate !unsafeOperation && // and the current segment is the direct parent (the last one is the mutate itself), @@ -119,6 +122,15 @@ export class QueryUtils { // preserve the original structure currQuery[currField.backLink] = { ...visitWhere }; } + + if (forMutationPayload && currQuery[currField.backLink]) { + // reconstruct compound unique field + currQuery[currField.backLink] = this.composeCompoundUniqueField( + backLinkField.type, + currQuery[currField.backLink] + ); + } + currQuery = currQuery[currField.backLink]; } currField = field; @@ -127,8 +139,33 @@ export class QueryUtils { return result; } + /** + * Composes a compound unique field from multiple fields. E.g.: { a: '1', b: '1' } => { a_b: { a: '1', b: '1' } }. + */ + composeCompoundUniqueField(model: string, fieldData: any) { + const uniqueConstraints = getUniqueConstraints(this.options.modelMeta, model); + if (!uniqueConstraints) { + return fieldData; + } + + const result: any = clone(fieldData); + for (const [name, constraint] of Object.entries(uniqueConstraints)) { + if (constraint.fields.length > 1 && constraint.fields.every((f) => fieldData[f] !== undefined)) { + // multi-field unique constraint, compose it + result[name] = constraint.fields.reduce( + (prev, field) => ({ ...prev, [field]: fieldData[field] }), + {} + ); + constraint.fields.forEach((f) => delete result[f]); + } + } + return result; + } + + /** + * Flattens a generated unique field. E.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' }. + */ flattenGeneratedUniqueField(model: string, args: any) { - // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } const uniqueConstraints = getUniqueConstraints(this.options.modelMeta, model); if (uniqueConstraints && Object.keys(uniqueConstraints).length > 0) { for (const [field, value] of Object.entries(args)) { diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index 92c8b7726..5cd23610e 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -1,5 +1,6 @@ +import deepcopy from 'deepcopy'; import safeJsonStringify from 'safe-json-stringify'; -import { FieldInfo, ModelMeta, resolveField } from '..'; +import { resolveField, type FieldInfo, type ModelMeta } from '..'; import type { DbClientContract } from '../types'; /** @@ -42,3 +43,8 @@ export function isUnsafeMutate(model: string, args: any, modelMeta: ModelMeta) { export function isAutoIncrementIdField(field: FieldInfo) { return field.isId && field.isAutoIncrement; } + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function clone(value: unknown): any { + return value ? deepcopy(value) : {}; +} diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 3dc0f3f1e..3072ab202 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -1,6 +1,7 @@ import { ArrayExpr, DataModel, + DataModelAttribute, DataModelField, isArrayExpr, isBooleanLiteral, @@ -344,10 +345,7 @@ function getAttributes(target: DataModelField | DataModel): RuntimeAttribute[] { function getUniqueConstraints(model: DataModel) { const constraints: Array<{ name: string; fields: string[] }> = []; - // model-level constraints - for (const attr of model.attributes.filter( - (attr) => attr.decl.ref?.name === '@@unique' || attr.decl.ref?.name === '@@id' - )) { + const extractConstraint = (attr: DataModelAttribute) => { const argsMap = getAttributeArgs(attr); if (argsMap.fields) { const fieldNames = (argsMap.fields as ArrayExpr).items.map( @@ -358,14 +356,45 @@ function getUniqueConstraints(model: DataModel) { // default constraint name is fields concatenated with underscores constraintName = fieldNames.join('_'); } - constraints.push({ name: constraintName, fields: fieldNames }); + return { name: constraintName, fields: fieldNames }; + } else { + return undefined; + } + }; + + const addConstraint = (constraint: { name: string; fields: string[] }) => { + if (!constraints.some((c) => c.name === constraint.name)) { + constraints.push(constraint); + } + }; + + // field-level @id first + for (const field of model.fields) { + if (hasAttribute(field, '@id')) { + addConstraint({ name: field.name, fields: [field.name] }); } } - // field-level constraints + // then model-level @@id + for (const attr of model.attributes.filter((attr) => attr.decl.ref?.name === '@@id')) { + const constraint = extractConstraint(attr); + if (constraint) { + addConstraint(constraint); + } + } + + // then field-level @unique for (const field of model.fields) { - if (hasAttribute(field, '@id') || hasAttribute(field, '@unique')) { - constraints.push({ name: field.name, fields: [field.name] }); + if (hasAttribute(field, '@unique')) { + addConstraint({ name: field.name, fields: [field.name] }); + } + } + + // then model-level @@unique + for (const attr of model.attributes.filter((attr) => attr.decl.ref?.name === '@@unique')) { + const constraint = extractConstraint(attr); + if (constraint) { + addConstraint(constraint); } } diff --git a/tests/integration/tests/enhancements/with-policy/multi-id-fields.test.ts b/tests/integration/tests/enhancements/with-policy/multi-id-fields.test.ts index 227dc5a27..0abb45559 100644 --- a/tests/integration/tests/enhancements/with-policy/multi-id-fields.test.ts +++ b/tests/integration/tests/enhancements/with-policy/multi-id-fields.test.ts @@ -12,7 +12,7 @@ describe('With Policy: multiple id fields', () => { process.chdir(origDir); }); - it('multi-id fields', async () => { + it('multi-id fields crud', async () => { const { prisma, enhance } = await loadSchema( ` model A { @@ -69,6 +69,75 @@ describe('With Policy: multiple id fields', () => { ).toResolveTruthy(); }); + it('multi-id fields id update', async () => { + const { prisma, enhance } = await loadSchema( + ` + model A { + x String + y Int + value Int + b B? + @@id([x, y]) + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 0 && future().value > 1) + } + + model B { + b1 String + b2 String + value Int + a A @relation(fields: [ax, ay], references: [x, y]) + ax String + ay Int + + @@allow('read', value > 2) + @@allow('create', value > 1) + + @@unique([ax, ay]) + @@id([b1, b2]) + } + ` + ); + + const db = enhance(); + + await db.a.create({ data: { x: '1', y: 2, value: 1 } }); + + await expect( + db.a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 0 } }) + ).toBeRejectedByPolicy(); + + await expect( + db.a.update({ where: { x_y: { x: '1', y: 2 } }, data: { x: '2', y: 3, value: 2 } }) + ).resolves.toMatchObject({ + x: '2', + y: 3, + value: 2, + }); + + await expect( + db.a.upsert({ + where: { x_y: { x: '2', y: 3 } }, + update: { x: '3', y: 4, value: 0 }, + create: { x: '4', y: 5, value: 5 }, + }) + ).toBeRejectedByPolicy(); + + await expect( + db.a.upsert({ + where: { x_y: { x: '2', y: 3 } }, + update: { x: '3', y: 4, value: 3 }, + create: { x: '4', y: 5, value: 5 }, + }) + ).resolves.toMatchObject({ + x: '3', + y: 4, + value: 3, + }); + }); + it('multi-id auth', async () => { const { prisma, enhance } = await loadSchema( ` @@ -270,4 +339,79 @@ describe('With Policy: multiple id fields', () => { expect(await db.b.findUnique({ where: { id: 1 } })).toEqual(expect.objectContaining({ v: 5 })); expect(await db.c.findUnique({ where: { id: 1 } })).toEqual(expect.objectContaining({ v: 6 })); }); + + it('multi-id fields nested id update', async () => { + const { enhance } = await loadSchema( + ` + model A { + x String + y Int + value Int + b B @relation(fields: [bId], references: [id]) + bId Int + @@id([x, y]) + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 0 && future().value > 1) + } + + model B { + id Int @id @default(autoincrement()) + a A[] + @@allow('all', true) + } + ` + ); + + const db = enhance(); + + await db.b.create({ data: { id: 1, a: { create: { x: '1', y: 1, value: 1 } } } }); + + await expect( + db.b.update({ + where: { id: 1 }, + data: { a: { update: { where: { x_y: { x: '1', y: 1 } }, data: { x: '2', y: 2, value: 0 } } } }, + }) + ).toBeRejectedByPolicy(); + + await expect( + db.b.update({ + where: { id: 1 }, + data: { a: { update: { where: { x_y: { x: '1', y: 1 } }, data: { x: '2', y: 2, value: 2 } } } }, + include: { a: true }, + }) + ).resolves.toMatchObject({ a: expect.arrayContaining([expect.objectContaining({ x: '2', y: 2, value: 2 })]) }); + + await expect( + db.b.update({ + where: { id: 1 }, + data: { + a: { + upsert: { + where: { x_y: { x: '2', y: 2 } }, + update: { x: '3', y: 3, value: 0 }, + create: { x: '4', y: '4', value: 4 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + await expect( + db.b.update({ + where: { id: 1 }, + data: { + a: { + upsert: { + where: { x_y: { x: '2', y: 2 } }, + update: { x: '3', y: 3, value: 3 }, + create: { x: '4', y: '4', value: 4 }, + }, + }, + }, + include: { a: true }, + }) + ).resolves.toMatchObject({ a: expect.arrayContaining([expect.objectContaining({ x: '3', y: 3, value: 3 })]) }); + }); }); diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts index 777af1118..5cafbb361 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts @@ -284,6 +284,101 @@ describe('With Policy:nested to-many', () => { expect(r.m2).toEqual(expect.arrayContaining([expect.objectContaining({ id: '2', value: 3 })])); }); + it('update id field', async () => { + const { withPolicy } = await loadSchema( + ` + model M1 { + id String @id @default(uuid()) + m2 M2[] + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String + + @@allow('read', true) + @@allow('create', true) + @@allow('update', value > 1 && future().value > 2) + } + ` + ); + + const db = withPolicy(); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: { id: '1', value: 2 }, + }, + }, + }); + + await expect( + db.m1.update({ + where: { id: '1' }, + include: { m2: true }, + data: { + m2: { + update: { + where: { id: '1' }, + data: { id: '2', value: 1 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + let r = await db.m1.update({ + where: { id: '1' }, + include: { m2: true }, + data: { + m2: { + update: { + where: { id: '1' }, + data: { id: '2', value: 3 }, + }, + }, + }, + }); + expect(r.m2).toEqual(expect.arrayContaining([expect.objectContaining({ id: '2', value: 3 })])); + + await expect( + db.m1.update({ + where: { id: '1' }, + include: { m2: true }, + data: { + m2: { + upsert: { + where: { id: '2' }, + create: { id: '4', value: 4 }, + update: { id: '3', value: 1 }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + r = await db.m1.update({ + where: { id: '1' }, + include: { m2: true }, + data: { + m2: { + upsert: { + where: { id: '2' }, + create: { id: '4', value: 4 }, + update: { id: '3', value: 4 }, + }, + }, + }, + }); + expect(r.m2).toEqual(expect.arrayContaining([expect.objectContaining({ id: '3', value: 4 })])); + }); + it('update with create from one to many', async () => { const { enhance } = await loadSchema( ` diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts index 4b30c095f..7829ff415 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts @@ -212,6 +212,64 @@ describe('With Policy:nested to-one', () => { ).toBeRejectedByPolicy(); }); + it('nested update id tests', async () => { + const { withPolicy } = await loadSchema( + ` + model M1 { + id String @id @default(uuid()) + m2 M2? + + @@allow('all', true) + } + + model M2 { + id String @id @default(uuid()) + value Int + m1 M1 @relation(fields: [m1Id], references:[id]) + m1Id String @unique + + @@allow('read', true) + @@allow('create', value > 0) + @@allow('update', value > 1 && future().value > 2) + } + ` + ); + + const db = withPolicy(); + + await db.m1.create({ + data: { + id: '1', + m2: { + create: { id: '1', value: 2 }, + }, + }, + }); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { id: '2', value: 1 }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + await expect( + db.m1.update({ + where: { id: '1' }, + data: { + m2: { + update: { id: '2', value: 3 }, + }, + }, + include: { m2: true }, + }) + ).resolves.toMatchObject({ m2: expect.objectContaining({ id: '2', value: 3 }) }); + }); + it('nested create', async () => { const { enhance } = await loadSchema( ` diff --git a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts index 61f25dc25..da920c064 100644 --- a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts +++ b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts @@ -147,6 +147,82 @@ describe('With Policy: toplevel operations', () => { ).toBeTruthy(); }); + it('update id tests', async () => { + const { withPolicy } = await loadSchema( + ` + model Model { + id String @id @default(uuid()) + value Int + + @@allow('read', value > 1) + @@allow('create', value > 0) + @@allow('update', value > 1 && future().value > 2) + } + ` + ); + + const db = withPolicy(); + + await db.model.create({ + data: { + id: '1', + value: 2, + }, + }); + + // update denied + await expect( + db.model.update({ + where: { id: '1' }, + data: { + id: '2', + value: 1, + }, + }) + ).toBeRejectedByPolicy(); + + // update success + await expect( + db.model.update({ + where: { id: '1' }, + data: { + id: '2', + value: 3, + }, + }) + ).resolves.toMatchObject({ id: '2', value: 3 }); + + // upsert denied + await expect( + db.model.upsert({ + where: { id: '2' }, + update: { + id: '3', + value: 1, + }, + create: { + id: '4', + value: 5, + }, + }) + ).toBeRejectedByPolicy(); + + // upsert success + await expect( + db.model.upsert({ + where: { id: '2' }, + update: { + id: '3', + value: 4, + }, + create: { + id: '4', + value: 5, + }, + }) + ).resolves.toMatchObject({ id: '3', value: 4 }); + }); + it('delete tests', async () => { const { enhance, prisma } = await loadSchema( ` diff --git a/tests/integration/tests/regression/issue-1271.test.ts b/tests/integration/tests/regression/issue-1271.test.ts new file mode 100644 index 000000000..d25cabb3b --- /dev/null +++ b/tests/integration/tests/regression/issue-1271.test.ts @@ -0,0 +1,192 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1271', () => { + it('regression', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id @default(uuid()) + + @@auth + @@allow('all', true) + } + + model Test { + id String @id @default(uuid()) + linkingTable LinkingTable[] + key String @default('test') + locale String @default('EN') + + @@unique([key, locale]) + @@allow("all", true) + } + + model LinkingTable { + test_id String + test Test @relation(fields: [test_id], references: [id]) + + another_test_id String + another_test AnotherTest @relation(fields: [another_test_id], references: [id]) + + @@id([test_id, another_test_id]) + @@allow("all", true) + } + + model AnotherTest { + id String @id @default(uuid()) + status String + linkingTable LinkingTable[] + + @@allow("all", true) + } + `, + { logPrismaQuery: true } + ); + + const db = enhance(); + + const test = await db.test.create({ + data: { + key: 'test1', + }, + }); + const anotherTest = await db.anotherTest.create({ + data: { + status: 'available', + }, + }); + + const updated = await db.test.upsert({ + where: { + key_locale: { + key: test.key, + locale: test.locale, + }, + }, + create: { + linkingTable: { + create: { + another_test_id: anotherTest.id, + }, + }, + }, + update: { + linkingTable: { + create: { + another_test_id: anotherTest.id, + }, + }, + }, + include: { + linkingTable: true, + }, + }); + + expect(updated.linkingTable).toHaveLength(1); + expect(updated.linkingTable[0]).toMatchObject({ another_test_id: anotherTest.id }); + + const test2 = await db.test.upsert({ + where: { + key_locale: { + key: 'test2', + locale: 'locale2', + }, + }, + create: { + key: 'test2', + locale: 'locale2', + linkingTable: { + create: { + another_test_id: anotherTest.id, + }, + }, + }, + update: { + linkingTable: { + create: { + another_test_id: anotherTest.id, + }, + }, + }, + include: { + linkingTable: true, + }, + }); + expect(test2).toMatchObject({ key: 'test2', locale: 'locale2' }); + expect(test2.linkingTable).toHaveLength(1); + expect(test2.linkingTable[0]).toMatchObject({ another_test_id: anotherTest.id }); + + const linkingTable = test2.linkingTable[0]; + + // connectOrCreate: connect case + const test3 = await db.test.create({ + data: { + key: 'test3', + locale: 'locale3', + }, + }); + console.log('test3 created:', test3); + const updated2 = await db.linkingTable.update({ + where: { + test_id_another_test_id: { + test_id: linkingTable.test_id, + another_test_id: linkingTable.another_test_id, + }, + }, + data: { + test: { + connectOrCreate: { + where: { + key_locale: { + key: test3.key, + locale: test3.locale, + }, + }, + create: { + key: 'test4', + locale: 'locale4', + }, + }, + }, + another_test: { connect: { id: anotherTest.id } }, + }, + include: { test: true }, + }); + expect(updated2).toMatchObject({ + test: expect.objectContaining({ key: 'test3', locale: 'locale3' }), + another_test_id: anotherTest.id, + }); + + // connectOrCreate: create case + const updated3 = await db.linkingTable.update({ + where: { + test_id_another_test_id: { + test_id: updated2.test_id, + another_test_id: updated2.another_test_id, + }, + }, + data: { + test: { + connectOrCreate: { + where: { + key_locale: { + key: 'test4', + locale: 'locale4', + }, + }, + create: { + key: 'test4', + locale: 'locale4', + }, + }, + }, + another_test: { connect: { id: anotherTest.id } }, + }, + include: { test: true }, + }); + expect(updated3).toMatchObject({ + test: expect.objectContaining({ key: 'test4', locale: 'locale4' }), + another_test_id: anotherTest.id, + }); + }); +}); From 58cd94872e58f8ed35c2a263568e8ae4e6ac468a Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 23 Apr 2024 21:06:52 +0800 Subject: [PATCH 2/2] fix tests --- .../tests/enhancements/with-policy/nested-to-many.test.ts | 4 ++-- .../tests/enhancements/with-policy/nested-to-one.test.ts | 4 ++-- .../enhancements/with-policy/toplevel-operations.test.ts | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts index 5cafbb361..01d2c36e2 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts @@ -285,7 +285,7 @@ describe('With Policy:nested to-many', () => { }); it('update id field', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -307,7 +307,7 @@ describe('With Policy:nested to-many', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts index 7829ff415..e215a917b 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts @@ -213,7 +213,7 @@ describe('With Policy:nested to-one', () => { }); it('nested update id tests', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model M1 { id String @id @default(uuid()) @@ -235,7 +235,7 @@ describe('With Policy:nested to-one', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.m1.create({ data: { diff --git a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts index da920c064..3543dd7b5 100644 --- a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts +++ b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts @@ -148,7 +148,7 @@ describe('With Policy: toplevel operations', () => { }); it('update id tests', async () => { - const { withPolicy } = await loadSchema( + const { enhance } = await loadSchema( ` model Model { id String @id @default(uuid()) @@ -161,7 +161,7 @@ describe('With Policy: toplevel operations', () => { ` ); - const db = withPolicy(); + const db = enhance(); await db.model.create({ data: {