diff --git a/lib/mcp/tools/index.ts b/lib/mcp/tools/index.ts index 0b60c5f..3ddd45e 100644 --- a/lib/mcp/tools/index.ts +++ b/lib/mcp/tools/index.ts @@ -17,6 +17,7 @@ import { registerAllYouTubeTools } from "./youtube"; import { registerTranscribeTools } from "./transcribe"; import { registerSendEmailTool } from "./registerSendEmailTool"; import { registerAllArtistTools } from "./artists"; +import { registerAllPulseTools } from "./pulse"; /** * Registers all MCP tools on the server. @@ -33,6 +34,7 @@ export const registerAllTools = (server: McpServer): void => { registerAllCatalogTools(server); registerAllFileTools(server); registerAllImageTools(server); + registerAllPulseTools(server); registerAllSora2Tools(server); registerAllSpotifyTools(server); registerAllTaskTools(server); diff --git a/lib/mcp/tools/pulse/__tests__/registerGetPulseTool.test.ts b/lib/mcp/tools/pulse/__tests__/registerGetPulseTool.test.ts new file mode 100644 index 0000000..ceacbf6 --- /dev/null +++ b/lib/mcp/tools/pulse/__tests__/registerGetPulseTool.test.ts @@ -0,0 +1,158 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; + +import { registerGetPulseTool } from "../registerGetPulseTool"; + +const mockSelectPulseAccount = vi.fn(); +const mockCanAccessAccount = vi.fn(); + +vi.mock("@/lib/supabase/pulse_accounts/selectPulseAccount", () => ({ + selectPulseAccount: (...args: unknown[]) => mockSelectPulseAccount(...args), +})); + +vi.mock("@/lib/organizations/canAccessAccount", () => ({ + canAccessAccount: (...args: unknown[]) => mockCanAccessAccount(...args), +})); + +type ServerRequestHandlerExtra = RequestHandlerExtra; + +/** + * Creates a mock extra object with optional authInfo. + * + * @param authInfo + * @param authInfo.accountId + * @param authInfo.orgId + */ +function createMockExtra(authInfo?: { + accountId?: string; + orgId?: string | null; +}): ServerRequestHandlerExtra { + return { + authInfo: authInfo + ? { + token: "test-token", + scopes: ["mcp:tools"], + clientId: authInfo.accountId, + extra: { + accountId: authInfo.accountId, + orgId: authInfo.orgId ?? null, + }, + } + : undefined, + } as unknown as ServerRequestHandlerExtra; +} + +describe("registerGetPulseTool", () => { + let mockServer: McpServer; + let registeredHandler: (args: unknown, extra: ServerRequestHandlerExtra) => Promise; + + beforeEach(() => { + vi.clearAllMocks(); + + mockServer = { + registerTool: vi.fn((name, config, handler) => { + registeredHandler = handler; + }), + } as unknown as McpServer; + + registerGetPulseTool(mockServer); + }); + + it("registers the get_pulse tool", () => { + expect(mockServer.registerTool).toHaveBeenCalledWith( + "get_pulse", + expect.objectContaining({ + description: "Get the pulse status for an account.", + }), + expect.any(Function), + ); + }); + + it("returns pulse with active: false when no record exists", async () => { + mockSelectPulseAccount.mockResolvedValue(null); + + const result = await registeredHandler({}, createMockExtra({ accountId: "account-123" })); + + expect(mockSelectPulseAccount).toHaveBeenCalledWith("account-123"); + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining('"active":false'), + }, + ], + }); + }); + + it("returns pulse with active: true when record exists", async () => { + mockSelectPulseAccount.mockResolvedValue({ + id: "pulse-456", + account_id: "account-123", + active: true, + }); + + const result = await registeredHandler({}, createMockExtra({ accountId: "account-123" })); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining('"active":true'), + }, + ], + }); + }); + + it("allows account_id override for org auth with access", async () => { + mockCanAccessAccount.mockResolvedValue(true); + mockSelectPulseAccount.mockResolvedValue({ + id: "pulse-456", + account_id: "target-account-789", + active: true, + }); + + await registeredHandler( + { account_id: "target-account-789" }, + createMockExtra({ accountId: "org-account-id", orgId: "org-account-id" }), + ); + + expect(mockCanAccessAccount).toHaveBeenCalledWith({ + orgId: "org-account-id", + targetAccountId: "target-account-789", + }); + expect(mockSelectPulseAccount).toHaveBeenCalledWith("target-account-789"); + }); + + it("returns error when org auth lacks access to account_id", async () => { + mockCanAccessAccount.mockResolvedValue(false); + + const result = await registeredHandler( + { account_id: "target-account-789" }, + createMockExtra({ accountId: "org-account-id", orgId: "org-account-id" }), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Access denied"), + }, + ], + }); + }); + + it("returns error when neither auth nor account_id is provided", async () => { + const result = await registeredHandler({}, createMockExtra()); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Authentication required"), + }, + ], + }); + }); +}); diff --git a/lib/mcp/tools/pulse/__tests__/registerUpdatePulseTool.test.ts b/lib/mcp/tools/pulse/__tests__/registerUpdatePulseTool.test.ts new file mode 100644 index 0000000..cea2959 --- /dev/null +++ b/lib/mcp/tools/pulse/__tests__/registerUpdatePulseTool.test.ts @@ -0,0 +1,196 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; + +import { registerUpdatePulseTool } from "../registerUpdatePulseTool"; + +const mockUpsertPulseAccount = vi.fn(); +const mockCanAccessAccount = vi.fn(); + +vi.mock("@/lib/supabase/pulse_accounts/upsertPulseAccount", () => ({ + upsertPulseAccount: (...args: unknown[]) => mockUpsertPulseAccount(...args), +})); + +vi.mock("@/lib/organizations/canAccessAccount", () => ({ + canAccessAccount: (...args: unknown[]) => mockCanAccessAccount(...args), +})); + +type ServerRequestHandlerExtra = RequestHandlerExtra; + +/** + * Creates a mock extra object with optional authInfo. + * + * @param authInfo + * @param authInfo.accountId + * @param authInfo.orgId + */ +function createMockExtra(authInfo?: { + accountId?: string; + orgId?: string | null; +}): ServerRequestHandlerExtra { + return { + authInfo: authInfo + ? { + token: "test-token", + scopes: ["mcp:tools"], + clientId: authInfo.accountId, + extra: { + accountId: authInfo.accountId, + orgId: authInfo.orgId ?? null, + }, + } + : undefined, + } as unknown as ServerRequestHandlerExtra; +} + +describe("registerUpdatePulseTool", () => { + let mockServer: McpServer; + let registeredHandler: (args: unknown, extra: ServerRequestHandlerExtra) => Promise; + + beforeEach(() => { + vi.clearAllMocks(); + + mockServer = { + registerTool: vi.fn((name, config, handler) => { + registeredHandler = handler; + }), + } as unknown as McpServer; + + registerUpdatePulseTool(mockServer); + }); + + it("registers the update_pulse tool", () => { + expect(mockServer.registerTool).toHaveBeenCalledWith( + "update_pulse", + expect.objectContaining({ + description: "Update the pulse status for an account.", + }), + expect.any(Function), + ); + }); + + it("updates pulse with active: true", async () => { + mockUpsertPulseAccount.mockResolvedValue({ + id: "pulse-456", + account_id: "account-123", + active: true, + }); + + const result = await registeredHandler( + { active: true }, + createMockExtra({ accountId: "account-123" }), + ); + + expect(mockUpsertPulseAccount).toHaveBeenCalledWith({ + account_id: "account-123", + active: true, + }); + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining('"active":true'), + }, + ], + }); + }); + + it("updates pulse with active: false", async () => { + mockUpsertPulseAccount.mockResolvedValue({ + id: "pulse-456", + account_id: "account-123", + active: false, + }); + + const result = await registeredHandler( + { active: false }, + createMockExtra({ accountId: "account-123" }), + ); + + expect(mockUpsertPulseAccount).toHaveBeenCalledWith({ + account_id: "account-123", + active: false, + }); + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining('"active":false'), + }, + ], + }); + }); + + it("allows account_id override for org auth with access", async () => { + mockCanAccessAccount.mockResolvedValue(true); + mockUpsertPulseAccount.mockResolvedValue({ + id: "pulse-456", + account_id: "target-account-789", + active: true, + }); + + await registeredHandler( + { active: true, account_id: "target-account-789" }, + createMockExtra({ accountId: "org-account-id", orgId: "org-account-id" }), + ); + + expect(mockCanAccessAccount).toHaveBeenCalledWith({ + orgId: "org-account-id", + targetAccountId: "target-account-789", + }); + expect(mockUpsertPulseAccount).toHaveBeenCalledWith({ + account_id: "target-account-789", + active: true, + }); + }); + + it("returns error when org auth lacks access to account_id", async () => { + mockCanAccessAccount.mockResolvedValue(false); + + const result = await registeredHandler( + { active: true, account_id: "target-account-789" }, + createMockExtra({ accountId: "org-account-id", orgId: "org-account-id" }), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Access denied"), + }, + ], + }); + }); + + it("returns error when neither auth nor account_id is provided", async () => { + const result = await registeredHandler({ active: true }, createMockExtra()); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Authentication required"), + }, + ], + }); + }); + + it("returns error when upsert fails", async () => { + mockUpsertPulseAccount.mockResolvedValue(null); + + const result = await registeredHandler( + { active: true }, + createMockExtra({ accountId: "account-123" }), + ); + + expect(result).toEqual({ + content: [ + { + type: "text", + text: expect.stringContaining("Failed to update pulse status"), + }, + ], + }); + }); +}); diff --git a/lib/mcp/tools/pulse/index.ts b/lib/mcp/tools/pulse/index.ts new file mode 100644 index 0000000..70042af --- /dev/null +++ b/lib/mcp/tools/pulse/index.ts @@ -0,0 +1,13 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import { registerGetPulseTool } from "./registerGetPulseTool"; +import { registerUpdatePulseTool } from "./registerUpdatePulseTool"; + +/** + * Registers all pulse-related MCP tools on the server. + * + * @param server - The MCP server instance to register tools on. + */ +export const registerAllPulseTools = (server: McpServer): void => { + registerGetPulseTool(server); + registerUpdatePulseTool(server); +}; diff --git a/lib/mcp/tools/pulse/registerGetPulseTool.ts b/lib/mcp/tools/pulse/registerGetPulseTool.ts new file mode 100644 index 0000000..d60fda9 --- /dev/null +++ b/lib/mcp/tools/pulse/registerGetPulseTool.ts @@ -0,0 +1,54 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; +import { z } from "zod"; +import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; +import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; +import { selectPulseAccount } from "@/lib/supabase/pulse_accounts/selectPulseAccount"; +import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess"; +import { getToolResultError } from "@/lib/mcp/getToolResultError"; + +const getPulseSchema = z.object({ + account_id: z.string().optional().describe("The account ID to get pulse status for."), +}); + +export type GetPulseArgs = z.infer; + +/** + * Registers the "get_pulse" tool on the MCP server. + * Retrieves the pulse status for an account. + * + * @param server - The MCP server instance to register the tool on. + */ +export function registerGetPulseTool(server: McpServer): void { + server.registerTool( + "get_pulse", + { + description: "Get the pulse status for an account.", + inputSchema: getPulseSchema, + }, + async (args: GetPulseArgs, extra: RequestHandlerExtra) => { + const { account_id } = args; + + const authInfo = extra.authInfo as McpAuthInfo | undefined; + const { accountId, error } = await resolveAccountId({ + authInfo, + accountIdOverride: account_id, + }); + + if (error) { + return getToolResultError(error); + } + + if (!accountId) { + return getToolResultError("Failed to resolve account ID"); + } + + const pulseAccount = await selectPulseAccount(accountId); + + return getToolResultSuccess({ + pulse: pulseAccount ?? { id: null, account_id: accountId, active: false }, + }); + }, + ); +} diff --git a/lib/mcp/tools/pulse/registerUpdatePulseTool.ts b/lib/mcp/tools/pulse/registerUpdatePulseTool.ts new file mode 100644 index 0000000..d96cd1d --- /dev/null +++ b/lib/mcp/tools/pulse/registerUpdatePulseTool.ts @@ -0,0 +1,60 @@ +import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { ServerRequest, ServerNotification } from "@modelcontextprotocol/sdk/types.js"; +import { z } from "zod"; +import type { McpAuthInfo } from "@/lib/mcp/verifyApiKey"; +import { resolveAccountId } from "@/lib/mcp/resolveAccountId"; +import { upsertPulseAccount } from "@/lib/supabase/pulse_accounts/upsertPulseAccount"; +import { getToolResultSuccess } from "@/lib/mcp/getToolResultSuccess"; +import { getToolResultError } from "@/lib/mcp/getToolResultError"; + +const updatePulseSchema = z.object({ + active: z.boolean().describe("Whether pulse is active for this account"), + account_id: z.string().optional().describe("The account ID to update pulse status for."), +}); + +export type UpdatePulseArgs = z.infer; + +/** + * Registers the "update_pulse" tool on the MCP server. + * Updates the pulse status for an account. + * + * @param server - The MCP server instance to register the tool on. + */ +export function registerUpdatePulseTool(server: McpServer): void { + server.registerTool( + "update_pulse", + { + description: "Update the pulse status for an account.", + inputSchema: updatePulseSchema, + }, + async ( + args: UpdatePulseArgs, + extra: RequestHandlerExtra, + ) => { + const { active, account_id } = args; + + const authInfo = extra.authInfo as McpAuthInfo | undefined; + const { accountId, error } = await resolveAccountId({ + authInfo, + accountIdOverride: account_id, + }); + + if (error) { + return getToolResultError(error); + } + + if (!accountId) { + return getToolResultError("Failed to resolve account ID"); + } + + const pulseAccount = await upsertPulseAccount({ account_id: accountId, active }); + + if (!pulseAccount) { + return getToolResultError("Failed to update pulse status"); + } + + return getToolResultSuccess({ pulse: pulseAccount }); + }, + ); +}