diff --git a/packages/server/src/jwt/jwt.guard.ts b/packages/server/src/jwt/jwt.guard.ts index c431b4d1..d8da9f86 100644 --- a/packages/server/src/jwt/jwt.guard.ts +++ b/packages/server/src/jwt/jwt.guard.ts @@ -3,7 +3,7 @@ import { AuthGuard } from '@nestjs/passport'; import { GqlExecutionContext } from '@nestjs/graphql'; @Injectable() -export class JwtAuthGuard extends AuthGuard('local') { +export class JwtAuthGuard extends AuthGuard('jwt') { getRequest(context: ExecutionContext) { // Return HTTP context if (context.getType() === 'http') { diff --git a/packages/server/src/jwt/jwt.module.ts b/packages/server/src/jwt/jwt.module.ts index e59eac10..a045b434 100644 --- a/packages/server/src/jwt/jwt.module.ts +++ b/packages/server/src/jwt/jwt.module.ts @@ -1,13 +1,39 @@ -import { Module, forwardRef } from '@nestjs/common'; +import { Module, UnauthorizedException, forwardRef } from '@nestjs/common'; import { JwtService } from './jwt.service'; import { HttpModule } from '@nestjs/axios'; import { JwtAuthGuard } from './jwt.guard'; import { JwtStrategy } from './jwt.strategy'; import { OrganizationModule } from '../organization/organization.module'; import { UserOrgModule } from '../userorg/userorg.module'; +import { JwtSecretRequestType, JwtModule as NestJwtModule } from '@nestjs/jwt'; @Module({ - imports: [HttpModule, forwardRef(() => OrganizationModule), UserOrgModule], + imports: [ + HttpModule, + forwardRef(() => OrganizationModule), + UserOrgModule, + NestJwtModule.registerAsync({ + imports: [forwardRef(() => JwtModule)], + inject: [JwtService], + useFactory: (jwtService: JwtService) => ({ + secretOrKeyProvider: async (requestType, rawJwtToken) => { + // Can only verify tokens via the Google public key + switch (requestType) { + case JwtSecretRequestType.SIGN: + throw new Error('Cannot sign tokens'); + case JwtSecretRequestType.VERIFY: + const publicKey = await jwtService.getPublicKey(rawJwtToken); + if (!publicKey) { + throw new UnauthorizedException('No public key found for token'); + } + return publicKey; + default: + throw new Error('Invalid request type'); + } + } + }) + }) + ], providers: [JwtService, JwtAuthGuard, JwtStrategy], exports: [JwtService] }) diff --git a/packages/server/src/jwt/jwt.service.ts b/packages/server/src/jwt/jwt.service.ts index ea651f83..a766f6fa 100644 --- a/packages/server/src/jwt/jwt.service.ts +++ b/packages/server/src/jwt/jwt.service.ts @@ -21,34 +21,34 @@ export class JwtService { return response.data; } - async getPublicKey(kid: string): Promise { - if (!this.publicKeys || !this.publicKeys[kid]) { - this.publicKeys = await this.queryForPublicKey(); + // TODO: Handle when key rotation has taken place + async getPublicKey(rawToken: string | null | Buffer | object): Promise { + // Make sure the tokn is the correct type + if (!rawToken) { + return null; + } + if (typeof rawToken === 'object') { + return null; } - return this.publicKeys[kid] || null; - } - async validate(rawToken: string): Promise { - // Parse out the token - const tokenString = rawToken.split(' ')[1]; - const token = jwt.decode(tokenString, { complete: true }) as any; + // Decode the token to get the kid + const token = jwt.decode(rawToken, { complete: true }); + if (!token) { + return null; + } - // Get the kid to verify the JWT against + // Get the kid from the token const kid = token.header.kid; if (!kid) { return null; } - const publicKey = await this.getPublicKey(kid); - if (!publicKey) { - return null; + // If we don't have the public keys yet or the kid isn't in the public keys, query for the public keys + if (!this.publicKeys || !this.publicKeys[kid]) { + this.publicKeys = await this.queryForPublicKey(); } - try { - jwt.verify(tokenString, publicKey); - return token.payload; - } catch (e) { - return null; - } + // Return the public key + return this.publicKeys[kid] || null; } } diff --git a/packages/server/src/jwt/jwt.strategy.ts b/packages/server/src/jwt/jwt.strategy.ts index 75f597dc..e62ebf8a 100644 --- a/packages/server/src/jwt/jwt.strategy.ts +++ b/packages/server/src/jwt/jwt.strategy.ts @@ -1,45 +1,38 @@ import { Injectable } from '@nestjs/common'; import { PassportStrategy } from '@nestjs/passport'; -import { Strategy } from 'passport-local'; import { TokenPayload } from './token.dto'; -import { Request } from 'express'; -import { ParamsDictionary } from 'express-serve-static-core'; -import { ParsedQs } from 'qs'; import { JwtService } from './jwt.service'; +import { Strategy, ExtractJwt } from 'passport-jwt'; +import { Request } from 'express'; @Injectable() export class JwtStrategy extends PassportStrategy(Strategy) { - constructor(private readonly jwtService: JwtService) { - super(); - } - - async authenticate( - req: Request>, - _options?: any - ): Promise { - // Check if the token is present - const rawToken = req.headers.authorization; - if (!rawToken) { - this.fail({ meessage: 'Invalid Token' }, 400); - return; - } - - // Validate the token - const payload = await this.jwtService.validate(rawToken); - if (!payload) { - this.fail({ message: 'Invalid Token' }, 400); - return; - } - - const result = await this.validate(payload); - this.success(result); + constructor(jwtService: JwtService) { + super({ + jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(), + ignoreExpiration: false, + secretOrKeyProvider: ( + _request: Request, + rawJwtToken: any, + done: (err: any, secretOrKey?: string | Buffer) => void + ) => { + // Can only verify tokens via the Google public key + jwtService + .getPublicKey(rawJwtToken) + .then((publicKey) => { + if (!publicKey) { + done(new Error('No public key found for token')); + return; + } + done(null, publicKey); + }) + .catch((err) => { + done(err); + }); + } + }); } - /** - * Need to add the organization at this step since the organization is - * queried from the database and not part of the JWT token. This allows - * the organization to then be pulled in via the organization context - */ async validate(payload: TokenPayload): Promise { return { ...payload