diff --git a/src/db/client.ts b/src/db/client.ts index b4aea43a..1facb85d 100644 --- a/src/db/client.ts +++ b/src/db/client.ts @@ -3,15 +3,121 @@ import { drizzle } from 'drizzle-orm/node-postgres'; import pg from 'pg'; import * as schema from './schema/index.js'; -let db: ReturnType> | null = null; -let pool: pg.Pool | null = null; -let _testDbOverride: ReturnType> | null = null; +// ============================================================================ +// DatabaseContext class +// ============================================================================ -/** Test-only: override the DB instance returned by getDb(). */ -export function _setTestDb(db: ReturnType> | null): void { - _testDbOverride = db; +export type DrizzleDb = ReturnType>; + +interface DatabaseConfig { + connectionString: string; + max?: number; + ssl: false | { rejectUnauthorized: boolean; ca?: string }; +} + +/** + * Encapsulates a Drizzle database instance and its underlying connection pool. + * Use `createDatabaseContext()` to create instances. + */ +export class DatabaseContext { + private db: DrizzleDb; + private pool: pg.Pool; + + constructor(config: DatabaseConfig) { + this.pool = new pg.Pool({ + connectionString: config.connectionString, + max: config.max ?? 5, + ssl: config.ssl, + }); + this.db = drizzle(this.pool, { schema }); + } + + getDb(): DrizzleDb { + return this.db; + } + + async close(): Promise { + await this.pool.end(); + } +} + +/** + * Factory function that creates a DatabaseContext from environment variables. + */ +export function createDatabaseContext(): DatabaseContext { + return new DatabaseContext({ + connectionString: getDatabaseUrl(), + ssl: getSslConfig(), + }); +} + +// ============================================================================ +// Default global context (lazy singleton) +// ============================================================================ + +let _defaultContext: DatabaseContext | null = null; + +/** + * Set the default DatabaseContext used by `getDb()`. + * Replaces `_setTestDb()` — use this in tests to inject a mock database. + */ +export function setDefaultDatabaseContext(context: DatabaseContext | null): void { + _defaultContext = context; +} + +/** + * @deprecated Use `setDefaultDatabaseContext()` instead. + * Kept for backward compatibility during migration. + */ +export function _setTestDb(db: DrizzleDb | null): void { + if (db === null) { + _defaultContext = null; + } else { + // Wrap the raw db in a minimal DatabaseContext-like object + _defaultContext = { + getDb: () => db, + close: async () => {}, + } as DatabaseContext; + } } +// ============================================================================ +// Module-level API (backward-compatible) +// ============================================================================ + +/** + * Returns the default database instance. + * Lazily initializes a global DatabaseContext on first call. + * If `setDefaultDatabaseContext()` has been called, returns that context's db. + */ +export function getDb(): DrizzleDb { + if (!_defaultContext) { + _defaultContext = createDatabaseContext(); + } + return _defaultContext.getDb(); +} + +/** + * Closes the default database connection pool and resets the context. + * Safe to call even if the db has never been initialized. + */ +export async function closeDb(): Promise { + if (_defaultContext) { + // Only close if it's a real DatabaseContext (has its own pool) + // Skip if it was set via _setTestDb (which wraps a mock) + try { + await _defaultContext.close(); + } catch { + // Ignore errors closing mock contexts + } + _defaultContext = null; + } +} + +// ============================================================================ +// Internal helpers +// ============================================================================ + function getDatabaseUrl(): string { if (process.env.DATABASE_URL) { return process.env.DATABASE_URL; @@ -43,24 +149,3 @@ function getSslConfig(): false | { rejectUnauthorized: boolean; ca?: string } { } return sslConfig; } - -export function getDb(): ReturnType> { - if (_testDbOverride) return _testDbOverride; - if (!db) { - pool = new pg.Pool({ - connectionString: getDatabaseUrl(), - max: 5, - ssl: getSslConfig(), - }); - db = drizzle(pool, { schema }); - } - return db; -} - -export async function closeDb(): Promise { - if (pool) { - await pool.end(); - pool = null; - db = null; - } -} diff --git a/tests/helpers/sharedMocks.ts b/tests/helpers/sharedMocks.ts index b252d888..486a03c1 100644 --- a/tests/helpers/sharedMocks.ts +++ b/tests/helpers/sharedMocks.ts @@ -169,6 +169,8 @@ export const mockGetDb = vi.fn(); export const mockDbClientModule = { getDb: mockGetDb, closeDb: vi.fn(), + setDefaultDatabaseContext: vi.fn(), + _setTestDb: vi.fn(), }; // --------------------------------------------------------------------------- diff --git a/tests/unit/db/client.test.ts b/tests/unit/db/client.test.ts index f0196f56..a2cc6e44 100644 --- a/tests/unit/db/client.test.ts +++ b/tests/unit/db/client.test.ts @@ -31,23 +31,88 @@ vi.mock('node:fs', () => ({ // ── Imports (after mocks) ───────────────────────────────────────────────────── -import { _setTestDb, closeDb, getDb } from '../../../src/db/client.js'; +import { + DatabaseContext, + _setTestDb, + closeDb, + createDatabaseContext, + getDb, + setDefaultDatabaseContext, +} from '../../../src/db/client.js'; // ── Helpers ─────────────────────────────────────────────────────────────────── -/** Reset module-level pool/db singletons between tests. */ +/** Reset module-level context singleton between tests. */ async function resetDbState() { - // closeDb() resets pool + db to null; if pool is null it's a no-op so safe. + // closeDb() resets _defaultContext to null; safe to call when already null. await closeDb(); - // Also clear the test override. - _setTestDb(null); + // Also clear any test override. + setDefaultDatabaseContext(null); } +// ── Tests: setDefaultDatabaseContext ───────────────────────────────────────── + +describe('setDefaultDatabaseContext', () => { + afterEach(async () => { + await resetDbState(); + }); + + it('getDb() returns the db from the injected context', () => { + const fakeDb = { __isFakeDb: true } as unknown as ReturnType; + const fakeContext = { + getDb: () => fakeDb, + close: vi.fn().mockResolvedValue(undefined), + } as unknown as DatabaseContext; + + setDefaultDatabaseContext(fakeContext); + expect(getDb()).toBe(fakeDb); + }); + + it('getDb() returns the latest injected context', () => { + const fakeDb1 = { id: 1 } as unknown as ReturnType; + const fakeDb2 = { id: 2 } as unknown as ReturnType; + + const fakeCtx1 = { + getDb: () => fakeDb1, + close: vi.fn().mockResolvedValue(undefined), + } as unknown as DatabaseContext; + const fakeCtx2 = { + getDb: () => fakeDb2, + close: vi.fn().mockResolvedValue(undefined), + } as unknown as DatabaseContext; + + setDefaultDatabaseContext(fakeCtx1); + expect(getDb()).toBe(fakeDb1); + + setDefaultDatabaseContext(fakeCtx2); + expect(getDb()).toBe(fakeDb2); + }); + + it('setting null causes getDb() to create a new real context', () => { + vi.stubEnv('DATABASE_URL', 'postgresql://user:pass@localhost:5432/testdb'); + + const fakeDb = { __isFakeDb: true } as unknown as ReturnType; + const fakeContext = { + getDb: () => fakeDb, + close: vi.fn().mockResolvedValue(undefined), + } as unknown as DatabaseContext; + + setDefaultDatabaseContext(fakeContext); + expect(getDb()).toBe(fakeDb); + + setDefaultDatabaseContext(null); + // Now getDb() should create a new pool and return the drizzle mock + const db = getDb(); + expect(db).toEqual({ __isMockDrizzle: true }); + expect(mockPoolConstructor).toHaveBeenCalled(); + }); +}); + // ── Tests: _setTestDb (pre-existing coverage, kept for regression) ──────────── describe('_setTestDb', () => { - afterEach(() => { - _setTestDb(null); + afterEach(async () => { + await resetDbState(); }); it('getDb() returns the override when set', () => { @@ -77,6 +142,41 @@ describe('_setTestDb', () => { // Assert: new override wins expect(getDb()).toBe(newDb); }); + + it('_setTestDb(null) clears the override', async () => { + vi.stubEnv('DATABASE_URL', 'postgresql://user:pass@localhost:5432/testdb'); + + const fakeDb = { __isFakeDb: true } as unknown as ReturnType; + _setTestDb(fakeDb); + expect(getDb()).toBe(fakeDb); + + _setTestDb(null); + // After clearing, getDb() should create a real context + const db = getDb(); + expect(db).toEqual({ __isMockDrizzle: true }); + }); +}); + +// ── Tests: createDatabaseContext ────────────────────────────────────────────── + +describe('createDatabaseContext', () => { + beforeEach(async () => { + await resetDbState(); + vi.stubEnv('DATABASE_URL', 'postgresql://user:pass@localhost:5432/testdb'); + }); + + afterEach(async () => { + await resetDbState(); + }); + + it('creates a DatabaseContext with a pool and drizzle db', () => { + vi.stubEnv('DATABASE_SSL', 'false'); + + const ctx = createDatabaseContext(); + expect(ctx).toBeInstanceOf(DatabaseContext); + expect(ctx.getDb()).toEqual({ __isMockDrizzle: true }); + expect(mockPoolConstructor).toHaveBeenCalled(); + }); }); // ── Tests: getDatabaseUrl (tested via getDb internals) ─────────────────────── @@ -232,7 +332,7 @@ describe('closeDb', () => { }); it('calls pool.end() and resets state', async () => { - getDb(); // creates pool + getDb(); // creates pool via DatabaseContext expect(mockPoolConstructor).toHaveBeenCalledTimes(1); await closeDb(); @@ -243,10 +343,19 @@ describe('closeDb', () => { expect(mockPoolConstructor).toHaveBeenCalledTimes(2); }); - it('is a no-op when pool is already null', async () => { - // No getDb() call — pool is null + it('is a no-op when context is null (pool never initialized)', async () => { + // No getDb() call — _defaultContext is null await closeDb(); expect(mockPoolEnd).not.toHaveBeenCalled(); }); + + it('does not throw when close() is called on a mock context', async () => { + // When _setTestDb is used, close() is a no-op and should not throw + const fakeDb = { __isFakeDb: true } as unknown as ReturnType; + _setTestDb(fakeDb); + + await expect(closeDb()).resolves.toBeUndefined(); + expect(mockPoolEnd).not.toHaveBeenCalled(); + }); });