Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/db/migrations/0035_add_job_id_to_runs.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "agent_runs" ADD COLUMN "job_id" TEXT;
7 changes: 7 additions & 0 deletions src/db/migrations/meta/_journal.json
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@
"when": 1769000000000,
"tag": "0034_remove_subscription_cost_zero",
"breakpoints": false
},
{
"idx": 35,
"version": "7",
"when": 1770000000000,
"tag": "0035_add_job_id_to_runs",
"breakpoints": false
}
]
}
15 changes: 15 additions & 0 deletions src/db/repositories/runsRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ const enrichedRunSelect = {
error: agentRuns.error,
prUrl: agentRuns.prUrl,
outputSummary: agentRuns.outputSummary,
jobId: agentRuns.jobId,
workItemUrl: prWorkItems.workItemUrl,
workItemTitle: prWorkItems.workItemTitle,
prTitle: prWorkItems.prTitle,
Expand Down Expand Up @@ -127,6 +128,20 @@ export async function updateRunPRNumber(runId: string, prNumber: number): Promis
.where(and(eq(agentRuns.id, runId), isNull(agentRuns.prNumber)));
}

export async function updateRunJobId(runId: string, jobId: string): Promise<void> {
const db = getDb();
await db.update(agentRuns).set({ jobId }).where(eq(agentRuns.id, runId));
}

export async function getRunJobId(runId: string): Promise<string | null> {
const db = getDb();
const [row] = await db
.select({ jobId: agentRuns.jobId })
.from(agentRuns)
.where(eq(agentRuns.id, runId));
return row?.jobId ?? null;
}

export async function completeRun(runId: string, input: CompleteRunInput): Promise<void> {
const db = getDb();
await db
Expand Down
1 change: 1 addition & 0 deletions src/db/schema/runs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export const agentRuns = pgTable(
error: text('error'),
prUrl: text('pr_url'),
outputSummary: text('output_summary'),
jobId: text('job_id'),
},
(table) => [
index('idx_agent_runs_project_id').on(table.projectId),
Expand Down
93 changes: 93 additions & 0 deletions src/queue/cancel.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/**
* Redis pub/sub module for cancel command distribution.
*
* Provides a mechanism for the Dashboard to publish cancel commands that the Router
* receives and uses to terminate running agent jobs.
*/

import { Redis } from 'ioredis';

// ── Types ────────────────────────────────────────────────────────────────

export interface CancelCommandPayload {
runId: string;
reason: string;
}

type CancelCommandHandler = (payload: CancelCommandPayload) => Promise<void>;

// ── Channel ──────────────────────────────────────────────────────────────

const CANCEL_CHANNEL = 'cascade:cancel';

// ── Instance caching ────────────────────────────────────────────────────

let publisherInstance: Redis | null = null;
let subscriberInstance: Redis | null = null;

function getPublisher(): Redis {
if (!publisherInstance) {
const redisUrl = process.env.REDIS_URL;
if (!redisUrl) {
throw new Error('REDIS_URL is required for cancel pub/sub');
}
publisherInstance = new Redis(redisUrl);
}
return publisherInstance;
}

function getSubscriber(): Redis {
if (!subscriberInstance) {
const redisUrl = process.env.REDIS_URL;
if (!redisUrl) {
throw new Error('REDIS_URL is required for cancel pub/sub');
}
subscriberInstance = new Redis(redisUrl);
}
return subscriberInstance;
}

// ── Publish ──────────────────────────────────────────────────────────────

/**
* Publish a cancel command to the cascade:cancel channel.
*
* The Router process subscribes to this channel and uses the runId to
* identify and terminate the corresponding job.
*
* @param runId - The agent run ID to cancel
* @param reason - Human-readable reason for cancellation (e.g., "user requested", "timeout")
*/
export async function publishCancelCommand(runId: string, reason: string): Promise<void> {
const publisher = getPublisher();
const payload: CancelCommandPayload = { runId, reason };
await publisher.publish(CANCEL_CHANNEL, JSON.stringify(payload));
}

// ── Subscribe ────────────────────────────────────────────────────────────

/**
* Subscribe to cancel commands from the cascade:cancel channel.
*
* Invokes the handler callback for each cancel command received.
* The handler should look up the run's jobId from the database and
* use it to kill the job in BullMQ.
*
* @param handler - Callback function invoked with each cancel payload
*/
export async function subscribeToCancelCommands(handler: CancelCommandHandler): Promise<void> {
const subscriber = getSubscriber();

subscriber.on('message', async (channel, message) => {
if (channel === CANCEL_CHANNEL) {
try {
const payload = JSON.parse(message) as CancelCommandPayload;
await handler(payload);
} catch (error) {
console.error('[cancel] Failed to handle cancel command:', error);
}
}
});

await subscriber.subscribe(CANCEL_CHANNEL);
}
156 changes: 156 additions & 0 deletions tests/unit/db/runsRepository-jobId.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/**
* Unit tests for jobId-related functions in src/db/repositories/runsRepository.ts
*
* Tests updateRunJobId and getRunJobId functions.
*/
import { beforeEach, describe, expect, it, vi } from 'vitest';

// Mock the database client
const mockUpdate = vi.fn();
const mockSelect = vi.fn();
const mockSet = vi.fn();
const mockWhere = vi.fn();
const mockFrom = vi.fn();

vi.mock('../../../src/db/client.js', () => ({
getDb: () => ({
update: mockUpdate,
select: mockSelect,
}),
}));

vi.mock('../../../src/db/schema/index.js', () => ({
agentRuns: {
id: 'id',
jobId: 'job_id',
projectId: 'project_id',
workItemId: 'work_item_id',
agentType: 'agent_type',
status: 'status',
startedAt: 'started_at',
prNumber: 'pr_number',
durationMs: 'duration_ms',
costUsd: 'cost_usd',
engine: 'engine',
triggerType: 'trigger_type',
model: 'model',
maxIterations: 'max_iterations',
completedAt: 'completed_at',
llmIterations: 'llm_iterations',
gadgetCalls: 'gadget_calls',
success: 'success',
error: 'error',
prUrl: 'pr_url',
outputSummary: 'output_summary',
},
prWorkItems: {
projectId: 'project_id',
prNumber: 'pr_number',
workItemUrl: 'work_item_url',
workItemTitle: 'work_item_title',
prTitle: 'pr_title',
},
agentRunLogs: { runId: 'run_id' },
agentRunLlmCalls: {
runId: 'run_id',
callNumber: 'call_number',
id: 'id',
},
debugAnalyses: { id: 'id' },
projects: { id: 'id', orgId: 'org_id', name: 'name' },
organizations: { id: 'id', name: 'name' },
}));

vi.mock('../../../src/db/repositories/joinHelpers.js', () => ({
buildAgentRunWorkItemJoin: () => 'mock-join-condition',
}));

import { getRunJobId, updateRunJobId } from '../../../src/db/repositories/runsRepository.js';

describe('updateRunJobId', () => {
beforeEach(() => {
vi.resetAllMocks();

// Set up chained mock returns for update
mockUpdate.mockReturnValue({ set: mockSet });
mockSet.mockReturnValue({ where: mockWhere });
mockWhere.mockResolvedValue(undefined);
});

it('updates the job_id column for a given run', async () => {
const runId = 'run-123';
const jobId = 'job-456';

await updateRunJobId(runId, jobId);

expect(mockUpdate).toHaveBeenCalled();
expect(mockSet).toHaveBeenCalledWith({ jobId });
expect(mockWhere).toHaveBeenCalled();
});

it('handles multiple jobId updates independently', async () => {
await updateRunJobId('run-1', 'job-1');
await updateRunJobId('run-2', 'job-2');

expect(mockUpdate).toHaveBeenCalledTimes(2);
expect(mockSet).toHaveBeenNthCalledWith(1, { jobId: 'job-1' });
expect(mockSet).toHaveBeenNthCalledWith(2, { jobId: 'job-2' });
});
});

describe('getRunJobId', () => {
beforeEach(() => {
vi.resetAllMocks();

// Set up chained mock returns for select
mockSelect.mockReturnValue({ from: mockFrom });
mockFrom.mockReturnValue({ where: mockWhere });
});

it('returns the job_id for a given run', async () => {
const jobId = 'job-789';
mockWhere.mockResolvedValue([{ jobId }]);

const result = await getRunJobId('run-123');

expect(result).toBe(jobId);
expect(mockSelect).toHaveBeenCalled();
expect(mockWhere).toHaveBeenCalled();
});

it('returns null when no job_id is found', async () => {
mockWhere.mockResolvedValue([]);

const result = await getRunJobId('run-nonexistent');

expect(result).toBeNull();
});

it('returns null when the jobId field is null in the database', async () => {
mockWhere.mockResolvedValue([{ jobId: null }]);

const result = await getRunJobId('run-123');

expect(result).toBeNull();
});

it('returns null when the row has no jobId property', async () => {
mockWhere.mockResolvedValue([{}]);

const result = await getRunJobId('run-123');

expect(result).toBeNull();
});

it('handles multiple getRunJobId calls independently', async () => {
mockWhere.mockResolvedValueOnce([{ jobId: 'job-1' }]);
mockWhere.mockResolvedValueOnce([{ jobId: 'job-2' }]);

const result1 = await getRunJobId('run-1');
const result2 = await getRunJobId('run-2');

expect(result1).toBe('job-1');
expect(result2).toBe('job-2');
expect(mockSelect).toHaveBeenCalledTimes(2);
});
});
1 change: 1 addition & 0 deletions tests/unit/db/runsRepository.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ vi.mock('../../../src/db/schema/index.js', () => ({
error: 'error',
prUrl: 'pr_url',
outputSummary: 'output_summary',
jobId: 'job_id',
},
agentRunLogs: { runId: 'run_id' },
agentRunLlmCalls: {
Expand Down
Loading
Loading