diff --git a/lib/DBSQLClient.ts b/lib/DBSQLClient.ts index 5c25d540..c8bc1e10 100644 --- a/lib/DBSQLClient.ts +++ b/lib/DBSQLClient.ts @@ -128,6 +128,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I azureTenantId: options.azureTenantId, clientId: options.oauthClientId, clientSecret: options.oauthClientSecret, + useDatabricksOAuthInAzure: options.useDatabricksOAuthInAzure, context: this, }); case 'custom': diff --git a/lib/connection/auth/DatabricksOAuth/OAuthManager.ts b/lib/connection/auth/DatabricksOAuth/OAuthManager.ts index de805d1a..07c6a0c9 100644 --- a/lib/connection/auth/DatabricksOAuth/OAuthManager.ts +++ b/lib/connection/auth/DatabricksOAuth/OAuthManager.ts @@ -13,6 +13,7 @@ export interface OAuthManagerOptions { clientId?: string; azureTenantId?: string; clientSecret?: string; + useDatabricksOAuthInAzure?: boolean; context: IClientContext; } @@ -189,24 +190,35 @@ export default abstract class OAuthManager { // normalize const host = options.host.toLowerCase().replace('https://', '').split('/')[0]; - // eslint-disable-next-line @typescript-eslint/no-use-before-define - const managers = [AWSOAuthManager, AzureOAuthManager]; + const awsDomains = ['.cloud.databricks.com', '.dev.databricks.com']; + const isAWSDomain = awsDomains.some((domain) => host.endsWith(domain)); + if (isAWSDomain) { + // eslint-disable-next-line @typescript-eslint/no-use-before-define + return new DatabricksOAuthManager(options); + } - for (const OAuthManagerClass of managers) { - for (const domain of OAuthManagerClass.domains) { - if (host.endsWith(domain)) { - return new OAuthManagerClass(options); - } + if (options.useDatabricksOAuthInAzure) { + const domains = ['.azuredatabricks.net']; + const isSupportedDomain = domains.some((domain) => host.endsWith(domain)); + if (isSupportedDomain) { + // eslint-disable-next-line @typescript-eslint/no-use-before-define + return new DatabricksOAuthManager(options); } } + const azureDomains = ['.azuredatabricks.net', '.databricks.azure.us', '.databricks.azure.cn']; + const isAzureDomain = azureDomains.some((domain) => host.endsWith(domain)); + if (isAzureDomain) { + // eslint-disable-next-line @typescript-eslint/no-use-before-define + return new AzureOAuthManager(options); + } + throw new Error(`OAuth is not supported for ${options.host}`); } } -export class AWSOAuthManager extends OAuthManager { - public static domains = ['.cloud.databricks.com', '.dev.databricks.com']; - +// Databricks InHouse OAuth Manager +export class DatabricksOAuthManager extends OAuthManager { public static defaultClientId = 'databricks-sql-connector'; public static defaultCallbackPorts = [8030]; @@ -220,17 +232,15 @@ export class AWSOAuthManager extends OAuthManager { } protected getClientId(): string { - return this.options.clientId ?? AWSOAuthManager.defaultClientId; + return this.options.clientId ?? DatabricksOAuthManager.defaultClientId; } protected getCallbackPorts(): Array { - return this.options.callbackPorts ?? AWSOAuthManager.defaultCallbackPorts; + return this.options.callbackPorts ?? DatabricksOAuthManager.defaultCallbackPorts; } } export class AzureOAuthManager extends OAuthManager { - public static domains = ['.azuredatabricks.net', '.databricks.azure.cn', '.databricks.azure.us']; - public static defaultClientId = '96eecda7-19ea-49cc-abb5-240097d554f5'; public static defaultCallbackPorts = [8030]; diff --git a/lib/contracts/IDBSQLClient.ts b/lib/contracts/IDBSQLClient.ts index 0a14f435..0c0cee89 100644 --- a/lib/contracts/IDBSQLClient.ts +++ b/lib/contracts/IDBSQLClient.ts @@ -19,6 +19,7 @@ type AuthOptions = azureTenantId?: string; oauthClientId?: string; oauthClientSecret?: string; + useDatabricksOAuthInAzure?: boolean; } | { authType: 'custom'; diff --git a/tests/unit/DBSQLClient.test.js b/tests/unit/DBSQLClient.test.js index b1a1f3f2..2ec1ca29 100644 --- a/tests/unit/DBSQLClient.test.js +++ b/tests/unit/DBSQLClient.test.js @@ -5,9 +5,13 @@ const DBSQLSession = require('../../dist/DBSQLSession').default; const PlainHttpAuthentication = require('../../dist/connection/auth/PlainHttpAuthentication').default; const DatabricksOAuth = require('../../dist/connection/auth/DatabricksOAuth').default; -const { AWSOAuthManager, AzureOAuthManager } = require('../../dist/connection/auth/DatabricksOAuth/OAuthManager'); +const { + DatabricksOAuthManager, + AzureOAuthManager, +} = require('../../dist/connection/auth/DatabricksOAuth/OAuthManager'); const HttpConnectionModule = require('../../dist/connection/connections/HttpConnection'); + const { default: HttpConnection } = HttpConnectionModule; class AuthProviderMock { @@ -343,7 +347,7 @@ describe('DBSQLClient.initAuthProvider', () => { }); expect(provider).to.be.instanceOf(DatabricksOAuth); - expect(provider.manager).to.be.instanceOf(AWSOAuthManager); + expect(provider.manager).to.be.instanceOf(DatabricksOAuthManager); }); it('should use Databricks OAuth method (Azure)', () => { @@ -359,6 +363,34 @@ describe('DBSQLClient.initAuthProvider', () => { expect(provider.manager).to.be.instanceOf(AzureOAuthManager); }); + it('should use Databricks InHouse OAuth method (Azure)', () => { + const client = new DBSQLClient(); + + case1: { + const provider = client.initAuthProvider({ + authType: 'databricks-oauth', + // host is used when creating OAuth manager, so make it look like a real Azure instance + host: 'example.azuredatabricks.net', + useDatabricksOAuthInAzure: true, + }); + + expect(provider).to.be.instanceOf(DatabricksOAuth); + expect(provider.manager).to.be.instanceOf(DatabricksOAuthManager); + } + + case2: { + const provider = client.initAuthProvider({ + authType: 'databricks-oauth', + // host is used when creating OAuth manager, so make it look like a real Azure instance + host: 'example.databricks.azure.us', + useDatabricksOAuthInAzure: true, + }); + + expect(provider).to.be.instanceOf(DatabricksOAuth); + expect(provider.manager).to.be.instanceOf(AzureOAuthManager); + } + }); + it('should throw error when OAuth not supported for host', () => { const client = new DBSQLClient(); diff --git a/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js index 814f3faa..aab82975 100644 --- a/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js +++ b/tests/unit/connection/auth/DatabricksOAuth/OAuthManager.test.js @@ -3,7 +3,7 @@ const sinon = require('sinon'); const openidClientLib = require('openid-client'); const { DBSQLLogger, LogLevel } = require('../../../../../dist'); const { - AWSOAuthManager, + DatabricksOAuthManager, AzureOAuthManager, } = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthManager'); const OAuthToken = require('../../../../../dist/connection/auth/DatabricksOAuth/OAuthToken').default; @@ -110,7 +110,7 @@ class OAuthClientMock { } } -[AWSOAuthManager, AzureOAuthManager].forEach((OAuthManagerClass) => { +[DatabricksOAuthManager, AzureOAuthManager].forEach((OAuthManagerClass) => { function prepareTestInstances(options) { const oauthClient = new OAuthClientMock(); sinon.stub(oauthClient, 'grant').callThrough();