diff --git a/.changeset/olive-cars-run.md b/.changeset/olive-cars-run.md new file mode 100644 index 0000000000..7a90c3a79d --- /dev/null +++ b/.changeset/olive-cars-run.md @@ -0,0 +1,5 @@ +--- +"@workflow/core": patch +--- + +Ensure class serialization / deserialization only happens in the proper global context diff --git a/packages/core/src/class-serialization.ts b/packages/core/src/class-serialization.ts index f9fff1d06b..46f04c4bd8 100644 --- a/packages/core/src/class-serialization.ts +++ b/packages/core/src/class-serialization.ts @@ -51,25 +51,15 @@ export function registerSerializationClass(classId: string, cls: Function) { * Find a registered class constructor by ID (used during deserialization) * * @param classId - The class ID to look up - * @param global - The global object to check first. Defaults to globalThis. - * If the class is not found and `global` differs from `globalThis`, - * it will also check `globalThis` as a fallback. + * @param global - The global object to check. This ensures workflow code running + * in a VM only accesses classes registered on the VM's global, + * matching production serverless behavior where workflow code + * runs in isolation. */ export function getSerializationClass( classId: string, - global: Record = globalThis + global: Record // biome-ignore lint/complexity/noBannedTypes: We need to use Function to represent class constructors ): Function | undefined { - // Check the provided global first - const cls = getRegistry(global).get(classId); - if (cls) return cls; - - // Fallback: check globalThis if it differs from the provided global - // This handles the case where classes are registered in the host context - // but deserialization happens in a VM context - if (global !== globalThis) { - return getRegistry(globalThis).get(classId); - } - - return undefined; + return getRegistry(global).get(classId); } diff --git a/packages/core/src/serialization.test.ts b/packages/core/src/serialization.test.ts index 4274e32e8d..3f9f1f557d 100644 --- a/packages/core/src/serialization.test.ts +++ b/packages/core/src/serialization.test.ts @@ -1852,12 +1852,42 @@ describe('step function serialization', () => { }); describe('custom class serialization', () => { - const { globalThis: vmGlobalThis } = createContext({ + const { context, globalThis: vmGlobalThis } = createContext({ seed: 'test', fixedTimestamp: 1714857600000, }); + // Make the serialization symbols available inside the VM + (vmGlobalThis as any).WORKFLOW_SERIALIZE = WORKFLOW_SERIALIZE; + (vmGlobalThis as any).WORKFLOW_DESERIALIZE = WORKFLOW_DESERIALIZE; + + // Define registerSerializationClass inside the VM so that it uses the VM's globalThis. + // In production, the workflow bundle includes the full function code, so globalThis + // inside it refers to the VM's global. We simulate that here. + runInContext( + ` + const WORKFLOW_CLASS_REGISTRY = Symbol.for('workflow-class-registry'); + function registerSerializationClass(classId, cls) { + let registry = globalThis[WORKFLOW_CLASS_REGISTRY]; + if (!registry) { + registry = new Map(); + globalThis[WORKFLOW_CLASS_REGISTRY] = registry; + } + registry.set(classId, cls); + Object.defineProperty(cls, 'classId', { + value: classId, + writable: false, + enumerable: false, + configurable: false, + }); + } + globalThis.registerSerializationClass = registerSerializationClass; + `, + context + ); + it('should serialize and deserialize a class with WORKFLOW_SERIALIZE/DESERIALIZE', () => { + // Define the class in the host context (for serialization) class Point { constructor( public x: number, @@ -1876,9 +1906,31 @@ describe('custom class serialization', () => { // The classId is normally generated by the SWC compiler (Point as any).classId = 'test/Point'; - // Register the class for deserialization + // Register the class on the host for serialization registerSerializationClass('test/Point', Point); + // Define and register the class inside the VM (simulates workflow bundle) + // In production, the SWC plugin generates this code in the workflow bundle + runInContext( + ` + class Point { + constructor(x, y) { + this.x = x; + this.y = y; + } + static [WORKFLOW_SERIALIZE](instance) { + return { x: instance.x, y: instance.y }; + } + static [WORKFLOW_DESERIALIZE](data) { + return new Point(data.x, data.y); + } + } + Point.classId = 'test/Point'; + registerSerializationClass('test/Point', Point); + `, + context + ); + const point = new Point(10, 20); const serialized = dehydrateWorkflowArguments(point, [], mockRunId); @@ -1889,14 +1941,17 @@ describe('custom class serialization', () => { const serializedStr = new TextDecoder().decode(serialized); expect(serializedStr).toContain('test/Point'); - // Hydrate it back + // Hydrate it back (inside the VM context) const hydrated = hydrateWorkflowArguments(serialized, vmGlobalThis); - expect(hydrated).toBeInstanceOf(Point); + // Note: hydrated is an instance of the VM's Point class, not the host's + // so we check constructor.name instead of instanceof + expect(hydrated.constructor.name).toBe('Point'); expect(hydrated.x).toBe(10); expect(hydrated.y).toBe(20); }); it('should serialize nested custom serializable objects', () => { + // Define the class in the host context (for serialization) class Vector { constructor( public dx: number, @@ -1915,9 +1970,30 @@ describe('custom class serialization', () => { // The classId is normally generated by the SWC compiler (Vector as any).classId = 'test/Vector'; - // Register the class for deserialization + // Register the class on the host for serialization registerSerializationClass('test/Vector', Vector); + // Define and register the class inside the VM + runInContext( + ` + class Vector { + constructor(dx, dy) { + this.dx = dx; + this.dy = dy; + } + static [WORKFLOW_SERIALIZE](instance) { + return { dx: instance.dx, dy: instance.dy }; + } + static [WORKFLOW_DESERIALIZE](data) { + return new Vector(data.dx, data.dy); + } + } + Vector.classId = 'test/Vector'; + registerSerializationClass('test/Vector', Vector); + `, + context + ); + const data = { name: 'test', vector: new Vector(5, 10), @@ -1930,15 +2006,16 @@ describe('custom class serialization', () => { const hydrated = hydrateWorkflowArguments(serialized, vmGlobalThis); expect(hydrated.name).toBe('test'); - expect(hydrated.vector).toBeInstanceOf(Vector); + expect(hydrated.vector.constructor.name).toBe('Vector'); expect(hydrated.vector.dx).toBe(5); expect(hydrated.vector.dy).toBe(10); - expect(hydrated.nested.anotherVector).toBeInstanceOf(Vector); + expect(hydrated.nested.anotherVector.constructor.name).toBe('Vector'); expect(hydrated.nested.anotherVector.dx).toBe(1); expect(hydrated.nested.anotherVector.dy).toBe(2); }); it('should serialize custom class in an array', () => { + // Define the class in the host context (for serialization) class Item { constructor(public id: string) {} @@ -1954,9 +2031,29 @@ describe('custom class serialization', () => { // The classId is normally generated by the SWC compiler (Item as any).classId = 'test/Item'; - // Register the class for deserialization + // Register the class on the host for serialization registerSerializationClass('test/Item', Item); + // Define and register the class inside the VM + runInContext( + ` + class Item { + constructor(id) { + this.id = id; + } + static [WORKFLOW_SERIALIZE](instance) { + return { id: instance.id }; + } + static [WORKFLOW_DESERIALIZE](data) { + return new Item(data.id); + } + } + Item.classId = 'test/Item'; + registerSerializationClass('test/Item', Item); + `, + context + ); + const items = [new Item('a'), new Item('b'), new Item('c')]; const serialized = dehydrateWorkflowArguments(items, [], mockRunId); @@ -1964,11 +2061,11 @@ describe('custom class serialization', () => { expect(Array.isArray(hydrated)).toBe(true); expect(hydrated).toHaveLength(3); - expect(hydrated[0]).toBeInstanceOf(Item); + expect(hydrated[0].constructor.name).toBe('Item'); expect(hydrated[0].id).toBe('a'); - expect(hydrated[1]).toBeInstanceOf(Item); + expect(hydrated[1].constructor.name).toBe('Item'); expect(hydrated[1].id).toBe('b'); - expect(hydrated[2]).toBeInstanceOf(Item); + expect(hydrated[2].constructor.name).toBe('Item'); expect(hydrated[2].id).toBe('c'); }); diff --git a/packages/core/src/serialization.ts b/packages/core/src/serialization.ts index 98ae2abb20..e9d872bcb5 100644 --- a/packages/core/src/serialization.ts +++ b/packages/core/src/serialization.ts @@ -779,7 +779,9 @@ export function getCommonRevivers(global: Record = globalThis) { RegExp: (value) => new global.RegExp(value.source, value.flags), Class: (value) => { const classId = value.classId; - const cls = getSerializationClass(classId); + // Pass the global object to support VM contexts where classes are registered + // on the VM's global rather than the host's globalThis + const cls = getSerializationClass(classId, global); if (!cls) { throw new Error( `Class "${classId}" not found. Make sure the class is registered with registerSerializationClass.`