Skip to content

Commit d27ac2a

Browse files
wip
1 parent b92454c commit d27ac2a

File tree

6 files changed

+467
-396
lines changed

6 files changed

+467
-396
lines changed

e2e/react-start/server-functions/src/routes/index.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ function Home() {
9191
<li>
9292
<Link to="/server-only-fn">
9393
Server Function only called by Server Environment is kept in the
94-
server build</Link>
94+
server build
95+
</Link>
9596
</li>
9697
<li>
9798
<Link to="/middleware/unhandled-exception">

packages/start-client-core/src/createServerFn.ts

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { isRedirect, parseRedirect } from '@tanstack/router-core'
44
import { TSS_SERVER_FUNCTION_FACTORY } from './constants'
55
import { getStartOptions } from './getStartOptions'
66
import { getStartContextServerOnly } from './getStartContextServerOnly'
7+
import { createNullProtoObject, safeObjectMerge } from './safeObjectMerge'
78
import type {
89
AnyValidator,
910
Constrain,
@@ -118,7 +119,7 @@ export const createServerFn: CreateServerFn<Register> = (options, __opts) => {
118119
data: opts?.data as any,
119120
headers: opts?.headers,
120121
signal: opts?.signal,
121-
context: {},
122+
context: createNullProtoObject(),
122123
})
123124

124125
const redirect = parseRedirect(result.error)
@@ -138,13 +139,14 @@ export const createServerFn: CreateServerFn<Register> = (options, __opts) => {
138139
const startContext = getStartContextServerOnly()
139140
const serverContextAfterGlobalMiddlewares =
140141
startContext.contextAfterGlobalMiddlewares
142+
// Use safeObjectMerge for opts.context which comes from client
141143
const ctx = {
142144
...extractedFn,
143145
...opts,
144-
context: {
145-
...serverContextAfterGlobalMiddlewares,
146-
...opts.context,
147-
},
146+
context: safeObjectMerge(
147+
serverContextAfterGlobalMiddlewares,
148+
opts.context,
149+
),
148150
signal,
149151
request: startContext.request,
150152
}
@@ -239,17 +241,12 @@ export async function executeMiddleware(
239241
userCtx: ServerFnMiddlewareResult | undefined = {} as any,
240242
) => {
241243
// Return the next middleware
244+
// Use safeObjectMerge for context objects to prevent prototype pollution
242245
const nextCtx = {
243246
...ctx,
244247
...userCtx,
245-
context: {
246-
...ctx.context,
247-
...userCtx.context,
248-
},
249-
sendContext: {
250-
...ctx.sendContext,
251-
...(userCtx.sendContext ?? {}),
252-
},
248+
context: safeObjectMerge(ctx.context, userCtx.context),
249+
sendContext: safeObjectMerge(ctx.sendContext, userCtx.sendContext),
253250
headers: mergeHeaders(ctx.headers, userCtx.headers),
254251
result:
255252
userCtx.result !== undefined
@@ -315,7 +312,7 @@ export async function executeMiddleware(
315312
...opts,
316313
headers: opts.headers || {},
317314
sendContext: opts.sendContext || {},
318-
context: opts.context || {},
315+
context: opts.context || createNullProtoObject(),
319316
})
320317
}
321318

@@ -652,18 +649,21 @@ export interface ServerFnTypes<
652649
allOutput: IntersectAllValidatorOutputs<TMiddlewares, TInputValidator>
653650
}
654651

655-
export function flattenMiddlewares(
656-
middlewares: Array<AnyFunctionMiddleware | AnyRequestMiddleware>,
657-
): Array<AnyFunctionMiddleware | AnyRequestMiddleware> {
658-
const seen = new Set<AnyFunctionMiddleware | AnyRequestMiddleware>()
659-
const flattened: Array<AnyFunctionMiddleware | AnyRequestMiddleware> = []
652+
export function flattenMiddlewares<
653+
T extends AnyFunctionMiddleware | AnyRequestMiddleware,
654+
>(middlewares: Array<T>, maxDepth: number = 100): Array<T> {
655+
const seen = new Set<T>()
656+
const flattened: Array<T> = []
660657

661-
const recurse = (
662-
middleware: Array<AnyFunctionMiddleware | AnyRequestMiddleware>,
663-
) => {
658+
const recurse = (middleware: Array<T>, depth: number) => {
659+
if (depth > maxDepth) {
660+
throw new Error(
661+
`Middleware nesting depth exceeded maximum of ${maxDepth}. Check for circular references.`,
662+
)
663+
}
664664
middleware.forEach((m) => {
665665
if (m.options.middleware) {
666-
recurse(m.options.middleware)
666+
recurse(m.options.middleware as Array<T>, depth + 1)
667667
}
668668

669669
if (!seen.has(m)) {
@@ -673,7 +673,7 @@ export function flattenMiddlewares(
673673
})
674674
}
675675

676-
recurse(middlewares)
676+
recurse(middlewares, 0)
677677

678678
return flattened
679679
}

packages/start-client-core/src/index.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,4 @@ export type { Register } from '@tanstack/router-core'
100100
export { getRouterInstance } from './getRouterInstance'
101101
export { getDefaultSerovalPlugins } from './getDefaultSerovalPlugins'
102102
export { getGlobalStartContext } from './getGlobalStartContext'
103+
export { safeObjectMerge, createNullProtoObject } from './safeObjectMerge'
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
function isSafeKey(key: string): boolean {
2+
return key !== '__proto__' && key !== 'constructor' && key !== 'prototype'
3+
}
4+
5+
/**
6+
* Merge target and source into a new null-proto object, filtering dangerous keys.
7+
*/
8+
export function safeObjectMerge<T extends Record<string, unknown>>(
9+
target: T | undefined,
10+
source: Record<string, unknown> | null | undefined,
11+
): T {
12+
const result = Object.create(null) as T
13+
if (target) {
14+
for (const key of Object.keys(target)) {
15+
if (isSafeKey(key)) result[key as keyof T] = target[key] as T[keyof T]
16+
}
17+
}
18+
if (source && typeof source === 'object') {
19+
for (const key of Object.keys(source)) {
20+
if (isSafeKey(key)) result[key as keyof T] = source[key] as T[keyof T]
21+
}
22+
}
23+
return result
24+
}
25+
26+
/**
27+
* Create a null-prototype object, optionally copying from source.
28+
*/
29+
export function createNullProtoObject<T extends object>(
30+
source?: T,
31+
): { [K in keyof T]: T[K] } {
32+
if (!source) return Object.create(null)
33+
const obj = Object.create(null)
34+
for (const key of Object.keys(source)) {
35+
if (isSafeKey(key)) obj[key] = (source as Record<string, unknown>)[key]
36+
}
37+
return obj
38+
}

0 commit comments

Comments
 (0)