Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 167 additions & 24 deletions src/main/presenter/githubCopilotDeviceFlow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,26 @@ export interface AccessTokenResponse {
error_description?: string
}

export interface CopilotTokenResponse {
token: string
expires_at: number
refresh_in?: number
}

export interface ApiToken {
apiKey: string
expiresAt: Date
}

export interface CopilotConfig {
oauthToken?: string
apiToken?: ApiToken
}

export class GitHubCopilotDeviceFlow {
private config: DeviceFlowConfig
private pollingInterval: NodeJS.Timeout | null = null
private oauthToken: string | null = null

constructor(config: DeviceFlowConfig) {
this.config = config
Expand All @@ -49,8 +66,84 @@ export class GitHubCopilotDeviceFlow {

return accessToken
} catch (error) {
console.error('Failed to start device flow', error)
throw new Error('Failed to start device flow')
console.error('[GitHub Copilot] Device flow failed:', error)
throw new Error(
`Device flow authentication failed: ${error instanceof Error ? error.message : 'Unknown error'}`
)
}
}

/**
* 获取 Copilot API token
* 使用OAuth token交换Copilot API token
*/
public async getCopilotToken(): Promise<string> {
if (!this.oauthToken) {
throw new Error('No OAuth token available')
}

// 使用OAuth token从GitHub API获取Copilot token
const tokenUrl = 'https://api.github.com/copilot_internal/v2/token'

try {
const response = await fetch(tokenUrl, {
method: 'GET',
headers: {
Authorization: `Bearer ${this.oauthToken}`,
Accept: 'application/json',
'User-Agent': 'DeepChat/1.0.0'
}
})

if (!response.ok) {
const errorText = await response.text().catch(() => '')
throw new Error(
`Failed to get Copilot token: ${response.status} ${response.statusText} - ${errorText}`
)
}

const data = (await response.json()) as { token: string; expires_at: number }
return data.token
} catch (error) {
console.error('[GitHub Copilot][DeviceFlow] Failed to get Copilot token:', error)
throw error
}
}

/**
* 检查是否已经有有效的认证状态
*/
public async checkExistingAuth(externalToken?: string): Promise<string | null> {
try {
// 如果提供了外部 token,使用它
if (externalToken) {
this.oauthToken = externalToken

// 尝试获取 API token 来验证认证状态
try {
await this.getCopilotToken()
return this.oauthToken
} catch {
this.oauthToken = null
return null
}
}

// 检查内部存储的 token
if (this.oauthToken) {
// 尝试获取 API token 来验证认证状态
try {
await this.getCopilotToken()
return this.oauthToken!
} catch {
this.oauthToken = null
}
}

return null
} catch (error) {
console.warn('[GitHub Copilot][DeviceFlow] Error checking existing auth:', error)
return null
}
}

Expand All @@ -64,23 +157,29 @@ export class GitHubCopilotDeviceFlow {
scope: this.config.scope
}

const response = await fetch(url, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
'User-Agent': 'DeepChat/1.0.0'
},
body: JSON.stringify(body)
})

if (!response.ok) {
throw new Error(`Failed to request device code: ${response.status} ${response.statusText}`)
}
try {
const response = await fetch(url, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
'User-Agent': 'DeepChat/1.0.0'
},
body: JSON.stringify(body)
})

const data = (await response.json()) as DeviceCodeResponse
if (!response.ok) {
const errorText = await response.text().catch(() => '')
throw new Error(
`Failed to request device code: ${response.status} ${response.statusText} - ${errorText}`
)
}

return data
const data = (await response.json()) as DeviceCodeResponse
return data
} catch (error) {
throw error
}
Comment on lines +180 to +182
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't pass linting here.

}

/**
Expand Down Expand Up @@ -364,11 +463,18 @@ export class GitHubCopilotDeviceFlow {
const startTime = Date.now()
const expiresAt = startTime + deviceCodeResponse.expires_in * 1000
let pollCount = 0
let currentInterval = deviceCodeResponse.interval

const poll = async () => {
pollCount++
if (pollCount > 50) {
reject(new Error('Poll count exceeded'))

if (pollCount > 100) {
// 增加最大轮询次数
if (this.pollingInterval) {
clearInterval(this.pollingInterval)
this.pollingInterval = null
}
reject(new Error('Maximum polling attempts exceeded'))
return
}

Expand Down Expand Up @@ -409,10 +515,11 @@ export class GitHubCopilotDeviceFlow {
return // Continue polling

case 'slow_down':
// Increase polling interval
// Increase polling interval by at least 5 seconds as per OAuth 2.0 spec
currentInterval += 5
if (this.pollingInterval) {
clearInterval(this.pollingInterval)
this.pollingInterval = setInterval(poll, (deviceCodeResponse.interval + 5) * 1000)
this.pollingInterval = setInterval(poll, currentInterval * 1000)
}
return

Expand All @@ -421,18 +528,22 @@ export class GitHubCopilotDeviceFlow {
clearInterval(this.pollingInterval)
this.pollingInterval = null
}
reject(new Error('Device code expired'))
reject(new Error('Device code expired during polling'))
return

case 'access_denied':
if (this.pollingInterval) {
clearInterval(this.pollingInterval)
this.pollingInterval = null
}
reject(new Error('User denied access'))
reject(new Error('User denied access to GitHub Copilot'))
return

default:
if (this.pollingInterval) {
clearInterval(this.pollingInterval)
this.pollingInterval = null
}
reject(new Error(`OAuth error: ${data.error_description || data.error}`))
return
}
Expand All @@ -447,7 +558,8 @@ export class GitHubCopilotDeviceFlow {
return
}
} catch {
// ignore
// Continue polling, network errors may be temporary
return
}
}

Expand All @@ -468,6 +580,13 @@ export class GitHubCopilotDeviceFlow {
this.pollingInterval = null
}
}

/**
* 清理资源
*/
public dispose(): void {
this.stopPolling()
}
}

// GitHub Copilot Device Flow configuration
Expand All @@ -486,5 +605,29 @@ export function createGitHubCopilotDeviceFlow(): GitHubCopilotDeviceFlow {
scope: 'read:user read:org'
}

console.log('[GitHub Copilot][DeviceFlow] Creating device flow with config:', {
clientIdConfigured: !!clientId,
scope: config.scope
})

return new GitHubCopilotDeviceFlow(config)
}

/**
* 创建一个全局的 GitHub Copilot Device Flow 实例
*/
let globalDeviceFlowInstance: GitHubCopilotDeviceFlow | null = null

export function getGlobalGitHubCopilotDeviceFlow(): GitHubCopilotDeviceFlow {
if (!globalDeviceFlowInstance) {
globalDeviceFlowInstance = createGitHubCopilotDeviceFlow()
}
return globalDeviceFlowInstance
}

export function disposeGlobalGitHubCopilotDeviceFlow(): void {
if (globalDeviceFlowInstance) {
globalDeviceFlowInstance.dispose()
globalDeviceFlowInstance = null
}
}
9 changes: 3 additions & 6 deletions src/main/presenter/githubCopilotOAuth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export class GitHubCopilotOAuth {
this.closeWindow()
reject(new Error(`GitHub authorization failed: ${error}`))
} else if (code) {
console.log('OAuth success, received code:', code)
console.log('OAuth success, received authorization code')
this.closeWindow()
resolve(code)
} else {
Expand Down Expand Up @@ -266,13 +266,10 @@ export function createGitHubCopilotOAuth(): GitHubCopilotOAuth {
}
if (is.dev) {
console.log('Final OAuth config:', {
clientId:
config.clientId.substring(0, 4) +
'****' +
config.clientId.substring(config.clientId.length - 4),
clientIdConfigured: !!config.clientId,
redirectUri: config.redirectUri,
scope: config.scope,
clientSecretLength: config.clientSecret.length
clientSecretConfigured: !!config.clientSecret
})
}

Expand Down
13 changes: 11 additions & 2 deletions src/main/presenter/llmProviderPresenter/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1531,8 +1531,17 @@ export class LLMProviderPresenter implements ILlmProviderPresenter {
return { isOk: false, errorMsg: `Model test failed: ${errorMessage}` }
}
} else {
// 如果没有提供modelId,使用provider的check方法进行基础验证
return await provider.check()
// 如果没有提供modelId,使用provider自己的check方法进行基本验证
console.log(
`[LLMProviderPresenter] No modelId provided, using provider's own check method for ${providerId}`
)
try {
return await provider.check()
} catch (error) {
console.error(`Provider ${providerId} check failed:`, error)
const errorMessage = error instanceof Error ? error.message : String(error)
return { isOk: false, errorMsg: `Provider check failed: ${errorMessage}` }
}
}
} catch (error) {
console.error(`Provider ${providerId} check failed:`, error)
Expand Down
2 changes: 1 addition & 1 deletion src/main/presenter/llmProviderPresenter/oauthHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ export class OAuthHelper {
eventBus.send(CONFIG_EVENTS.OAUTH_LOGIN_ERROR, SendTarget.ALL_WINDOWS, error)
reject(new Error(`OAuth授权失败: ${error}`))
} else if (code) {
console.log('OAuth success, received code:', code)
console.log('OAuth success, received authorization code')
eventBus.send(CONFIG_EVENTS.OAUTH_LOGIN_SUCCESS, SendTarget.ALL_WINDOWS, code)
resolve(code)
} else {
Expand Down
Loading