Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/server/src/jwt/jwt.guard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { AuthGuard } from '@nestjs/passport';
import { GqlExecutionContext } from '@nestjs/graphql';

@Injectable()
export class JwtAuthGuard extends AuthGuard('local') {
export class JwtAuthGuard extends AuthGuard('jwt') {
getRequest(context: ExecutionContext) {
// Return HTTP context
if (context.getType() === 'http') {
Expand Down
30 changes: 28 additions & 2 deletions packages/server/src/jwt/jwt.module.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
import { Module, forwardRef } from '@nestjs/common';
import { Module, UnauthorizedException, forwardRef } from '@nestjs/common';
import { JwtService } from './jwt.service';
import { HttpModule } from '@nestjs/axios';
import { JwtAuthGuard } from './jwt.guard';
import { JwtStrategy } from './jwt.strategy';
import { OrganizationModule } from '../organization/organization.module';
import { UserOrgModule } from '../userorg/userorg.module';
import { JwtSecretRequestType, JwtModule as NestJwtModule } from '@nestjs/jwt';

@Module({
imports: [HttpModule, forwardRef(() => OrganizationModule), UserOrgModule],
imports: [
HttpModule,
forwardRef(() => OrganizationModule),
UserOrgModule,
NestJwtModule.registerAsync({
imports: [forwardRef(() => JwtModule)],
inject: [JwtService],
useFactory: (jwtService: JwtService) => ({
secretOrKeyProvider: async (requestType, rawJwtToken) => {
// Can only verify tokens via the Google public key
switch (requestType) {
case JwtSecretRequestType.SIGN:
throw new Error('Cannot sign tokens');
case JwtSecretRequestType.VERIFY:
const publicKey = await jwtService.getPublicKey(rawJwtToken);
if (!publicKey) {
throw new UnauthorizedException('No public key found for token');
}
return publicKey;
default:
throw new Error('Invalid request type');
}
}
})
})
],
providers: [JwtService, JwtAuthGuard, JwtStrategy],
exports: [JwtService]
})
Expand Down
38 changes: 19 additions & 19 deletions packages/server/src/jwt/jwt.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,34 @@ export class JwtService {
return response.data;
}

async getPublicKey(kid: string): Promise<string | null> {
if (!this.publicKeys || !this.publicKeys[kid]) {
this.publicKeys = await this.queryForPublicKey();
// TODO: Handle when key rotation has taken place
async getPublicKey(rawToken: string | null | Buffer | object): Promise<string | null> {
// Make sure the tokn is the correct type
if (!rawToken) {
return null;
}
if (typeof rawToken === 'object') {
return null;
}
return this.publicKeys[kid] || null;
}

async validate(rawToken: string): Promise<any | null> {
// Parse out the token
const tokenString = rawToken.split(' ')[1];
const token = jwt.decode(tokenString, { complete: true }) as any;
// Decode the token to get the kid
const token = jwt.decode(rawToken, { complete: true });
if (!token) {
return null;
}

// Get the kid to verify the JWT against
// Get the kid from the token
const kid = token.header.kid;
if (!kid) {
return null;
}

const publicKey = await this.getPublicKey(kid);
if (!publicKey) {
return null;
// If we don't have the public keys yet or the kid isn't in the public keys, query for the public keys
if (!this.publicKeys || !this.publicKeys[kid]) {
this.publicKeys = await this.queryForPublicKey();
}

try {
jwt.verify(tokenString, publicKey);
return token.payload;
} catch (e) {
return null;
}
// Return the public key
return this.publicKeys[kid] || null;
}
}
59 changes: 26 additions & 33 deletions packages/server/src/jwt/jwt.strategy.ts
Original file line number Diff line number Diff line change
@@ -1,45 +1,38 @@
import { Injectable } from '@nestjs/common';
import { PassportStrategy } from '@nestjs/passport';
import { Strategy } from 'passport-local';
import { TokenPayload } from './token.dto';
import { Request } from 'express';
import { ParamsDictionary } from 'express-serve-static-core';
import { ParsedQs } from 'qs';
import { JwtService } from './jwt.service';
import { Strategy, ExtractJwt } from 'passport-jwt';
import { Request } from 'express';

@Injectable()
export class JwtStrategy extends PassportStrategy(Strategy) {
constructor(private readonly jwtService: JwtService) {
super();
}

async authenticate(
req: Request<ParamsDictionary, any, any, ParsedQs, Record<string, any>>,
_options?: any
): Promise<void> {
// Check if the token is present
const rawToken = req.headers.authorization;
if (!rawToken) {
this.fail({ meessage: 'Invalid Token' }, 400);
return;
}

// Validate the token
const payload = await this.jwtService.validate(rawToken);
if (!payload) {
this.fail({ message: 'Invalid Token' }, 400);
return;
}

const result = await this.validate(payload);
this.success(result);
constructor(jwtService: JwtService) {
super({
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
ignoreExpiration: false,
secretOrKeyProvider: (
_request: Request,
rawJwtToken: any,
done: (err: any, secretOrKey?: string | Buffer) => void
) => {
// Can only verify tokens via the Google public key
jwtService
.getPublicKey(rawJwtToken)
.then((publicKey) => {
if (!publicKey) {
done(new Error('No public key found for token'));
return;
}
done(null, publicKey);
})
.catch((err) => {
done(err);
});
}
});
}

/**
* Need to add the organization at this step since the organization is
* queried from the database and not part of the JWT token. This allows
* the organization to then be pulled in via the organization context
*/
async validate(payload: TokenPayload): Promise<TokenPayload> {
return {
...payload
Expand Down