Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/apps/desktop/src/api/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,17 +402,12 @@ pub async fn test_ai_config_connection(
result.response_time_ms + image_result.response_time_ms;

if !image_result.success {
let image_error = image_result
.error_details
.unwrap_or_else(|| "Unknown image input test error".to_string());
let merged = bitfun_core::util::types::ConnectionTestResult {
success: false,
response_time_ms,
model_response: image_result.model_response.or(result.model_response),
error_details: Some(format!(
"Basic connection passed, but multimodal image input test failed: {}",
image_error
)),
message_code: image_result.message_code,
error_details: image_result.error_details,
};
info!(
"AI config connection test completed: model={}, success={}, response_time={}ms",
Expand All @@ -425,7 +420,8 @@ pub async fn test_ai_config_connection(
success: true,
response_time_ms,
model_response: image_result.model_response.or(result.model_response),
error_details: None,
message_code: result.message_code,
error_details: result.error_details,
};
info!(
"AI config connection test completed: model={}, success={}, response_time={}ms",
Expand Down
90 changes: 11 additions & 79 deletions src/crates/core/src/infrastructure/ai/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,33 +156,6 @@ impl AIClient {
)
}

fn build_test_connection_extra_body(&self) -> Option<serde_json::Value> {
let provider = self.config.format.to_ascii_lowercase();
if !matches!(
provider.as_str(),
"openai" | "response" | "responses" | "nvidia" | "openrouter"
) {
return self.config.custom_request_body.clone();
}

let mut extra_body = self
.config
.custom_request_body
.clone()
.unwrap_or_else(|| serde_json::json!({}));

if let Some(extra_obj) = extra_body.as_object_mut() {
extra_obj
.entry("temperature".to_string())
.or_insert_with(|| serde_json::json!(0));
extra_obj
.entry("tool_choice".to_string())
.or_insert_with(|| serde_json::json!("required"));
}

Some(extra_body)
}

fn is_gemini_api_format(api_format: &str) -> bool {
matches!(
api_format.to_ascii_lowercase().as_str(),
Expand Down Expand Up @@ -1942,8 +1915,8 @@ impl AIClient {
pub async fn test_connection(&self) -> Result<ConnectionTestResult> {
let start_time = std::time::Instant::now();

// Force a tool call to avoid false negatives: some models may answer directly when
// `tool_choice=auto`, even if they support tool calls.
// Reuse the normal chat request path so the test matches real conversations, even when
// a provider rejects stricter tool_choice settings such as "required".
let test_messages = vec![Message::user(
"Call the get_weather tool for city=Beijing. Do not answer with plain text."
.to_string(),
Expand All @@ -1961,14 +1934,7 @@ impl AIClient {
}),
}]);

let extra_body = self.build_test_connection_extra_body();

let result = if extra_body.is_some() {
self.send_message_with_extra_body(test_messages, tools, extra_body)
.await
} else {
self.send_message(test_messages, tools).await
};
let result = self.send_message(test_messages, tools).await;

match result {
Ok(response) => {
Expand All @@ -1978,16 +1944,16 @@ impl AIClient {
success: true,
response_time_ms,
model_response: Some(response.text),
message_code: None,
error_details: None,
})
} else {
Ok(ConnectionTestResult {
success: false,
success: true,
response_time_ms,
model_response: Some(response.text),
error_details: Some(
"Model did not return tool calls (tool_choice=required).".to_string(),
),
message_code: Some(ConnectionTestMessageCode::ToolCallsNotDetected),
error_details: None,
})
}
}
Expand All @@ -1999,6 +1965,7 @@ impl AIClient {
success: false,
response_time_ms,
model_response: None,
message_code: None,
error_details: Some(error_msg),
})
}
Expand Down Expand Up @@ -2059,6 +2026,7 @@ impl AIClient {
success: true,
response_time_ms: start_time.elapsed().as_millis() as u64,
model_response: Some(response.text),
message_code: None,
error_details: None,
})
} else {
Expand All @@ -2071,6 +2039,7 @@ impl AIClient {
success: false,
response_time_ms: start_time.elapsed().as_millis() as u64,
model_response: Some(response.text),
message_code: Some(ConnectionTestMessageCode::ImageInputCheckFailed),
error_details: Some(detail),
})
}
Expand All @@ -2082,6 +2051,7 @@ impl AIClient {
success: false,
response_time_ms: start_time.elapsed().as_millis() as u64,
model_response: None,
message_code: None,
error_details: Some(error_msg),
})
}
Expand Down Expand Up @@ -2130,44 +2100,6 @@ mod tests {
})
}

#[test]
fn build_test_connection_extra_body_merges_custom_body_defaults() {
let client = make_test_client(
"responses",
Some(json!({
"metadata": {
"source": "test"
}
})),
);

let extra_body = client
.build_test_connection_extra_body()
.expect("extra body");

assert_eq!(extra_body["metadata"]["source"], "test");
assert_eq!(extra_body["temperature"], 0);
assert_eq!(extra_body["tool_choice"], "required");
}

#[test]
fn build_test_connection_extra_body_preserves_existing_tool_choice() {
let client = make_test_client(
"response",
Some(json!({
"tool_choice": "auto",
"temperature": 0.3
})),
);

let extra_body = client
.build_test_connection_extra_body()
.expect("extra body");

assert_eq!(extra_body["tool_choice"], "auto");
assert_eq!(extra_body["temperature"], 0.3);
}

#[test]
fn resolves_openai_models_url_from_completion_endpoint() {
let client = AIClient::new(AIConfig {
Expand Down
13 changes: 12 additions & 1 deletion src/crates/core/src/util/types/ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ pub struct GeminiUsage {
pub cached_content_token_count: Option<u32>,
}

/// Structured message codes for localized connection test messaging.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ConnectionTestMessageCode {
ToolCallsNotDetected,
ImageInputCheckFailed,
}

/// AI connection test result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionTestResult {
Expand All @@ -44,7 +52,10 @@ pub struct ConnectionTestResult {
/// Model response content (if successful)
#[serde(skip_serializing_if = "Option::is_none")]
pub model_response: Option<String>,
/// Error details (if failed)
/// Structured message code for localized frontend messaging
#[serde(skip_serializing_if = "Option::is_none")]
pub message_code: Option<ConnectionTestMessageCode>,
/// Raw error or diagnostic details
#[serde(skip_serializing_if = "Option::is_none")]
pub error_details: Option<String>,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { systemAPI } from '@/infrastructure/api';
import { Select, Checkbox, Button, IconButton } from '@/component-library';
import { PROVIDER_TEMPLATES } from '@/infrastructure/config/services/modelConfigs';
import { createLogger } from '@/shared/utils/logger';
import { translateConnectionTestMessage } from '@/shared/utils/aiConnectionTestMessages';

const log = createLogger('ModelConfigStep');

Expand Down Expand Up @@ -39,6 +40,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
);
const [testStatus, setTestStatus] = useState<TestStatus>('idle');
const [testError, setTestError] = useState<string>('');
const [testNotice, setTestNotice] = useState<string>('');
const [remoteModelOptions, setRemoteModelOptions] = useState<RemoteModelOption[]>([]);
const [isFetchingRemoteModels, setIsFetchingRemoteModels] = useState(false);
const [remoteModelsError, setRemoteModelsError] = useState<string>('');
Expand Down Expand Up @@ -250,6 +252,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
setSelectedProviderId(newProviderId);
setTestStatus('idle');
setTestError('');
setTestNotice('');

if (newProviderId === 'custom') {
setBaseUrl('');
Expand All @@ -274,6 +277,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
setModelName(value);
setTestStatus('idle');
setTestError('');
setTestNotice('');
}, []);

// Open help URL
Expand Down Expand Up @@ -309,6 +313,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }

setTestStatus('testing');
setTestError('');
setTestNotice('');

try {
const effectiveBaseUrl = baseUrl || (currentTemplate?.baseUrl || '');
Expand All @@ -321,23 +326,31 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
model_name: effectiveModelName,
provider: format
});
const localizedMessage = translateConnectionTestMessage(result.message_code, tAiModel);

if (result.success) {
setTestStatus('success');
setTestNotice(localizedMessage || result.error_details || '');
log.info('Connection test passed', {
provider: selectedProviderId,
modelName: effectiveModelName
});
} else {
setTestStatus('error');
const errorMsg = result.error_details
? `${t('model.testFailed')}\n${result.error_details}`
setTestNotice('');
const detailLines = [
localizedMessage,
result.error_details ? `${tAiModel('messages.errorDetails')}: ${result.error_details}` : undefined
].filter((line): line is string => Boolean(line));
const errorMsg = detailLines.length > 0
? `${t('model.testFailed')}\n${detailLines.join('\n')}`
: t('model.testFailed');
setTestError(errorMsg);
}
} catch (error) {
log.error('Connection test failed', error);
setTestStatus('error');
setTestNotice('');
const rawMsg = error instanceof Error ? error.message : String(error);
// Tauri command errors often have "Connection test failed: " prefix, extract the actual cause
const cleanMsg = rawMsg.replace(/^Connection test failed:\s*/i, '');
Expand Down Expand Up @@ -511,6 +524,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
setApiKey(e.target.value);
setTestStatus('idle');
setTestError('');
setTestNotice('');
}}
/>
{currentTemplate.helpUrl && (
Expand Down Expand Up @@ -542,6 +556,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
}
setTestStatus('idle');
setTestError('');
setTestNotice('');
}}
placeholder={t('model.baseUrl.placeholder')}
options={currentTemplate.baseUrlOptions.map(opt => ({
Expand All @@ -561,6 +576,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
setBaseUrl(e.target.value);
setTestStatus('idle');
setTestError('');
setTestNotice('');
}}
onFocus={(e) => e.target.select()}
/>
Expand All @@ -587,6 +603,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
setBaseUrl(e.target.value);
setTestStatus('idle');
setTestError('');
setTestNotice('');
}}
/>
</div>
Expand All @@ -602,6 +619,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
setModelName(value as string);
setTestStatus('idle');
setTestError('');
setTestNotice('');
}}
placeholder={t('model.modelName.placeholder')}
options={availableModelOptions}
Expand Down Expand Up @@ -644,6 +662,7 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
setApiKey(e.target.value);
setTestStatus('idle');
setTestError('');
setTestNotice('');
}}
/>
</div>
Expand Down Expand Up @@ -830,6 +849,13 @@ export const ModelConfigStep: React.FC<ModelConfigStepProps> = ({ onSkipForNow }
{testError}
</div>
)}

{testStatus === 'success' && testNotice && (
<div className="bitfun-onboarding-model__warning">
<AlertTriangle size={14} />
<span>{testNotice}</span>
</div>
)}
</>
)}

Expand Down
2 changes: 2 additions & 0 deletions src/web-ui/src/infrastructure/api/service-api/AIApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import { api } from './ApiClient';
import { createTauriCommandError } from '../errors/TauriCommandError';
import type { SendMessageRequest } from './tauri-commands';
import type { ConnectionTestMessageCode } from '@/shared/utils/aiConnectionTestMessages';

export interface CreateAISessionRequest {
session_id?: string;
Expand All @@ -19,6 +20,7 @@ export interface ConnectionTestResult {
success: boolean;
response_time_ms: number;
model_response?: string;
message_code?: ConnectionTestMessageCode;
error_details?: string;
}

Expand Down
Loading
Loading