Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 200 additions & 33 deletions packages/start-plugin-core/src/create-server-fn-plugin/compiler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ function needsDirectCallDetection(kinds: Set<LookupKind>): boolean {
return false
}

/**
* Checks if all kinds in the set are guaranteed to be top-level only.
* Only ServerFn is always declared at module level (must be assigned to a variable).
* Middleware, IsomorphicFn, ServerOnlyFn, ClientOnlyFn can be nested inside functions.
* When all kinds are top-level-only, we can use a fast scan instead of full traversal.
*/
function areAllKindsTopLevelOnly(kinds: Set<LookupKind>): boolean {
return kinds.size === 1 && kinds.has('ServerFn')
}

/**
* Checks if a CallExpression is a direct-call candidate for NESTED detection.
* Returns true if the callee is a known factory function name.
Expand Down Expand Up @@ -234,6 +244,11 @@ export class ServerFnCompiler {
private moduleCache = new Map<string, ModuleInfo>()
private initialized = false
private validLookupKinds: Set<LookupKind>
private resolveIdCache = new Map<string, string | null>()
private exportResolutionCache = new Map<
string,
Map<string, { moduleInfo: ModuleInfo; binding: Binding } | null>
>()
// Fast lookup for direct imports from known libraries (e.g., '@tanstack/react-start')
// Maps: libName → (exportName → Kind)
// This allows O(1) resolution for the common case without async resolveId calls
Expand All @@ -246,11 +261,44 @@ export class ServerFnCompiler {
lookupKinds: Set<LookupKind>
loadModule: (id: string) => Promise<void>
resolveId: (id: string, importer?: string) => Promise<string | null>
/**
* In 'build' mode, resolution results are cached for performance.
* In 'dev' mode (default), caching is disabled to avoid invalidation complexity with HMR.
*/
mode?: 'dev' | 'build'
},
) {
this.validLookupKinds = options.lookupKinds
}

private get mode(): 'dev' | 'build' {
return this.options.mode ?? 'dev'
}

private async resolveIdCached(id: string, importer?: string) {
if (this.mode === 'dev') {
return this.options.resolveId(id, importer)
}

const cacheKey = importer ? `${importer}::${id}` : id
const cached = this.resolveIdCache.get(cacheKey)
if (cached !== undefined) {
return cached
}
const resolved = await this.options.resolveId(id, importer)
this.resolveIdCache.set(cacheKey, resolved)
return resolved
}

private getExportResolutionCache(moduleId: string) {
let cache = this.exportResolutionCache.get(moduleId)
if (!cache) {
cache = new Map()
this.exportResolutionCache.set(moduleId, cache)
}
return cache
}

private async init() {
// Register internal stub package exports for recognition.
// These don't need module resolution - only the knownRootImports fast path.
Expand All @@ -274,7 +322,7 @@ export class ServerFnCompiler {
}
libExports.set(config.rootExport, config.kind)

const libId = await this.options.resolveId(config.libName)
const libId = await this.resolveIdCached(config.libName)
if (!libId) {
throw new Error(`could not resolve "${config.libName}"`)
}
Expand Down Expand Up @@ -311,9 +359,14 @@ export class ServerFnCompiler {
this.initialized = true
}

public ingestModule({ code, id }: { code: string; id: string }) {
const ast = parseAst({ code })

/**
* Extracts bindings and exports from an already-parsed AST.
* This is the core logic shared by ingestModule and ingestModuleFromAst.
*/
private extractModuleInfo(
ast: ReturnType<typeof parseAst>,
id: string,
): ModuleInfo {
const bindings = new Map<string, Binding>()
const exports = new Map<string, ExportEntry>()
const reExportAllSources: Array<string> = []
Expand Down Expand Up @@ -414,10 +467,19 @@ export class ServerFnCompiler {
reExportAllSources,
}
this.moduleCache.set(id, info)
return info
}

public ingestModule({ code, id }: { code: string; id: string }) {
const ast = parseAst({ code })
const info = this.extractModuleInfo(ast, id)
return { info, ast }
}

public invalidateModule(id: string) {
// Note: Resolution caches (resolveIdCache, exportResolutionCache) are only
// used in build mode where there's no HMR. In dev mode, caching is disabled,
// so we only need to invalidate the moduleCache here.
return this.moduleCache.delete(id)
}

Expand Down Expand Up @@ -448,7 +510,13 @@ export class ServerFnCompiler {
}

const checkDirectCalls = needsDirectCallDetection(fileKinds)
// Optimization: ServerFn is always a top-level declaration (must be assigned to a variable).
// If the file only has ServerFn, we can skip full AST traversal and only visit
// the specific top-level declarations that have candidates.
const canUseFastPath = areAllKindsTopLevelOnly(fileKinds)

// Always parse and extract module info upfront.
// This ensures the module is cached for import resolution even if no candidates are found.
const { ast } = this.ingestModule({ code, id })

// Single-pass traversal to:
Expand All @@ -462,38 +530,110 @@ export class ServerFnCompiler {
babel.NodePath<t.CallExpression>
>()

babel.traverse(ast, {
CallExpression: (path) => {
const node = path.node
const parent = path.parent
if (canUseFastPath) {
// Fast path: only visit top-level statements that have potential candidates

// Check if this call is part of a larger chain (inner call)
// If so, store it for method chain lookup but don't treat as candidate
if (
t.isMemberExpression(parent) &&
t.isCallExpression(path.parentPath.parent)
) {
// This is an inner call in a chain - store for later lookup
chainCallPaths.set(node, path)
return
// Collect indices of top-level statements that contain candidates
const candidateIndices: Array<number> = []
for (let i = 0; i < ast.program.body.length; i++) {
const node = ast.program.body[i]!
let declarations: Array<t.VariableDeclarator> | undefined

if (t.isVariableDeclaration(node)) {
declarations = node.declarations
} else if (t.isExportNamedDeclaration(node) && node.declaration) {
if (t.isVariableDeclaration(node.declaration)) {
declarations = node.declaration.declarations
}
}

// Pattern 1: Method chain pattern (.handler(), .server(), .client(), etc.)
if (isMethodChainCandidate(node, fileKinds)) {
candidatePaths.push(path)
return
if (declarations) {
for (const decl of declarations) {
if (decl.init && t.isCallExpression(decl.init)) {
if (isMethodChainCandidate(decl.init, fileKinds)) {
candidateIndices.push(i)
break // Only need to mark this statement once
}
}
}
}
}

// Pattern 2: Direct call pattern
if (checkDirectCalls) {
if (isTopLevelDirectCallCandidate(path)) {
candidatePaths.push(path)
} else if (isNestedDirectCallCandidate(node)) {
// Early exit: no potential candidates found at top level
if (candidateIndices.length === 0) {
return null
}

// Targeted traversal: only visit the specific statements that have candidates
// This is much faster than traversing the entire AST
babel.traverse(ast, {
Program(programPath) {
const bodyPaths = programPath.get('body')
for (const idx of candidateIndices) {
const stmtPath = bodyPaths[idx]
if (!stmtPath) continue

// Traverse only this statement's subtree
stmtPath.traverse({
CallExpression(path) {
const node = path.node
const parent = path.parent

// Check if this call is part of a larger chain (inner call)
if (
t.isMemberExpression(parent) &&
t.isCallExpression(path.parentPath.parent)
) {
chainCallPaths.set(node, path)
return
}

// Method chain pattern
if (isMethodChainCandidate(node, fileKinds)) {
candidatePaths.push(path)
}
},
})
}
// Stop traversal after processing Program
programPath.stop()
},
})
} else {
// Normal path: full traversal for non-fast-path kinds
babel.traverse(ast, {
CallExpression: (path) => {
const node = path.node
const parent = path.parent

// Check if this call is part of a larger chain (inner call)
// If so, store it for method chain lookup but don't treat as candidate
if (
t.isMemberExpression(parent) &&
t.isCallExpression(path.parentPath.parent)
) {
// This is an inner call in a chain - store for later lookup
chainCallPaths.set(node, path)
return
}

// Pattern 1: Method chain pattern (.handler(), .server(), .client(), etc.)
if (isMethodChainCandidate(node, fileKinds)) {
candidatePaths.push(path)
return
}
}
},
})

// Pattern 2: Direct call pattern
if (checkDirectCalls) {
if (isTopLevelDirectCallCandidate(path)) {
candidatePaths.push(path)
} else if (isNestedDirectCallCandidate(node)) {
candidatePaths.push(path)
}
}
},
})
}

if (candidatePaths.length === 0) {
return null
Expand Down Expand Up @@ -651,6 +791,19 @@ export class ServerFnCompiler {
exportName: string,
visitedModules = new Set<string>(),
): Promise<{ moduleInfo: ModuleInfo; binding: Binding } | undefined> {
const isBuildMode = this.mode === 'build'

// Check cache first (only for top-level calls in build mode)
if (isBuildMode && visitedModules.size === 0) {
const moduleCache = this.exportResolutionCache.get(moduleInfo.id)
if (moduleCache) {
const cached = moduleCache.get(exportName)
if (cached !== undefined) {
return cached ?? undefined
}
}
}

// Prevent infinite loops in circular re-exports
if (visitedModules.has(moduleInfo.id)) {
return undefined
Expand All @@ -662,7 +815,12 @@ export class ServerFnCompiler {
if (directExport) {
const binding = moduleInfo.bindings.get(directExport.name)
if (binding) {
return { moduleInfo, binding }
const result = { moduleInfo, binding }
// Cache the result (build mode only)
if (isBuildMode) {
this.getExportResolutionCache(moduleInfo.id).set(exportName, result)
}
return result
}
}

Expand All @@ -671,10 +829,11 @@ export class ServerFnCompiler {
if (moduleInfo.reExportAllSources.length > 0) {
const results = await Promise.all(
moduleInfo.reExportAllSources.map(async (reExportSource) => {
const reExportTarget = await this.options.resolveId(
const reExportTarget = await this.resolveIdCached(
reExportSource,
moduleInfo.id,
)

if (reExportTarget) {
const reExportModule = await this.getModuleInfo(reExportTarget)
return this.findExportInModule(
Expand All @@ -689,11 +848,19 @@ export class ServerFnCompiler {
// Return the first valid result
for (const result of results) {
if (result) {
// Cache the result (build mode only)
if (isBuildMode) {
this.getExportResolutionCache(moduleInfo.id).set(exportName, result)
}
return result
}
}
}

// Cache negative result (build mode only)
if (isBuildMode) {
this.getExportResolutionCache(moduleInfo.id).set(exportName, null)
}
return undefined
}

Expand All @@ -719,7 +886,7 @@ export class ServerFnCompiler {
}

// Slow path: resolve through the module graph
const target = await this.options.resolveId(binding.source, fileId)
const target = await this.resolveIdCached(binding.source, fileId)
if (!target) {
return 'None'
}
Expand Down Expand Up @@ -863,7 +1030,7 @@ export class ServerFnCompiler {
binding.importedName === '*'
) {
// resolve the property from the target module
const targetModuleId = await this.options.resolveId(
const targetModuleId = await this.resolveIdCached(
binding.source,
fileId,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ export function createServerFnPlugin(opts: {
async handler(code, id) {
let compiler = compilers[this.environment.name]
if (!compiler) {
// Default to 'dev' mode for unknown environments (conservative: no caching)
const mode =
this.environment.mode === 'build' ? 'build' : ('dev' as const)
compiler = new ServerFnCompiler({
env: environment.type,
directive: opts.directive,
Expand All @@ -126,6 +129,7 @@ export function createServerFnPlugin(opts: {
environment.type,
opts.framework,
),
mode,
loadModule: async (id: string) => {
if (this.environment.mode === 'build') {
const loaded = await this.load({ id })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ describe('createMiddleware compiles correctly', async () => {
// the fast path uses knownRootImports map for O(1) lookup
// Note: init() now resolves from project root, not from a specific file
expect(resolveIdMock).toHaveBeenCalledTimes(1)
expect(resolveIdMock).toHaveBeenCalledWith('@tanstack/react-start')
expect(resolveIdMock).toHaveBeenCalledWith(
'@tanstack/react-start',
undefined,
)
})

test('should use slow path for factory pattern (resolveId called for import resolution)', async () => {
Expand Down Expand Up @@ -149,7 +152,11 @@ describe('createMiddleware compiles correctly', async () => {
// Note: The factory module's import from '@tanstack/react-start' ALSO uses
// the fast path (knownRootImports), so no additional resolveId call is needed there.
expect(resolveIdMock).toHaveBeenCalledTimes(2)
expect(resolveIdMock).toHaveBeenNthCalledWith(1, '@tanstack/react-start')
expect(resolveIdMock).toHaveBeenNthCalledWith(
1,
'@tanstack/react-start',
undefined,
)
expect(resolveIdMock).toHaveBeenNthCalledWith(2, './factory', 'test.ts')
})
})
Loading
Loading