Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions packages/cloud/src/CloudService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ export class CloudService extends EventEmitter<CloudServiceEvents> implements Di

// AuthService

public async login(landingPageSlug?: string): Promise<void> {
public async login(landingPageSlug?: string, useProviderSignup: boolean = false): Promise<void> {
this.ensureInitialized()
return this.authService!.login(landingPageSlug)
return this.authService!.login(landingPageSlug, useProviderSignup)
}

public async logout(): Promise<void> {
Expand Down Expand Up @@ -245,9 +245,10 @@ export class CloudService extends EventEmitter<CloudServiceEvents> implements Di
code: string | null,
state: string | null,
organizationId?: string | null,
providerModel?: string | null,
): Promise<void> {
this.ensureInitialized()
return this.authService!.handleCallback(code, state, organizationId)
return this.authService!.handleCallback(code, state, organizationId, providerModel)
}

public async switchOrganization(organizationId: string | null): Promise<void> {
Expand Down
3 changes: 2 additions & 1 deletion packages/cloud/src/StaticTokenAuthService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export class StaticTokenAuthService extends EventEmitter<AuthServiceEvents> impl
this.emit("user-info", { userInfo: this.userInfo })
}

public async login(): Promise<void> {
public async login(_landingPageSlug?: string, _useProviderSignup?: boolean): Promise<void> {
throw new Error("Authentication methods are disabled in StaticTokenAuthService")
}

Expand All @@ -59,6 +59,7 @@ export class StaticTokenAuthService extends EventEmitter<AuthServiceEvents> impl
_code: string | null,
_state: string | null,
_organizationId?: string | null,
_providerModel?: string | null,
): Promise<void> {
throw new Error("Authentication methods are disabled in StaticTokenAuthService")
}
Expand Down
17 changes: 14 additions & 3 deletions packages/cloud/src/WebAuthService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,9 @@ export class WebAuthService extends EventEmitter<AuthServiceEvents> implements A
* and opening the browser to the authorization URL.
*
* @param landingPageSlug Optional slug of a specific landing page (e.g., "supernova", "special-offer", etc.)
* @param useProviderSignup If true, uses provider signup flow (/extension/provider-sign-up). If false, uses standard sign-in (/extension/sign-in). Defaults to false.
*/
public async login(landingPageSlug?: string): Promise<void> {
public async login(landingPageSlug?: string, useProviderSignup: boolean = false): Promise<void> {
try {
const vscode = await importVscode()

Expand All @@ -272,10 +273,12 @@ export class WebAuthService extends EventEmitter<AuthServiceEvents> implements A
auth_redirect: `${vscode.env.uriScheme}://${publisher}.${name}`,
})

// Use landing page URL if slug is provided, otherwise use default sign-in URL
// Use landing page URL if slug is provided, otherwise use provider sign-up or sign-in URL based on parameter
const url = landingPageSlug
? `${getRooCodeApiUrl()}/l/${landingPageSlug}?${params.toString()}`
: `${getRooCodeApiUrl()}/extension/sign-in?${params.toString()}`
: useProviderSignup
? `${getRooCodeApiUrl()}/extension/provider-sign-up?${params.toString()}`
: `${getRooCodeApiUrl()}/extension/sign-in?${params.toString()}`

await vscode.env.openExternal(vscode.Uri.parse(url))
} catch (error) {
Expand All @@ -294,11 +297,13 @@ export class WebAuthService extends EventEmitter<AuthServiceEvents> implements A
* @param code The authorization code from the callback
* @param state The state parameter from the callback
* @param organizationId The organization ID from the callback (null for personal accounts)
* @param providerModel The model ID selected during signup (optional)
*/
public async handleCallback(
code: string | null,
state: string | null,
organizationId?: string | null,
providerModel?: string | null,
): Promise<void> {
if (!code || !state) {
const vscode = await importVscode()
Expand Down Expand Up @@ -326,6 +331,12 @@ export class WebAuthService extends EventEmitter<AuthServiceEvents> implements A

await this.storeCredentials(credentials)

// Store the provider model if provided
if (providerModel) {
await this.context.globalState.update("roo-provider-model", providerModel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing validation: The providerModel parameter is stored without validation. If an invalid model ID is passed (malformed format, non-existent model, or incompatible with the Roo provider), it will be stored and later applied in extension.ts without checks, potentially breaking the API configuration. Consider validating the model ID format and/or existence before storing.

Fix it with Roo Code or mention @roomote and request a fix.

this.log(`[auth] Stored provider model: ${providerModel}`)
}
Comment on lines +334 to +338
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error cleanup: If authentication fails after storing the provider model but before credential storage completes, the stored model persists in global state. On the next successful login, this orphaned model could be unexpectedly applied. The stored model should be cleaned up in the catch block to prevent this scenario.

Fix it with Roo Code or mention @roomote and request a fix.


const vscode = await importVscode()

if (vscode) {
Expand Down
14 changes: 12 additions & 2 deletions packages/cloud/src/__tests__/CloudService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,22 @@ describe("CloudService", () => {

it("should delegate handleAuthCallback to AuthService", async () => {
await cloudService.handleAuthCallback("code", "state")
expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", undefined)
expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", undefined, undefined)
})

it("should delegate handleAuthCallback with organizationId to AuthService", async () => {
await cloudService.handleAuthCallback("code", "state", "org_123")
expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", "org_123")
expect(mockAuthService.handleCallback).toHaveBeenCalledWith("code", "state", "org_123", undefined)
})

it("should delegate handleAuthCallback with providerModel to AuthService", async () => {
await cloudService.handleAuthCallback("code", "state", "org_123", "xai/grok-code-fast-1")
expect(mockAuthService.handleCallback).toHaveBeenCalledWith(
"code",
"state",
"org_123",
"xai/grok-code-fast-1",
)
})

it("should return stored organization ID from AuthService", () => {
Expand Down
49 changes: 48 additions & 1 deletion packages/cloud/src/__tests__/WebAuthService.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ describe("WebAuthService", () => {
)
})

it("should use package.json values for redirect URI", async () => {
it("should use package.json values for redirect URI with default sign-in endpoint", async () => {
const mockOpenExternal = vi.fn()
const vscode = await import("vscode")
vi.mocked(vscode.env.openExternal).mockImplementation(mockOpenExternal)
Expand All @@ -281,6 +281,26 @@ describe("WebAuthService", () => {
expect(calledUri.toString()).toBe(expectedUrl)
})

it("should use provider signup URL when useProviderSignup is true", async () => {
const mockOpenExternal = vi.fn()
const vscode = await import("vscode")
vi.mocked(vscode.env.openExternal).mockImplementation(mockOpenExternal)

await authService.login(undefined, true)

const expectedUrl =
"https://api.test.com/extension/provider-sign-up?state=746573742d72616e646f6d2d6279746573&auth_redirect=vscode%3A%2F%2FRooVeterinaryInc.roo-cline"
expect(mockOpenExternal).toHaveBeenCalledWith(
expect.objectContaining({
toString: expect.any(Function),
}),
)

// Verify the actual URL
const calledUri = mockOpenExternal.mock.calls[0]?.[0]
expect(calledUri.toString()).toBe(expectedUrl)
})

it("should handle errors during login", async () => {
vi.mocked(crypto.randomBytes).mockImplementation(() => {
throw new Error("Crypto error")
Expand Down Expand Up @@ -351,6 +371,33 @@ describe("WebAuthService", () => {
expect(mockShowInfo).toHaveBeenCalledWith("Successfully authenticated with Roo Code Cloud")
})

it("should store provider model when provided in callback", async () => {
const storedState = "valid-state"
mockContext.globalState.get.mockReturnValue(storedState)

// Mock successful Clerk sign-in response
const mockResponse = {
ok: true,
json: () =>
Promise.resolve({
response: { created_session_id: "session-123" },
}),
headers: {
get: (header: string) => (header === "authorization" ? "Bearer token-123" : null),
},
}
mockFetch.mockResolvedValue(mockResponse)

const vscode = await import("vscode")
const mockShowInfo = vi.fn()
vi.mocked(vscode.window.showInformationMessage).mockImplementation(mockShowInfo)

await authService.handleCallback("auth-code", storedState, null, "xai/grok-code-fast-1")

expect(mockContext.globalState.update).toHaveBeenCalledWith("roo-provider-model", "xai/grok-code-fast-1")
expect(mockLog).toHaveBeenCalledWith("[auth] Stored provider model: xai/grok-code-fast-1")
})

it("should handle Clerk API errors", async () => {
const storedState = "valid-state"
mockContext.globalState.get.mockReturnValue(storedState)
Expand Down
9 changes: 7 additions & 2 deletions packages/types/src/cloud.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,14 @@ export interface AuthService extends EventEmitter<AuthServiceEvents> {
broadcast(): void

// Authentication methods
login(landingPageSlug?: string): Promise<void>
login(landingPageSlug?: string, useProviderSignup?: boolean): Promise<void>
logout(): Promise<void>
handleCallback(code: string | null, state: string | null, organizationId?: string | null): Promise<void>
handleCallback(
code: string | null,
state: string | null,
organizationId?: string | null,
providerModel?: string | null,
): Promise<void>
switchOrganization(organizationId: string | null): Promise<void>

// State methods
Expand Down
2 changes: 2 additions & 0 deletions src/activate/handleUri.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ export const handleUri = async (uri: vscode.Uri) => {
const code = query.get("code")
const state = query.get("state")
const organizationId = query.get("organizationId")
const providerModel = query.get("provider_model")

await CloudService.instance.handleAuthCallback(
code,
state,
organizationId === "null" ? null : organizationId,
providerModel,
)
break
}
Expand Down
3 changes: 2 additions & 1 deletion src/core/webview/webviewMessageHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,8 @@ export const webviewMessageHandler = async (
case "rooCloudSignIn": {
try {
TelemetryService.instance.captureEvent(TelemetryEventName.AUTHENTICATION_INITIATED)
await CloudService.instance.login()
// Use provider signup flow if useProviderSignup is explicitly true
await CloudService.instance.login(undefined, message.useProviderSignup ?? false)
} catch (error) {
provider.log(`AuthService#login failed: ${error}`)
vscode.window.showErrorMessage("Sign in failed.")
Expand Down
25 changes: 25 additions & 0 deletions src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,31 @@ export async function activate(context: vscode.ExtensionContext) {

if (data.state === "active-session" || data.state === "logged-out") {
await handleRooModelsCache()

// Apply stored provider model to API configuration if present
if (data.state === "active-session") {
try {
const storedModel = context.globalState.get<string>("roo-provider-model")
if (storedModel) {
cloudLogger(`[authStateChangedHandler] Applying stored provider model: ${storedModel}`)
// Get the current API configuration name
const currentConfigName =
provider.contextProxy.getGlobalState("currentApiConfigName") || "default"
// Update it with the stored model using upsertProviderProfile
await provider.upsertProviderProfile(currentConfigName, {
apiProvider: "roo",
apiModelId: storedModel,
})
// Clear the stored model after applying
await context.globalState.update("roo-provider-model", undefined)
cloudLogger(`[authStateChangedHandler] Applied and cleared stored provider model`)
}
Comment on lines +173 to +189
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition: The stored provider model is applied asynchronously during auth state change, but there's no guarantee it completes before the user dismisses the welcome screen or makes other configuration changes. If authentication completes quickly, the model could be applied after the user has already selected a different model, overwriting their choice. Consider using a promise or flag to ensure the model application completes before allowing further configuration changes.

Fix it with Roo Code or mention @roomote and request a fix.

} catch (error) {
cloudLogger(
`[authStateChangedHandler] Failed to apply stored provider model: ${error instanceof Error ? error.message : String(error)}`,
)
}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/shared/WebviewMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ export interface WebviewMessage {
upsellId?: string // For dismissUpsell
list?: string[] // For dismissedUpsells response
organizationId?: string | null // For organization switching
useProviderSignup?: boolean // For rooCloudSignIn to use provider signup flow
codeIndexSettings?: {
// Global state settings
codebaseIndexEnabled: boolean
Expand Down
18 changes: 17 additions & 1 deletion webview-ui/src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import React, { useCallback, useEffect, useRef, useState, useMemo } from "react"
import { useEvent } from "react-use"
import { QueryClient, QueryClientProvider } from "@tanstack/react-query"
import posthog from "posthog-js"

import { ExtensionMessage } from "@roo/ExtensionMessage"
import TranslationProvider from "./i18n/TranslationContext"
Expand All @@ -15,6 +16,7 @@ import ChatView, { ChatViewRef } from "./components/chat/ChatView"
import HistoryView from "./components/history/HistoryView"
import SettingsView, { SettingsViewRef } from "./components/settings/SettingsView"
import WelcomeView from "./components/welcome/WelcomeView"
import WelcomeViewProvider from "./components/welcome/WelcomeViewProvider"
import McpView from "./components/mcp/McpView"
import { MarketplaceView } from "./components/marketplace/MarketplaceView"
import ModesView from "./components/modes/ModesView"
Expand Down Expand Up @@ -81,6 +83,16 @@ const App = () => {
mdmCompliant,
} = useExtensionState()

const [useProviderSignupView, setUseProviderSignupView] = useState(false)

// Check PostHog feature flag for provider signup view
useEffect(() => {
posthog.onFeatureFlags(function () {
// Feature flag for new provider-focused welcome view
setUseProviderSignupView(posthog?.getFeatureFlag("welcome-provider-signup") === "test")
})
}, [])

// Create a persistent state manager
const marketplaceStateManager = useMemo(() => new MarketplaceViewStateManager(), [])

Expand Down Expand Up @@ -247,7 +259,11 @@ const App = () => {
// Do not conditionally load ChatView, it's expensive and there's state we
// don't want to lose (user input, disableInput, askResponse promise, etc.)
return showWelcome ? (
<WelcomeView />
useProviderSignupView ? (
<WelcomeViewProvider />
) : (
<WelcomeView />
)
) : (
<>
{tab === "modes" && <ModesView onDone={() => switchTab("chat")} />}
Expand Down
4 changes: 2 additions & 2 deletions webview-ui/src/components/welcome/WelcomeView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ const WelcomeView = () => {
<Tab>
<TabContent className="flex flex-col gap-4 p-6">
<RooHero />
<h2 className="mt-0 mb-4 text-xl text-center">{t("welcome:greeting")}</h2>
<h2 className="mt-0 mb-4 text-xl">{t("welcome:greeting")}</h2>

<div className="text-base text-vscode-foreground py-2 px-2 mb-4">
<div className="text-base text-vscode-foreground py-2 mb-4">
<p className="mb-3 leading-relaxed">
<Trans i18nKey="welcome:introduction" />
</p>
Expand Down
Loading
Loading