Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions packages/start-plugin-core/src/create-server-fn-plugin/compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,18 @@ type ExportEntry =

type Kind = 'None' | `Root` | `Builder` | LookupKind

type LookupKind = 'ServerFn' | 'Middleware'
export type LookupKind = 'ServerFn' | 'Middleware'

const LookupSetup: Record<
LookupKind,
{ candidateCallIdentifier: Set<string> }
> = {
ServerFn: { candidateCallIdentifier: new Set(['handler']) },
Middleware: {
candidateCallIdentifier: new Set(['server', 'client', 'createMiddlewares']),
},
}

const validLookupKinds: Array<LookupKind> = ['ServerFn', 'Middleware']
const candidateCallIdentifier = ['handler', 'server', 'client']
export type LookupConfig = {
libName: string
rootExport: string
Expand All @@ -48,14 +56,18 @@ interface ModuleInfo {
export class ServerFnCompiler {
private moduleCache = new Map<string, ModuleInfo>()
private initialized = false
private validLookupKinds: Set<LookupKind>
constructor(
private options: {
env: 'client' | 'server'
lookupConfigurations: Array<LookupConfig>
lookupKinds: Set<LookupKind>
loadModule: (id: string) => Promise<void>
resolveId: (id: string, importer?: string) => Promise<string | null>
},
) {}
) {
this.validLookupKinds = options.lookupKinds
}

private async init(id: string) {
await Promise.all(
Expand Down Expand Up @@ -207,7 +219,7 @@ export class ServerFnCompiler {
}> = []
for (const handler of candidates) {
const kind = await this.resolveExprKind(handler, id)
if (validLookupKinds.includes(kind as LookupKind)) {
if (this.validLookupKinds.has(kind as LookupKind)) {
toRewrite.push({ callExpression: handler, kind: kind as LookupKind })
}
}
Expand Down Expand Up @@ -261,7 +273,10 @@ export class ServerFnCompiler {

for (const binding of bindings.values()) {
if (binding.type === 'var') {
const handler = isCandidateCallExpression(binding.init)
const handler = isCandidateCallExpression(
binding.init,
this.validLookupKinds,
)
if (handler) {
candidates.push(handler)
}
Expand Down Expand Up @@ -368,7 +383,7 @@ export class ServerFnCompiler {
if (calleeKind === `Root` || calleeKind === `Builder`) {
return `Builder`
}
for (const kind of validLookupKinds) {
for (const kind of this.validLookupKinds) {
if (calleeKind === kind) {
return kind
}
Expand Down Expand Up @@ -407,16 +422,18 @@ export class ServerFnCompiler {
if (t.isMemberExpression(callee) && t.isIdentifier(callee.property)) {
const prop = callee.property.name

if (prop === 'handler') {
if (
this.validLookupKinds.has('ServerFn') &&
LookupSetup['ServerFn'].candidateCallIdentifier.has(prop)
) {
const base = await this.resolveExprKind(callee.object, fileId, visited)
if (base === 'Root' || base === 'Builder') {
return 'ServerFn'
}
return 'None'
} else if (
prop === 'client' ||
prop === 'server' ||
prop === 'createMiddleware'
this.validLookupKinds.has('Middleware') &&
LookupSetup['Middleware'].candidateCallIdentifier.has(prop)
) {
const base = await this.resolveExprKind(callee.object, fileId, visited)
if (base === 'Root' || base === 'Builder' || base === 'Middleware') {
Expand Down Expand Up @@ -483,16 +500,19 @@ export class ServerFnCompiler {

function isCandidateCallExpression(
node: t.Node | null | undefined,
lookupKinds: Set<LookupKind>,
): undefined | t.CallExpression {
if (!t.isCallExpression(node)) return undefined

const callee = node.callee
if (!t.isMemberExpression(callee) || !t.isIdentifier(callee.property)) {
return undefined
}
if (!candidateCallIdentifier.includes(callee.property.name)) {
return undefined
for (const kind of lookupKinds) {
if (LookupSetup[kind].candidateCallIdentifier.has(callee.property.name)) {
return node
}
}

return node
return undefined
}
56 changes: 39 additions & 17 deletions packages/start-plugin-core/src/create-server-fn-plugin/plugin.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { VITE_ENVIRONMENT_NAMES } from '../constants'
import { ServerFnCompiler } from './compiler'
import type { LookupConfig, LookupKind } from './compiler'
import type { CompileStartFrameworkOptions } from '../start-compiler-plugin/compilers'
import type { ViteEnvironmentNames } from '../constants'
import type { PluginOption } from 'vite'
Expand All @@ -8,6 +9,36 @@ function cleanId(id: string): string {
return id.split('?')[0]!
}

const LookupKindsPerEnv: Record<'client' | 'server', Set<LookupKind>> = {
client: new Set(['Middleware', 'ServerFn'] as const),
server: new Set(['ServerFn'] as const),
}

const getLookupConfigurationsForEnv = (
env: 'client' | 'server',
framework: CompileStartFrameworkOptions,
): Array<LookupConfig> => {
const createServerFnConfig: LookupConfig = {
libName: `@tanstack/${framework}-start`,
rootExport: 'createServerFn',
}
if (env === 'client') {
return [
{
libName: `@tanstack/${framework}-start`,
rootExport: 'createMiddleware',
},
{
libName: `@tanstack/${framework}-start`,
rootExport: 'createStart',
},

createServerFnConfig,
]
} else {
return [createServerFnConfig]
}
}
export function createServerFnPlugin(
framework: CompileStartFrameworkOptions,
): PluginOption {
Expand Down Expand Up @@ -52,8 +83,9 @@ export function createServerFnPlugin(
exclude: new RegExp(`${SERVER_FN_LOOKUP}$`),
},
code: {
// only scan files that mention `.handler(` | `.server(` | `.client(`
include: [/\.handler\(/, /\.server\(/, /\.client\(/],
// TODO apply this plugin with a different filter per environment so that .createMiddleware() calls are not scanned in server env
// only scan files that mention `.handler(` | `.createMiddleware()`
include: [/\.handler\(/, /.createMiddleware\(\)/],
Comment on lines +86 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix the createMiddleware filter before shipping

The new /.createMiddleware\(\)/ filter only matches a literal createMiddleware() call with empty parentheses. Real code always has arguments (createMiddleware({...}), start.createMiddleware(...)), so the transform hook never receives those modules and the middleware rewrite never runs. Please broaden the regex so any createMiddleware( call is caught.

-            include: [/\.handler\(/, /.createMiddleware\(\)/],
+            include: [/\.handler\(/, /createMiddleware\(/],
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// TODO apply this plugin with a different filter per environment so that .createMiddleware() calls are not scanned in server env
// only scan files that mention `.handler(` | `.createMiddleware()`
include: [/\.handler\(/, /.createMiddleware\(\)/],
// TODO apply this plugin with a different filter per environment so that .createMiddleware() calls are not scanned in server env
// only scan files that mention `.handler(` | `.createMiddleware()`
include: [/\.handler\(/, /createMiddleware\(/],
🤖 Prompt for AI Agents
In packages/start-plugin-core/src/create-server-fn-plugin/plugin.ts around lines
86 to 88, the include filter /\.createMiddleware\(\)/ only matches a literal
empty-argument call and misses real usages like createMiddleware({...}) or
start.createMiddleware(...); change that entry to a broader regex (e.g.
/createMiddleware\(/) so any call site containing "createMiddleware(" is matched
and the transform hook runs for those modules.

},
},
async handler(code, id) {
Expand All @@ -73,21 +105,11 @@ export function createServerFnPlugin(

compiler = new ServerFnCompiler({
env,
lookupConfigurations: [
{
libName: `@tanstack/${framework}-start`,
rootExport: 'createMiddleware',
},

{
libName: `@tanstack/${framework}-start`,
rootExport: 'createServerFn',
},
{
libName: `@tanstack/${framework}-start`,
rootExport: 'createStart',
},
],
lookupKinds: LookupKindsPerEnv[env],
lookupConfigurations: getLookupConfigurationsForEnv(
env,
framework,
),
loadModule: async (id: string) => {
if (this.environment.mode === 'build') {
const loaded = await this.load({ id })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
findReferencedIdentifiers,
} from 'babel-dead-code-elimination'
import { generateFromAst, parseAst } from '@tanstack/router-utils'
import { handleCreateMiddleware } from '../create-server-fn-plugin/handleCreateMiddleware'
import { transformFuncs } from './constants'
import { handleCreateIsomorphicFnCallExpression } from './isomorphicFn'
import {
Expand Down Expand Up @@ -38,6 +39,11 @@ export function compileStartOutputFactory(
handleCallExpression: handleCreateIsomorphicFnCallExpression,
paths: [],
},
createMiddleware: {
name: 'createMiddleware',
handleCallExpression: handleCreateMiddleware,
paths: [],
},
}

const ast = parseAst(opts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ export const transformFuncs = [
'createServerOnlyFn',
'createClientOnlyFn',
'createIsomorphicFn',
'createMiddleware',
] as const
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@ async function compile(opts: {
loadModule: async (id) => {
// do nothing in test
},
lookupKinds: new Set(['Middleware']),
lookupConfigurations: [
{
libName: `@tanstack/react-start`,
rootExport: 'createMiddleware',
},

{
libName: `@tanstack/react-start`,
rootExport: 'createServerFn',
},
{
libName: `@tanstack/react-start`,
rootExport: 'createStart',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import { readFile, readdir } from 'node:fs/promises'
import path from 'node:path'
import { describe, expect, test } from 'vitest'
import { compileStartOutputFactory } from '../../src/start-compiler-plugin/compilers'

const compileStartOutput = compileStartOutputFactory('react')

async function getFilenames() {
return await readdir(path.resolve(import.meta.dirname, './test-files'))
}

describe('createMiddleware compiles correctly', async () => {
const filenames = await getFilenames()

describe.each(filenames)('should handle "%s"', async (filename) => {
const file = await readFile(
path.resolve(import.meta.dirname, `./test-files/${filename}`),
)
const code = file.toString()

test.each(['client', 'server'] as const)(
`should compile for ${filename} %s`,
async (env) => {
const compiledResult = compileStartOutput({
env,
code,
filename,
dce: false,
})

await expect(compiledResult.code).toMatchFileSnapshot(
`./snapshots/${env}/${filename}`,
)
},
)
})
})
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { createMiddleware } from '@tanstack/react-start';
import { foo } from '@some/lib';
export const fnMw = createMiddleware({
type: 'function'
}).client(() => {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { createMiddleware } from '@tanstack/react-start';
import { z } from 'zod';
export const withUseServer = createMiddleware({
id: 'test'
});
export const withoutUseServer = createMiddleware({
id: 'test'
});
export const withVariable = createMiddleware({
id: 'test'
});
async function abstractedFunction() {
console.info('Fetching posts...');
await new Promise(r => setTimeout(r, 500));
return axios.get<Array<PostType>>('https://jsonplaceholder.typicode.com/posts').then(r => r.data.slice(0, 10));
}
function zodValidator<TSchema extends z.ZodSchema, TResult>(schema: TSchema, fn: (input: z.output<TSchema>) => TResult) {
return async (input: unknown) => {
return fn(schema.parse(input));
};
}
export const withZodValidator = createMiddleware({
id: 'test'
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import { createMiddleware as middlewareFn } from '@tanstack/react-start';
import { z } from 'zod';
export const withUseServer = middlewareFn({
id: 'test'
});
export const withoutUseServer = middlewareFn({
id: 'test'
});
export const withVariable = middlewareFn({
id: 'test'
});
async function abstractedFunction() {
console.info('Fetching posts...');
await new Promise(r => setTimeout(r, 500));
return axios.get<Array<PostType>>('https://jsonplaceholder.typicode.com/posts').then(r => r.data.slice(0, 10));
}
function zodValidator<TSchema extends z.ZodSchema, TResult>(schema: TSchema, fn: (input: z.output<TSchema>) => TResult) {
return async (input: unknown) => {
return fn(schema.parse(input));
};
}
export const withZodValidator = middlewareFn({
id: 'test'
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import * as TanStackStart from '@tanstack/react-start';
import { z } from 'zod';
export const withUseServer = TanStackStart.createMiddleware({
id: 'test'
});
export const withoutUseServer = TanStackStart.createMiddleware({
id: 'test'
});
export const withVariable = TanStackStart.createMiddleware({
id: 'test'
});
async function abstractedFunction() {
console.info('Fetching posts...');
await new Promise(r => setTimeout(r, 500));
return axios.get<Array<PostType>>('https://jsonplaceholder.typicode.com/posts').then(r => r.data.slice(0, 10));
}
function zodValidator<TSchema extends z.ZodSchema, TResult>(schema: TSchema, fn: (input: z.output<TSchema>) => TResult) {
return async (input: unknown) => {
return fn(schema.parse(input));
};
}
export const withZodValidator = TanStackStart.createMiddleware({
id: 'test'
});
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { createMiddleware } from '@tanstack/react-start';
import { z } from 'zod';
export const withUseServer = createMiddleware({
id: 'test'
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { createMiddleware } from '@tanstack/react-start';
import { getCookie } from '@tanstack/react-start/server';
interface AuthMiddlewareOptions {
allowUnauthenticated?: boolean;
}
interface AuthContext {
session: {
id: string;
} | null;
}
export const createAuthMiddleware = (opts: AuthMiddlewareOptions = {
allowUnauthenticated: false
}) => createMiddleware({
type: 'function'
});
Loading
Loading