Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 112 additions & 27 deletions src/db/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,121 @@ import { drizzle } from 'drizzle-orm/node-postgres';
import pg from 'pg';
import * as schema from './schema/index.js';

let db: ReturnType<typeof drizzle<typeof schema>> | null = null;
let pool: pg.Pool | null = null;
let _testDbOverride: ReturnType<typeof drizzle<typeof schema>> | null = null;
// ============================================================================
// DatabaseContext class
// ============================================================================

/** Test-only: override the DB instance returned by getDb(). */
export function _setTestDb(db: ReturnType<typeof drizzle<typeof schema>> | null): void {
_testDbOverride = db;
export type DrizzleDb = ReturnType<typeof drizzle<typeof schema>>;

interface DatabaseConfig {
connectionString: string;
max?: number;
ssl: false | { rejectUnauthorized: boolean; ca?: string };
}

/**
* Encapsulates a Drizzle database instance and its underlying connection pool.
* Use `createDatabaseContext()` to create instances.
*/
export class DatabaseContext {
private db: DrizzleDb;
private pool: pg.Pool;

constructor(config: DatabaseConfig) {
this.pool = new pg.Pool({
connectionString: config.connectionString,
max: config.max ?? 5,
ssl: config.ssl,
});
this.db = drizzle(this.pool, { schema });
}

getDb(): DrizzleDb {
return this.db;
}

async close(): Promise<void> {
await this.pool.end();
}
}

/**
* Factory function that creates a DatabaseContext from environment variables.
*/
export function createDatabaseContext(): DatabaseContext {
return new DatabaseContext({
connectionString: getDatabaseUrl(),
ssl: getSslConfig(),
});
}

// ============================================================================
// Default global context (lazy singleton)
// ============================================================================

let _defaultContext: DatabaseContext | null = null;

/**
* Set the default DatabaseContext used by `getDb()`.
* Replaces `_setTestDb()` — use this in tests to inject a mock database.
*/
export function setDefaultDatabaseContext(context: DatabaseContext | null): void {
_defaultContext = context;
}

/**
* @deprecated Use `setDefaultDatabaseContext()` instead.
* Kept for backward compatibility during migration.
*/
export function _setTestDb(db: DrizzleDb | null): void {
if (db === null) {
_defaultContext = null;
} else {
// Wrap the raw db in a minimal DatabaseContext-like object
_defaultContext = {
getDb: () => db,
close: async () => {},
} as DatabaseContext;
}
}

// ============================================================================
// Module-level API (backward-compatible)
// ============================================================================

/**
* Returns the default database instance.
* Lazily initializes a global DatabaseContext on first call.
* If `setDefaultDatabaseContext()` has been called, returns that context's db.
*/
export function getDb(): DrizzleDb {
if (!_defaultContext) {
_defaultContext = createDatabaseContext();
}
return _defaultContext.getDb();
}

/**
* Closes the default database connection pool and resets the context.
* Safe to call even if the db has never been initialized.
*/
export async function closeDb(): Promise<void> {
if (_defaultContext) {
// Only close if it's a real DatabaseContext (has its own pool)
// Skip if it was set via _setTestDb (which wraps a mock)
try {
await _defaultContext.close();
} catch {
// Ignore errors closing mock contexts
}
_defaultContext = null;
}
}

// ============================================================================
// Internal helpers
// ============================================================================

function getDatabaseUrl(): string {
if (process.env.DATABASE_URL) {
return process.env.DATABASE_URL;
Expand Down Expand Up @@ -43,24 +149,3 @@ function getSslConfig(): false | { rejectUnauthorized: boolean; ca?: string } {
}
return sslConfig;
}

export function getDb(): ReturnType<typeof drizzle<typeof schema>> {
if (_testDbOverride) return _testDbOverride;
if (!db) {
pool = new pg.Pool({
connectionString: getDatabaseUrl(),
max: 5,
ssl: getSslConfig(),
});
db = drizzle(pool, { schema });
}
return db;
}

export async function closeDb(): Promise<void> {
if (pool) {
await pool.end();
pool = null;
db = null;
}
}
2 changes: 2 additions & 0 deletions tests/helpers/sharedMocks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ export const mockGetDb = vi.fn();
export const mockDbClientModule = {
getDb: mockGetDb,
closeDb: vi.fn(),
setDefaultDatabaseContext: vi.fn(),
_setTestDb: vi.fn(),
};

// ---------------------------------------------------------------------------
Expand Down
129 changes: 119 additions & 10 deletions tests/unit/db/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,88 @@ vi.mock('node:fs', () => ({

// ── Imports (after mocks) ─────────────────────────────────────────────────────

import { _setTestDb, closeDb, getDb } from '../../../src/db/client.js';
import {
DatabaseContext,
_setTestDb,
closeDb,
createDatabaseContext,
getDb,
setDefaultDatabaseContext,
} from '../../../src/db/client.js';

// ── Helpers ───────────────────────────────────────────────────────────────────

/** Reset module-level pool/db singletons between tests. */
/** Reset module-level context singleton between tests. */
async function resetDbState() {
// closeDb() resets pool + db to null; if pool is null it's a no-op so safe.
// closeDb() resets _defaultContext to null; safe to call when already null.
await closeDb();
// Also clear the test override.
_setTestDb(null);
// Also clear any test override.
setDefaultDatabaseContext(null);
}

// ── Tests: setDefaultDatabaseContext ─────────────────────────────────────────

describe('setDefaultDatabaseContext', () => {
afterEach(async () => {
await resetDbState();
});

it('getDb() returns the db from the injected context', () => {
const fakeDb = { __isFakeDb: true } as unknown as ReturnType<typeof getDb>;
const fakeContext = {
getDb: () => fakeDb,
close: vi.fn().mockResolvedValue(undefined),
} as unknown as DatabaseContext;

setDefaultDatabaseContext(fakeContext);
expect(getDb()).toBe(fakeDb);
});

it('getDb() returns the latest injected context', () => {
const fakeDb1 = { id: 1 } as unknown as ReturnType<typeof getDb>;
const fakeDb2 = { id: 2 } as unknown as ReturnType<typeof getDb>;

const fakeCtx1 = {
getDb: () => fakeDb1,
close: vi.fn().mockResolvedValue(undefined),
} as unknown as DatabaseContext;
const fakeCtx2 = {
getDb: () => fakeDb2,
close: vi.fn().mockResolvedValue(undefined),
} as unknown as DatabaseContext;

setDefaultDatabaseContext(fakeCtx1);
expect(getDb()).toBe(fakeDb1);

setDefaultDatabaseContext(fakeCtx2);
expect(getDb()).toBe(fakeDb2);
});

it('setting null causes getDb() to create a new real context', () => {
vi.stubEnv('DATABASE_URL', 'postgresql://user:pass@localhost:5432/testdb');

const fakeDb = { __isFakeDb: true } as unknown as ReturnType<typeof getDb>;
const fakeContext = {
getDb: () => fakeDb,
close: vi.fn().mockResolvedValue(undefined),
} as unknown as DatabaseContext;

setDefaultDatabaseContext(fakeContext);
expect(getDb()).toBe(fakeDb);

setDefaultDatabaseContext(null);
// Now getDb() should create a new pool and return the drizzle mock
const db = getDb();
expect(db).toEqual({ __isMockDrizzle: true });
expect(mockPoolConstructor).toHaveBeenCalled();
});
});

// ── Tests: _setTestDb (pre-existing coverage, kept for regression) ────────────

describe('_setTestDb', () => {
afterEach(() => {
_setTestDb(null);
afterEach(async () => {
await resetDbState();
});

it('getDb() returns the override when set', () => {
Expand Down Expand Up @@ -77,6 +142,41 @@ describe('_setTestDb', () => {
// Assert: new override wins
expect(getDb()).toBe(newDb);
});

it('_setTestDb(null) clears the override', async () => {
vi.stubEnv('DATABASE_URL', 'postgresql://user:pass@localhost:5432/testdb');

const fakeDb = { __isFakeDb: true } as unknown as ReturnType<typeof getDb>;
_setTestDb(fakeDb);
expect(getDb()).toBe(fakeDb);

_setTestDb(null);
// After clearing, getDb() should create a real context
const db = getDb();
expect(db).toEqual({ __isMockDrizzle: true });
});
});

// ── Tests: createDatabaseContext ──────────────────────────────────────────────

describe('createDatabaseContext', () => {
beforeEach(async () => {
await resetDbState();
vi.stubEnv('DATABASE_URL', 'postgresql://user:pass@localhost:5432/testdb');
});

afterEach(async () => {
await resetDbState();
});

it('creates a DatabaseContext with a pool and drizzle db', () => {
vi.stubEnv('DATABASE_SSL', 'false');

const ctx = createDatabaseContext();
expect(ctx).toBeInstanceOf(DatabaseContext);
expect(ctx.getDb()).toEqual({ __isMockDrizzle: true });
expect(mockPoolConstructor).toHaveBeenCalled();
});
});

// ── Tests: getDatabaseUrl (tested via getDb internals) ───────────────────────
Expand Down Expand Up @@ -232,7 +332,7 @@ describe('closeDb', () => {
});

it('calls pool.end() and resets state', async () => {
getDb(); // creates pool
getDb(); // creates pool via DatabaseContext
expect(mockPoolConstructor).toHaveBeenCalledTimes(1);

await closeDb();
Expand All @@ -243,10 +343,19 @@ describe('closeDb', () => {
expect(mockPoolConstructor).toHaveBeenCalledTimes(2);
});

it('is a no-op when pool is already null', async () => {
// No getDb() call — pool is null
it('is a no-op when context is null (pool never initialized)', async () => {
// No getDb() call — _defaultContext is null
await closeDb();

expect(mockPoolEnd).not.toHaveBeenCalled();
});

it('does not throw when close() is called on a mock context', async () => {
// When _setTestDb is used, close() is a no-op and should not throw
const fakeDb = { __isFakeDb: true } as unknown as ReturnType<typeof getDb>;
_setTestDb(fakeDb);

await expect(closeDb()).resolves.toBeUndefined();
expect(mockPoolEnd).not.toHaveBeenCalled();
});
});
Loading