diff --git a/access-control-service/src/main/scala/org/apache/texera/service/AccessControlService.scala b/access-control-service/src/main/scala/org/apache/texera/service/AccessControlService.scala index bed201c713b..63941213726 100644 --- a/access-control-service/src/main/scala/org/apache/texera/service/AccessControlService.scala +++ b/access-control-service/src/main/scala/org/apache/texera/service/AccessControlService.scala @@ -25,7 +25,12 @@ import io.dropwizard.core.setup.{Bootstrap, Environment} import org.apache.amber.config.StorageConfig import org.apache.texera.auth.{JwtAuthFilter, SessionUser} import org.apache.texera.dao.SqlServer -import org.apache.texera.service.resource.{AccessControlResource, HealthCheckResource} +import org.apache.texera.service.resource.{ + AccessControlResource, + HealthCheckResource, + LiteLLMModelsResource, + LiteLLMProxyResource +} import org.eclipse.jetty.server.session.SessionHandler import java.nio.file.Path @@ -54,6 +59,8 @@ class AccessControlService extends Application[AccessControlServiceConfiguration environment.jersey.register(classOf[HealthCheckResource]) environment.jersey.register(classOf[AccessControlResource]) + environment.jersey.register(classOf[LiteLLMProxyResource]) + environment.jersey.register(classOf[LiteLLMModelsResource]) // Register JWT authentication filter environment.jersey.register(new AuthDynamicFeature(classOf[JwtAuthFilter])) diff --git a/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala b/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala index 68c278dc71e..1f2e2bc11f9 100644 --- a/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala +++ b/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala @@ -20,11 +20,13 @@ package org.apache.texera.service.resource import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import com.typesafe.scalalogging.LazyLogging +import jakarta.ws.rs.client.{Client, ClientBuilder, Entity} import jakarta.ws.rs.core._ -import jakarta.ws.rs.{GET, POST, Path, Produces} +import jakarta.ws.rs.{Consumes, GET, POST, Path, Produces} import org.apache.texera.auth.JwtParser.parseToken import org.apache.texera.auth.SessionUser import org.apache.texera.auth.util.{ComputingUnitAccess, HeaderField} +import org.apache.texera.config.{GuiConfig, LLMConfig} import org.apache.texera.dao.jooq.generated.enums.PrivilegeEnum import java.net.URLDecoder @@ -203,3 +205,127 @@ class AccessControlResource extends LazyLogging { AccessControlResource.authorize(uriInfo, headers, Option(body).map(_.trim).filter(_.nonEmpty)) } } + +@Path("/chat") +@Produces(Array(MediaType.APPLICATION_JSON)) +@Consumes(Array(MediaType.APPLICATION_JSON)) +class LiteLLMProxyResource extends LazyLogging { + + private val client: Client = ClientBuilder.newClient() + private val litellmBaseUrl: String = LLMConfig.baseUrl + private val litellmApiKey: String = LLMConfig.masterKey + + @POST + @Path("/{path:.*}") + def proxyPost( + @Context uriInfo: UriInfo, + @Context headers: HttpHeaders, + body: String + ): Response = { + if (!GuiConfig.guiWorkflowWorkspaceCopilotEnabled) { + return Response + .status(Response.Status.FORBIDDEN) + .entity("""{"error": "Copilot feature is disabled"}""") + .build() + } + + // uriInfo.getPath returns "chat/completions" for /api/chat/completions + // We want to forward as "/chat/completions" to LiteLLM + val fullPath = uriInfo.getPath + val targetUrl = s"$litellmBaseUrl/$fullPath" + + logger.info(s"Proxying POST request to LiteLLM: $targetUrl") + + try { + val requestBuilder = client + .target(targetUrl) + .request(MediaType.APPLICATION_JSON) + .header("Authorization", s"Bearer $litellmApiKey") + + // Forward other relevant headers from the original request + headers.getRequestHeaders.asScala.foreach { + case (key, values) + if !key.equalsIgnoreCase("Authorization") && + !key.equalsIgnoreCase("Host") && + !key.equalsIgnoreCase("Content-Length") => + values.asScala.foreach(value => requestBuilder.header(key, value)) + case _ => // Skip Authorization, Host, and Content-Length headers + } + + val response = requestBuilder.post(Entity.json(body)) + + // Build response with same status and body from LiteLLM + val responseBody = response.readEntity(classOf[String]) + val responseBuilder = Response + .status(response.getStatus) + .entity(responseBody) + + // Forward response headers + response.getHeaders.asScala.foreach { + case (key, values) => + values.asScala.foreach(value => responseBuilder.header(key, value)) + } + + responseBuilder.build() + } catch { + case e: Exception => + logger.error(s"Error proxying request to LiteLLM: ${e.getMessage}", e) + Response + .status(Response.Status.BAD_GATEWAY) + .entity(s"""{"error": "Failed to proxy request to LiteLLM: ${e.getMessage}"}""") + .build() + } + } +} + +@Path("/models") +@Produces(Array(MediaType.APPLICATION_JSON)) +class LiteLLMModelsResource extends LazyLogging { + + private val client: Client = ClientBuilder.newClient() + private val litellmBaseUrl: String = LLMConfig.baseUrl + private val litellmApiKey: String = LLMConfig.masterKey + + @GET + def getModels: Response = { + if (!GuiConfig.guiWorkflowWorkspaceCopilotEnabled) { + return Response + .status(Response.Status.FORBIDDEN) + .entity("""{"error": "Copilot feature is disabled"}""") + .build() + } + + val targetUrl = s"$litellmBaseUrl/models" + + logger.info(s"Fetching models from LiteLLM: $targetUrl") + + try { + val response = client + .target(targetUrl) + .request(MediaType.APPLICATION_JSON) + .header("Authorization", s"Bearer $litellmApiKey") + .get() + + // Build response with same status and body from LiteLLM + val responseBody = response.readEntity(classOf[String]) + val responseBuilder = Response + .status(response.getStatus) + .entity(responseBody) + + // Forward response headers + response.getHeaders.asScala.foreach { + case (key, values) => + values.asScala.foreach(value => responseBuilder.header(key, value)) + } + + responseBuilder.build() + } catch { + case e: Exception => + logger.error(s"Error fetching models from LiteLLM: ${e.getMessage}", e) + Response + .status(Response.Status.BAD_GATEWAY) + .entity(s"""{"error": "Failed to fetch models from LiteLLM: ${e.getMessage}"}""") + .build() + } + } +} diff --git a/bin/litellm-config.yaml b/bin/litellm-config.yaml new file mode 100644 index 00000000000..6a75d7e0013 --- /dev/null +++ b/bin/litellm-config.yaml @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# The default configuration file for starting litellm (https://docs.litellm.ai/docs/proxy/quick_start) +# To start the litellm service: +# 1. Install litellm by: +# pip install 'litellm[proxy]' +# 2. Set your API keys as environment variable, e.g. +# export ANTHROPIC_API_KEY= +# 3. Start litellm by: +# litellm --config bin/litellm-config.yaml +# By default, litellm is running on http://0.0.0.0:4000 +model_list: + - model_name: claude-haiku-4.5 + litellm_params: + model: claude-haiku-4-5-20251001 + api_key: "os.environ/ANTHROPIC_API_KEY" + - model_name: gpt-5-mini + litellm_params: + model: gpt-5-mini-2025-08-07 + api_key: "os.environ/OPENAI_API_KEY" \ No newline at end of file diff --git a/common/config/src/main/resources/gui.conf b/common/config/src/main/resources/gui.conf index 601c63e92e8..f73cba82c3e 100644 --- a/common/config/src/main/resources/gui.conf +++ b/common/config/src/main/resources/gui.conf @@ -104,5 +104,9 @@ gui { # amount of time to be elapsed in minutes before user is detected as inactive active-time-in-minutes = 15 active-time-in-minutes = ${?GUI_WORKFLOW_WORKSPACE_ACTIVE_TIME_IN_MINUTES} + + # whether AI copilot feature is enabled + copilot-enabled = false + copilot-enabled = ${?GUI_WORKFLOW_WORKSPACE_COPILOT_ENABLED} } } \ No newline at end of file diff --git a/common/config/src/main/resources/llm.conf b/common/config/src/main/resources/llm.conf new file mode 100644 index 00000000000..23b9360cdab --- /dev/null +++ b/common/config/src/main/resources/llm.conf @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# LLM Configuration +llm { + # Base URL for LiteLLM service + base-url = "http://0.0.0.0:4000" + base-url = ${?LITELLM_BASE_URL} + + # Master key for LiteLLM authentication + master-key = "" + master-key = ${?LITELLM_MASTER_KEY} +} diff --git a/common/config/src/main/scala/org/apache/texera/config/GuiConfig.scala b/common/config/src/main/scala/org/apache/texera/config/GuiConfig.scala index 170ff5e90e5..5e16529bb6f 100644 --- a/common/config/src/main/scala/org/apache/texera/config/GuiConfig.scala +++ b/common/config/src/main/scala/org/apache/texera/config/GuiConfig.scala @@ -67,4 +67,6 @@ object GuiConfig { conf.getBoolean("gui.workflow-workspace.workflow-email-notification-enabled") val guiWorkflowWorkspaceActiveTimeInMinutes: Int = conf.getInt("gui.workflow-workspace.active-time-in-minutes") + val guiWorkflowWorkspaceCopilotEnabled: Boolean = + conf.getBoolean("gui.workflow-workspace.copilot-enabled") } diff --git a/common/config/src/main/scala/org/apache/texera/config/LLMConfig.scala b/common/config/src/main/scala/org/apache/texera/config/LLMConfig.scala new file mode 100644 index 00000000000..a85b734bad6 --- /dev/null +++ b/common/config/src/main/scala/org/apache/texera/config/LLMConfig.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.texera.config + +import com.typesafe.config.{Config, ConfigFactory} + +object LLMConfig { + private val conf: Config = ConfigFactory.parseResources("llm.conf").resolve() + + // LLM Service Configuration + val baseUrl: String = conf.getString("llm.base-url") + val masterKey: String = conf.getString("llm.master-key") +} diff --git a/config-service/src/main/scala/org/apache/texera/service/resource/ConfigResource.scala b/config-service/src/main/scala/org/apache/texera/service/resource/ConfigResource.scala index 8a74ea8d356..aa907ec3030 100644 --- a/config-service/src/main/scala/org/apache/texera/service/resource/ConfigResource.scala +++ b/config-service/src/main/scala/org/apache/texera/service/resource/ConfigResource.scala @@ -55,6 +55,7 @@ class ConfigResource { "password" -> GuiConfig.guiLoginDefaultLocalUserPassword ), "activeTimeInMinutes" -> GuiConfig.guiWorkflowWorkspaceActiveTimeInMinutes, + "copilotEnabled" -> GuiConfig.guiWorkflowWorkspaceCopilotEnabled, // flags from the auth.conf if needed "expirationTimeInMinutes" -> AuthConfig.jwtExpirationMinutes ) diff --git a/frontend/package.json b/frontend/package.json index f1a0e0cc801..c18858f484d 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -8,10 +8,10 @@ "scripts": { "start": "concurrently --kill-others \"npx y-websocket\" \"ng serve\"", "build": "ng build --configuration=production --progress=false --source-map=false", - "build:ci": "nx build --configuration=production --progress=false --source-map=false", + "build:ci": "node --max-old-space-size=8192 ./node_modules/nx/bin/nx build --configuration=production --progress=false --source-map=false", "analyze": "ng build --configuration=production --stats-json && webpack-bundle-analyzer dist/stats.json", "test": "ng test --watch=false", - "test:ci": "node --max-old-space-size=6144 ./node_modules/nx/bin/nx test --watch=false --progress=false", + "test:ci": "node --max-old-space-size=8192 ./node_modules/nx/bin/nx test --watch=false --progress=false", "prettier:fix": "prettier --write ./src", "lint": "ng lint", "eslint:fix": "yarn eslint --fix ./src", @@ -22,6 +22,7 @@ "private": true, "dependencies": { "@abacritt/angularx-social-login": "2.3.0", + "@ai-sdk/openai": "2.0.67", "@ali-hm/angular-tree-component": "12.0.5", "@angular/animations": "16.2.12", "@angular/cdk": "16.2.12", @@ -45,6 +46,7 @@ "@stoplight/json-ref-resolver": "3.1.5", "@types/lodash-es": "4.17.4", "@types/plotly.js-basic-dist-min": "2.12.4", + "ai": "5.0.93", "ajv": "8.10.0", "backbone": "1.4.1", "concaveman": "2.0.0", @@ -94,6 +96,7 @@ "y-quill": "0.1.5", "y-websocket": "1.4.0", "yjs": "13.5.41", + "zod": "3.25.76", "zone.js": "0.13.0" }, "resolutions": { diff --git a/frontend/proxy.config.json b/frontend/proxy.config.json index 3acc9a480c7..db8b690318a 100755 --- a/frontend/proxy.config.json +++ b/frontend/proxy.config.json @@ -1,4 +1,14 @@ { + "/api/models": { + "target": "http://localhost:9096", + "secure": false, + "changeOrigin": true + }, + "/api/chat/completion": { + "target": "http://localhost:9096", + "secure": false, + "changeOrigin": true + }, "/api/compile": { "target": "http://localhost:9090", "secure": false, diff --git a/frontend/src/app/app.module.ts b/frontend/src/app/app.module.ts index 257c2961fa4..73feecfdba3 100644 --- a/frontend/src/app/app.module.ts +++ b/frontend/src/app/app.module.ts @@ -20,7 +20,7 @@ import { DatePipe, registerLocaleData } from "@angular/common"; import { HTTP_INTERCEPTORS, HttpClientModule } from "@angular/common/http"; import en from "@angular/common/locales/en"; -import { APP_INITIALIZER, NgModule } from "@angular/core"; +import { APP_INITIALIZER, NgModule, CUSTOM_ELEMENTS_SCHEMA } from "@angular/core"; import { FormsModule, ReactiveFormsModule } from "@angular/forms"; import { BrowserModule } from "@angular/platform-browser"; import { BrowserAnimationsModule } from "@angular/platform-browser/animations"; @@ -100,8 +100,12 @@ import { NzPopconfirmModule } from "ng-zorro-antd/popconfirm"; import { AdminGuardService } from "./dashboard/service/admin/guard/admin-guard.service"; import { ContextMenuComponent } from "./workspace/component/workflow-editor/context-menu/context-menu/context-menu.component"; import { CoeditorUserIconComponent } from "./workspace/component/menu/coeditor-user-icon/coeditor-user-icon.component"; +import { AgentPanelComponent } from "./workspace/component/agent-panel/agent-panel.component"; +import { AgentChatComponent } from "./workspace/component/agent-panel/agent-chat/agent-chat.component"; +import { AgentRegistrationComponent } from "./workspace/component/agent-panel/agent-registration/agent-registration.component"; import { InputAutoCompleteComponent } from "./workspace/component/input-autocomplete/input-autocomplete.component"; import { CollabWrapperComponent } from "./common/formly/collab-wrapper/collab-wrapper/collab-wrapper.component"; +import { TexeraCopilot } from "./workspace/service/copilot/texera-copilot"; import { NzSwitchModule } from "ng-zorro-antd/switch"; import { AboutComponent } from "./hub/component/about/about.component"; import { NzLayoutModule } from "ng-zorro-antd/layout"; @@ -130,6 +134,7 @@ import { WorkflowRuntimeStatisticsComponent } from "./dashboard/component/user/u import { TimeTravelComponent } from "./workspace/component/left-panel/time-travel/time-travel.component"; import { NzMessageModule } from "ng-zorro-antd/message"; import { NzModalModule } from "ng-zorro-antd/modal"; +import { NzDescriptionsModule } from "ng-zorro-antd/descriptions"; import { OverlayModule } from "@angular/cdk/overlay"; import { HighlightSearchTermsPipe } from "./dashboard/component/user/user-workflow/user-workflow-list-item/highlight-search-terms.pipe"; import { en_US, provideNzI18n } from "ng-zorro-antd/i18n"; @@ -242,6 +247,9 @@ registerLocaleData(en); LocalLoginComponent, ContextMenuComponent, CoeditorUserIconComponent, + AgentPanelComponent, + AgentChatComponent, + AgentRegistrationComponent, InputAutoCompleteComponent, FileSelectionComponent, CollabWrapperComponent, @@ -307,6 +315,7 @@ registerLocaleData(en); NgxJsonViewerModule, NzMessageModule, NzModalModule, + NzDescriptionsModule, NzCardModule, NzTagModule, NzPopconfirmModule, @@ -345,6 +354,7 @@ registerLocaleData(en); GuiConfigService, FileSaverService, ReportGenerationService, + TexeraCopilot, { provide: HTTP_INTERCEPTORS, useClass: BlobErrorHttpInterceptor, @@ -381,5 +391,6 @@ registerLocaleData(en); }, ], bootstrap: [AppComponent], + schemas: [CUSTOM_ELEMENTS_SCHEMA], }) export class AppModule {} diff --git a/frontend/src/app/common/service/gui-config.service.mock.ts b/frontend/src/app/common/service/gui-config.service.mock.ts index 392f8447eec..610169a7862 100644 --- a/frontend/src/app/common/service/gui-config.service.mock.ts +++ b/frontend/src/app/common/service/gui-config.service.mock.ts @@ -48,6 +48,7 @@ export class MockGuiConfigService { defaultLocalUser: { username: "", password: "" }, expirationTimeInMinutes: 2880, activeTimeInMinutes: 15, + copilotEnabled: false, }; get env(): GuiConfig { diff --git a/frontend/src/app/common/type/gui-config.ts b/frontend/src/app/common/type/gui-config.ts index c634ebb6fe0..d9b4ad279ad 100644 --- a/frontend/src/app/common/type/gui-config.ts +++ b/frontend/src/app/common/type/gui-config.ts @@ -39,6 +39,7 @@ export interface GuiConfig { defaultLocalUser?: { username?: string; password?: string }; expirationTimeInMinutes: number; activeTimeInMinutes: number; + copilotEnabled: boolean; } export interface SidebarTabs { diff --git a/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.html b/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.html new file mode 100644 index 00000000000..a723f5789e1 --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.html @@ -0,0 +1,424 @@ + + + +
+ +
+
+ +
+ + Model: {{ agentInfo.modelType }} + +
+ + +
+ + Tokens: {{ getTotalInputTokens() }} in / {{ getTotalOutputTokens() }} out + +
+ + +
+ +
+
+
+ + +
+ +
+
+ +
+ + {{ response.role === 'user' ? 'You' : agentInfo.name }} +
+ + +
+
+ +
+
+ Execute {{ response.toolCalls.length }} tool{{ response.toolCalls.length > 1 ? 's' : '' }} +
+ + + +
+
+ + +
+
+ + {{ agentInfo.name }} +
+
+ + Thinking... +
+
+
+ + +
+ + + +
+
+ + +
+ + Agent is disconnected. Please check your connection. +
+
+ + + + +
+ +
+

+ + Token Usage +

+ + + {{ selectedResponse.usage.inputTokens || 0 }} + + + {{ selectedResponse.usage.outputTokens || 0 }} + + + {{ selectedResponse.usage.totalTokens || 0 }} + + + {{ selectedResponse.usage.cachedInputTokens || 0 }} + + +
+ + +
+

+ + Tool Calls ({{ selectedResponse.toolCalls.length }}) +

+ + +
+
+ Arguments: +
+
{{ formatJson(call.input || call.args) }}
+
+
+
+ Result: +
+
{{ formatJson(getToolResult(selectedResponse, idx)) }}
+
+
+
+ Operator Access: +
+
+
+ + + VIEWED: + + + {{ opId }} + +
+
+ + + MODIFIED: + + + {{ opId }} + +
+
+
+
+
+
+ + +
+ +
No additional details available for this response.
+
+
+
+
+ + + + +
+ +
+

+ + System Prompt +

+
+ +
+
+ + +
+

+ + Available Tools ({{ availableTools.length }}) +

+ + + +
+
+ Description: +
+
{{ tool.description }}
+
+ + +
+
+ Input Schema: +
+
{{ formatJson(tool.inputSchema) }}
+
+
+
+
+
+
+
diff --git a/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.scss b/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.scss new file mode 100644 index 00000000000..5c784194456 --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.scss @@ -0,0 +1,243 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +.agent-chat-container { + display: flex; + flex-direction: column; + height: 100%; + width: 100%; + background: white; +} + +.chat-toolbar { + display: flex; + justify-content: flex-end; + align-items: center; + padding: 8px 12px; + border-bottom: 1px solid #f0f0f0; + background: white; +} + +.chat-toolbar-controls { + display: flex; + width: 100%; + align-items: center; + justify-content: space-between; + gap: 24px; +} + +.toolbar-item { + display: flex; + align-items: center; + flex-shrink: 0; +} + +.chat-toolbar-buttons { + display: flex; + gap: 4px; + align-items: center; + margin-left: 8px; +} + +.chat-content-wrapper { + flex: 1; + overflow: hidden; + display: flex; + flex-direction: column; +} + +.messages-container { + flex: 1; + overflow-y: auto; + padding: 12px; + display: flex; + flex-direction: column; + gap: 8px; +} + +.message { + display: flex; + flex-direction: column; + gap: 2px; + max-width: 85%; + + &.user-message { + align-self: flex-end; + + .message-content { + background: #1890ff; + color: white; + } + } + + &.ai-message { + align-self: flex-start; + + .message-content { + background: #f5f5f5; + color: #262626; + } + } + + &.loading-message { + .message-content { + background: #e6f7ff; + border: 1px solid #91d5ff; + } + } +} + +.message-header { + display: flex; + align-items: center; + gap: 6px; + padding: 0 8px; + font-size: 12px; + color: #8c8c8c; + + i { + font-size: 14px; + } +} + +.message-content { + padding: 8px 12px; + border-radius: 6px; + line-height: 1.5; + word-wrap: break-word; + font-size: 14px; + user-select: text; + + ::ng-deep markdown { + display: block; + + p { + margin: 0 0 8px 0; + + &:last-child { + margin-bottom: 0; + } + } + + code { + background: rgba(0, 0, 0, 0.06); + padding: 2px 6px; + border-radius: 3px; + font-size: 13px; + } + + pre { + background: rgba(0, 0, 0, 0.06); + padding: 12px; + border-radius: 4px; + overflow-x: auto; + margin: 8px 0; + + code { + background: transparent; + padding: 0; + } + } + + ul, + ol { + margin: 8px 0; + padding-left: 24px; + } + + li { + margin: 4px 0; + } + + blockquote { + border-left: 3px solid rgba(0, 0, 0, 0.1); + padding-left: 12px; + margin: 8px 0; + color: rgba(0, 0, 0, 0.65); + } + + a { + color: #1890ff; + text-decoration: none; + + &:hover { + text-decoration: underline; + } + } + } +} + +.user-message .message-content ::ng-deep markdown { + code { + background: rgba(255, 255, 255, 0.2); + color: white; + } + + pre { + background: rgba(255, 255, 255, 0.15); + + code { + color: white; + } + } + + blockquote { + border-left-color: rgba(255, 255, 255, 0.3); + color: rgba(255, 255, 255, 0.9); + } + + a { + color: #e6f7ff; + + &:hover { + color: white; + } + } +} + +.input-area { + display: flex; + gap: 8px; + padding: 8px 12px; + border-top: 1px solid #f0f0f0; + background: #ffffff; + + textarea { + flex: 1; + } + + button { + align-self: flex-end; + } +} + +.connection-warning { + display: flex; + align-items: center; + gap: 8px; + padding: 8px 16px; + background: #fff7e6; + border-top: 1px solid #ffd666; + color: #d46b08; + font-size: 12px; + + i { + color: #faad14; + } +} diff --git a/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.spec.ts b/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.spec.ts new file mode 100644 index 00000000000..a972d71c73f --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.spec.ts @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { ComponentFixture, TestBed } from "@angular/core/testing"; +import { HttpClientTestingModule } from "@angular/common/http/testing"; +import { AgentChatComponent } from "./agent-chat.component"; +import { TexeraCopilotManagerService } from "../../../service/copilot/texera-copilot-manager.service"; +import { NotificationService } from "../../../../common/service/notification/notification.service"; +import { commonTestProviders } from "../../../../common/testing/test-utils"; +import { NO_ERRORS_SCHEMA } from "@angular/core"; + +describe("AgentChatComponent", () => { + let component: AgentChatComponent; + let fixture: ComponentFixture; + let mockCopilotManagerService: jasmine.SpyObj; + let mockNotificationService: jasmine.SpyObj; + + beforeEach(async () => { + mockCopilotManagerService = jasmine.createSpyObj("TexeraCopilotManagerService", [ + "getAgentResponsesObservable", + "getAgentStateObservable", + "sendMessage", + "stopGeneration", + "clearMessages", + "getSystemInfo", + ]); + mockNotificationService = jasmine.createSpyObj("NotificationService", ["info", "error", "success"]); + + await TestBed.configureTestingModule({ + declarations: [AgentChatComponent], + imports: [HttpClientTestingModule], + providers: [ + { provide: TexeraCopilotManagerService, useValue: mockCopilotManagerService }, + { provide: NotificationService, useValue: mockNotificationService }, + ...commonTestProviders, + ], + schemas: [NO_ERRORS_SCHEMA], + }).compileComponents(); + }); + + beforeEach(() => { + fixture = TestBed.createComponent(AgentChatComponent); + component = fixture.componentInstance; + }); + + it("should create", () => { + expect(component).toBeTruthy(); + }); +}); diff --git a/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.ts b/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.ts new file mode 100644 index 00000000000..99acf43effe --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-chat/agent-chat.component.ts @@ -0,0 +1,277 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component, ViewChild, ElementRef, Input, OnInit, AfterViewChecked } from "@angular/core"; +import { UntilDestroy, untilDestroyed } from "@ngneat/until-destroy"; +import { CopilotState, ReActStep } from "../../../service/copilot/texera-copilot"; +import { AgentInfo, TexeraCopilotManagerService } from "../../../service/copilot/texera-copilot-manager.service"; +import { NotificationService } from "../../../../common/service/notification/notification.service"; + +@UntilDestroy() +@Component({ + selector: "texera-agent-chat", + templateUrl: "agent-chat.component.html", + styleUrls: ["agent-chat.component.scss"], +}) +export class AgentChatComponent implements OnInit, AfterViewChecked { + @Input() agentInfo!: AgentInfo; + @ViewChild("messageContainer", { static: false }) messageContainer?: ElementRef; + @ViewChild("messageInput", { static: false }) messageInput?: ElementRef; + + public responses: ReActStep[] = []; + public currentMessage = ""; + private shouldScrollToBottom = false; + public isDetailsModalVisible = false; + public selectedResponse: ReActStep | null = null; + public hoveredMessageIndex: number | null = null; + public isSystemInfoModalVisible = false; + public systemPrompt: string = ""; + public availableTools: Array<{ name: string; description: string; inputSchema: any }> = []; + public agentState: CopilotState = CopilotState.UNAVAILABLE; + + constructor( + private copilotManagerService: TexeraCopilotManagerService, + private notificationService: NotificationService + ) {} + + ngOnInit(): void { + if (!this.agentInfo) { + return; + } + + // Subscribe to agent responses + this.copilotManagerService + .getReActStepsObservable(this.agentInfo.id) + .pipe(untilDestroyed(this)) + .subscribe(responses => { + this.responses = responses; + this.shouldScrollToBottom = true; + }); + + // Subscribe to agent state changes + this.copilotManagerService + .getAgentStateObservable(this.agentInfo.id) + .pipe(untilDestroyed(this)) + .subscribe(state => { + this.agentState = state; + }); + } + + ngAfterViewChecked(): void { + if (this.shouldScrollToBottom) { + this.scrollToBottom(); + this.shouldScrollToBottom = false; + } + } + + public setHoveredMessage(index: number | null): void { + this.hoveredMessageIndex = index; + } + + public showResponseDetails(response: ReActStep): void { + this.selectedResponse = response; + this.isDetailsModalVisible = true; + } + + public closeDetailsModal(): void { + this.isDetailsModalVisible = false; + this.selectedResponse = null; + } + + public showSystemInfo(): void { + this.copilotManagerService + .getSystemInfo(this.agentInfo.id) + .pipe(untilDestroyed(this)) + .subscribe(systemInfo => { + this.systemPrompt = systemInfo.systemPrompt; + this.availableTools = systemInfo.tools; + this.isSystemInfoModalVisible = true; + }); + } + + public closeSystemInfoModal(): void { + this.isSystemInfoModalVisible = false; + } + + public formatJson(data: any): string { + return JSON.stringify(data, null, 2); + } + + public getToolResult(response: ReActStep, toolCallIndex: number): any { + if (!response.toolResults || toolCallIndex >= response.toolResults.length) { + return null; + } + const toolResult = response.toolResults[toolCallIndex]; + return toolResult.output || toolResult.result || toolResult; + } + + public getReActStepOperatorAccess( + response: ReActStep, + toolCallIndex: number + ): { viewedOperatorIds: string[]; modifiedOperatorIds: string[] } | null { + if (!response.toolResults || toolCallIndex >= response.toolResults.length) { + return null; + } + const toolResult = response.toolResults[toolCallIndex]; + const result = toolResult.output || toolResult.result || toolResult; + + // Check if the result has operator access information + if (result && (result.viewedOperatorIds || result.modifiedOperatorIds)) { + return { + viewedOperatorIds: result.viewedOperatorIds || [], + modifiedOperatorIds: result.modifiedOperatorIds || [], + }; + } + + return null; + } + + public getTotalInputTokens(): number { + // Iterate in reverse to find the most recent usage (already sorted by timestamp) + for (let i = this.responses.length - 1; i >= 0; i--) { + if (this.responses[i].usage?.inputTokens !== undefined) { + return this.responses[i].usage!.inputTokens!; + } + } + return 0; + } + + public getTotalOutputTokens(): number { + // Iterate in reverse to find the most recent usage (already sorted by timestamp) + for (let i = this.responses.length - 1; i >= 0; i--) { + if (this.responses[i].usage?.outputTokens !== undefined) { + return this.responses[i].usage!.outputTokens!; + } + } + return 0; + } + + /** + * Send a message to the agent via the copilot manager service. + */ + public sendMessage(): void { + if (!this.currentMessage.trim() || !this.canSendMessage()) { + return; + } + + const userMessage = this.currentMessage.trim(); + this.currentMessage = ""; + + // Send to copilot via manager service + this.copilotManagerService + .sendMessage(this.agentInfo.id, userMessage) + .pipe(untilDestroyed(this)) + .subscribe({ + error: (error: unknown) => { + this.notificationService.error(`Error sending message: ${error}`); + }, + }); + } + + /** + * Check if messages can be sent (only when agent is available). + */ + public canSendMessage(): boolean { + return this.agentState === CopilotState.AVAILABLE; + } + + /** + * Get the NG-ZORRO icon type based on current agent state. + */ + public getStateIcon(): string { + switch (this.agentState) { + case CopilotState.AVAILABLE: + return "check-circle"; + case CopilotState.GENERATING: + case CopilotState.STOPPING: + return "sync"; + case CopilotState.UNAVAILABLE: + default: + return "close-circle"; + } + } + + /** + * Get the icon color based on current agent state. + */ + public getStateIconColor(): string { + switch (this.agentState) { + case CopilotState.AVAILABLE: + return "#52c41a"; + case CopilotState.GENERATING: + case CopilotState.STOPPING: + return "#1890ff"; + case CopilotState.UNAVAILABLE: + default: + return "#ff4d4f"; + } + } + + /** + * Get the tooltip text for the state icon. + */ + public getStateTooltip(): string { + switch (this.agentState) { + case CopilotState.AVAILABLE: + return "Agent is ready"; + case CopilotState.GENERATING: + return "Agent is generating response..."; + case CopilotState.STOPPING: + return "Agent is stopping..."; + case CopilotState.UNAVAILABLE: + return "Agent is unavailable"; + default: + return "Agent status unknown"; + } + } + + public onEnterPress(event: KeyboardEvent): void { + if (!event.shiftKey) { + event.preventDefault(); + this.sendMessage(); + } + } + + private scrollToBottom(): void { + if (this.messageContainer) { + const element = this.messageContainer.nativeElement; + element.scrollTop = element.scrollHeight; + } + } + + public stopGeneration(): void { + this.copilotManagerService.stopGeneration(this.agentInfo.id).pipe(untilDestroyed(this)).subscribe(); + } + + public clearMessages(): void { + this.copilotManagerService.clearMessages(this.agentInfo.id).pipe(untilDestroyed(this)).subscribe(); + } + + public isGenerating(): boolean { + return this.agentState === CopilotState.GENERATING; + } + + public isAvailable(): boolean { + return this.agentState === CopilotState.AVAILABLE; + } + + public isConnected(): boolean { + return this.agentState !== CopilotState.UNAVAILABLE; + } +} diff --git a/frontend/src/app/workspace/component/agent-panel/agent-panel.component.html b/frontend/src/app/workspace/component/agent-panel/agent-panel.component.html new file mode 100644 index 00000000000..4d419bd5764 --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-panel.component.html @@ -0,0 +1,138 @@ + + + + + +
+ +
    + +
  • + +
  • +
+ +
+

+ AI Agents +

+ + + + + + + + + + + + + +
+ {{ agent.name }} + +
+
+ + + +
+
+ + +
+
+ + + +
+ {{ agents.length }} agent(s) +
+
diff --git a/frontend/src/app/workspace/component/agent-panel/agent-panel.component.scss b/frontend/src/app/workspace/component/agent-panel/agent-panel.component.scss new file mode 100644 index 00000000000..3dc606bba65 --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-panel.component.scss @@ -0,0 +1,164 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +:host { + display: block; + width: 100%; + height: 100%; + position: fixed; + z-index: 3; +} + +#agent-container { + position: absolute; + top: calc(-100% + 80px); + right: 0; + z-index: 3; + background: white; +} + +#title { + padding: 5px 9px; + border-bottom: 1px solid #e0e0e0; + position: absolute; + top: 0; + background: white; + width: 100%; + z-index: 2; +} + +#agent-docked-button { + position: fixed; + bottom: 40px; // Position above the mini-map button + right: 10px; + z-index: 4; + box-shadow: + 0 3px 1px -2px #0003, + 0 2px 2px #00000024, + 0 1px 5px #0000001f; +} + +#return-button { + position: absolute; + top: 0; + right: 0; + z-index: 3; + display: flex; +} + +#content { + width: 100%; + height: 100%; + padding-top: 32px; + display: inline-block; + overflow-y: auto; +} + +.shadow { + border-radius: 5px; + box-shadow: + 0 3px 1px -2px #0003, + 0 2px 2px #00000024, + 0 1px 5px #0000001f; +} + +.ant-menu-item { + margin: 0 !important; + height: 32px; + line-height: 32px; + padding: 0 9px; +} + +.agent-tabs { + height: calc(100% - 32px); // Account for the title bar + display: flex; + flex-direction: column; + overflow: hidden; + + ::ng-deep { + .ant-tabs { + height: 100%; + display: flex; + flex-direction: column; + } + + .ant-tabs-nav { + margin-bottom: 0; + } + + .ant-tabs-content-holder { + flex: 1; + overflow: hidden; + } + + .ant-tabs-content { + height: 100%; + } + + .ant-tabs-tabpane { + height: 100%; + overflow: hidden; + padding: 0; + } + } +} + +.agent-tab-title { + display: flex; + align-items: center; + gap: 8px; + + .agent-tab-name { + flex: 1; + } +} + +.agent-tab-close { + padding: 0 !important; + width: 20px !important; + height: 20px !important; + min-width: 20px !important; + display: flex; + align-items: center; + justify-content: center; + opacity: 0.6; + margin-left: 4px; + + &:hover { + opacity: 1; + color: #ff4d4f !important; + background: rgba(255, 77, 79, 0.1) !important; + } + + i { + font-size: 12px; + } +} + +.tab-bar-extra { + padding-right: 8px; +} + +.agent-count { + font-size: 12px; + color: #8c8c8c; + padding: 4px 8px; + background: #f0f0f0; + border-radius: 4px; +} diff --git a/frontend/src/app/workspace/component/agent-panel/agent-panel.component.ts b/frontend/src/app/workspace/component/agent-panel/agent-panel.component.ts new file mode 100644 index 00000000000..47e936ed31b --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-panel.component.ts @@ -0,0 +1,200 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component, HostListener, OnDestroy, OnInit } from "@angular/core"; +import { UntilDestroy, untilDestroyed } from "@ngneat/until-destroy"; +import { NzResizeEvent } from "ng-zorro-antd/resizable"; +import { TexeraCopilotManagerService, AgentInfo } from "../../service/copilot/texera-copilot-manager.service"; +import { calculateTotalTranslate3d } from "../../../common/util/panel-dock"; + +@UntilDestroy() +@Component({ + selector: "texera-agent-panel", + templateUrl: "agent-panel.component.html", + styleUrls: ["agent-panel.component.scss"], +}) +export class AgentPanelComponent implements OnInit, OnDestroy { + protected readonly window = window; + private static readonly MIN_PANEL_WIDTH = 400; + private static readonly MIN_PANEL_HEIGHT = 450; + + // Panel dimensions and position + width: number = 0; // Start with 0 to show docked button + height = Math.max(AgentPanelComponent.MIN_PANEL_HEIGHT, window.innerHeight * 0.7); + id = -1; + dragPosition = { x: 0, y: 0 }; + returnPosition = { x: 0, y: 0 }; + isDocked = true; + + // Tab management + selectedTabIndex: number = 0; // 0 = registration tab, 1 = action plans tab, 2+ = agent tabs + agents: AgentInfo[] = []; + + constructor(private copilotManagerService: TexeraCopilotManagerService) {} + + ngOnInit(): void { + this.loadPanelSettings(); + + // Subscribe to agent changes + this.copilotManagerService.agentChange$.pipe(untilDestroyed(this)).subscribe(() => { + this.copilotManagerService + .getAllAgents() + .pipe(untilDestroyed(this)) + .subscribe(agents => { + this.agents = agents; + }); + }); + + // Load initial agents + this.copilotManagerService + .getAllAgents() + .pipe(untilDestroyed(this)) + .subscribe(agents => { + this.agents = agents; + }); + } + + @HostListener("window:beforeunload") + ngOnDestroy(): void { + this.savePanelSettings(); + } + + /** + * Open the panel from docked state + */ + public openPanel(): void { + if (this.width === 0) { + // Open panel + this.width = AgentPanelComponent.MIN_PANEL_WIDTH; + } else { + // Close panel (dock it) + this.width = 0; + this.isDocked = true; + } + } + + /** + * Handle agent creation + */ + public onAgentCreated(agentId: string): void { + // The agent is already added to the agents array by the manager service + // Find the index of the newly created agent and switch to that tab + // Tab index 0 is registration, 1 is action plans, so agent tabs start at index 2 + const agentIndex = this.agents.findIndex(agent => agent.id === agentId); + if (agentIndex !== -1) { + this.selectedTabIndex = agentIndex + 2; // +2 because tab 0 is registration, tab 1 is action plans + } + } + + /** + * Delete an agent + */ + public deleteAgent(agentId: string, event: Event): void { + event.stopPropagation(); // Prevent tab switch + + if (confirm("Are you sure you want to delete this agent?")) { + const agentIndex = this.agents.findIndex(agent => agent.id === agentId); + this.copilotManagerService.deleteAgent(agentId); + + // If we're on the deleted agent's tab, switch to registration + if (agentIndex !== -1 && this.selectedTabIndex === agentIndex + 2) { + this.selectedTabIndex = 0; + } else if (this.selectedTabIndex > agentIndex + 2) { + // Adjust selected index if we deleted a tab before the current one + this.selectedTabIndex--; + } + } + } + + /** + * Handle panel resize + */ + onResize({ width, height }: NzResizeEvent): void { + cancelAnimationFrame(this.id); + this.id = requestAnimationFrame(() => { + this.width = width!; + this.height = height!; + }); + } + + /** + * Reset panel to docked position + */ + resetPanelPosition(): void { + this.dragPosition = { x: this.returnPosition.x, y: this.returnPosition.y }; + this.isDocked = true; + } + + /** + * Handle drag start + */ + handleDragStart(): void { + this.isDocked = false; + } + + /** + * Load panel settings from localStorage + */ + private loadPanelSettings(): void { + const savedWidth = localStorage.getItem("agent-panel-width"); + const savedHeight = localStorage.getItem("agent-panel-height"); + const savedStyle = localStorage.getItem("agent-panel-style"); + const savedDocked = localStorage.getItem("agent-panel-docked"); + + // Only restore width if the panel was not docked + if (savedDocked === "false" && savedWidth) { + const parsedWidth = Number(savedWidth); + if (!isNaN(parsedWidth) && parsedWidth >= AgentPanelComponent.MIN_PANEL_WIDTH) { + this.width = parsedWidth; + } + } + + if (savedHeight) { + const parsedHeight = Number(savedHeight); + if (!isNaN(parsedHeight) && parsedHeight >= AgentPanelComponent.MIN_PANEL_HEIGHT) { + this.height = parsedHeight; + } + } + + if (savedStyle) { + const container = document.getElementById("agent-container"); + if (container) { + container.style.cssText = savedStyle; + const translates = container.style.transform; + const [xOffset, yOffset] = calculateTotalTranslate3d(translates); + this.returnPosition = { x: -xOffset, y: -yOffset }; + this.isDocked = this.dragPosition.x === this.returnPosition.x && this.dragPosition.y === this.returnPosition.y; + } + } + } + + /** + * Save panel settings to localStorage + */ + private savePanelSettings(): void { + localStorage.setItem("agent-panel-width", String(this.width)); + localStorage.setItem("agent-panel-height", String(this.height)); + localStorage.setItem("agent-panel-docked", String(this.width === 0)); + + const container = document.getElementById("agent-container"); + if (container) { + localStorage.setItem("agent-panel-style", container.style.cssText); + } + } +} diff --git a/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.html b/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.html new file mode 100644 index 00000000000..c31e8c3ef92 --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.html @@ -0,0 +1,101 @@ + + +
+
+

Create New Agent

+

Select a model type and create an AI agent to assist with your workflows

+
+ +
+

Select Model Type

+
+ +

Loading available models...

+
+
+ +

No models available

+
+
+
+
+ +
+
+
{{ modelType.name }}
+

{{ modelType.description }}

+
+
+ +
+
+
+
+ +
+

Agent Name (Optional)

+ +
+ +
+ +
+
diff --git a/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.scss b/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.scss new file mode 100644 index 00000000000..639dcb78eb0 --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.scss @@ -0,0 +1,147 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +.agent-registration-container { + display: flex; + flex-direction: column; + gap: 24px; + padding: 20px; + height: 100%; + overflow-y: auto; +} + +.registration-header { + text-align: center; + + h3 { + margin: 0 0 8px 0; + font-size: 20px; + font-weight: 600; + color: #262626; + } + + p { + margin: 0; + font-size: 14px; + color: #8c8c8c; + } +} + +.model-type-selection { + h4 { + margin: 0 0 12px 0; + font-size: 16px; + font-weight: 500; + color: #262626; + } +} + +.model-cards { + display: flex; + flex-direction: column; + gap: 12px; +} + +.model-card { + display: flex; + align-items: center; + gap: 16px; + padding: 16px; + border: 2px solid #d9d9d9; + border-radius: 8px; + cursor: pointer; + transition: all 0.3s ease; + position: relative; + + &:hover { + border-color: #1890ff; + box-shadow: 0 2px 8px rgba(24, 144, 255, 0.2); + } + + &.selected { + border-color: #1890ff; + background: #e6f7ff; + box-shadow: 0 2px 8px rgba(24, 144, 255, 0.3); + } +} + +.model-icon { + display: flex; + align-items: center; + justify-content: center; + width: 48px; + height: 48px; + background: #f0f0f0; + border-radius: 8px; + flex-shrink: 0; + + i { + font-size: 28px; + color: #1890ff; + } +} + +.model-info { + flex: 1; + + h5 { + margin: 0 0 4px 0; + font-size: 15px; + font-weight: 600; + color: #262626; + } + + p { + margin: 0; + font-size: 13px; + color: #8c8c8c; + line-height: 1.4; + } +} + +.selected-indicator { + flex-shrink: 0; + + i { + font-size: 24px; + color: #1890ff; + } +} + +.agent-name-input { + h4 { + margin: 0 0 8px 0; + font-size: 16px; + font-weight: 500; + color: #262626; + } + + input { + width: 100%; + } +} + +.action-buttons { + display: flex; + justify-content: center; + + button { + min-width: 200px; + } +} diff --git a/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.ts b/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.ts new file mode 100644 index 00000000000..4d9b69706e5 --- /dev/null +++ b/frontend/src/app/workspace/component/agent-panel/agent-registration/agent-registration.component.ts @@ -0,0 +1,112 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component, EventEmitter, OnDestroy, OnInit, Output } from "@angular/core"; +import { TexeraCopilotManagerService, ModelType } from "../../../service/copilot/texera-copilot-manager.service"; +import { NotificationService } from "../../../../common/service/notification/notification.service"; +import { Subject, takeUntil } from "rxjs"; + +@Component({ + selector: "texera-agent-registration", + templateUrl: "agent-registration.component.html", + styleUrls: ["agent-registration.component.scss"], +}) +export class AgentRegistrationComponent implements OnInit, OnDestroy { + @Output() agentCreated = new EventEmitter(); + + public modelTypes: ModelType[] = []; + public selectedModelType: string | null = null; + public customAgentName: string = ""; + public isLoadingModels: boolean = false; + public hasLoadingError: boolean = false; + + private destroy$ = new Subject(); + + constructor( + private copilotManagerService: TexeraCopilotManagerService, + private notificationService: NotificationService + ) {} + + ngOnInit(): void { + this.isLoadingModels = true; + this.hasLoadingError = false; + + this.copilotManagerService + .fetchModelTypes() + .pipe(takeUntil(this.destroy$)) + .subscribe({ + next: models => { + this.modelTypes = models; + this.isLoadingModels = false; + if (models.length === 0) { + this.hasLoadingError = true; + this.notificationService.error("No models available. Please check the LiteLLM configuration."); + } + }, + error: (error: unknown) => { + this.isLoadingModels = false; + this.hasLoadingError = true; + const errorMessage = error instanceof Error ? error.message : String(error); + this.notificationService.error(`Failed to fetch models: ${errorMessage}`); + }, + }); + } + + ngOnDestroy(): void { + this.destroy$.next(); + this.destroy$.complete(); + } + + public selectModelType(modelTypeId: string): void { + this.selectedModelType = modelTypeId; + } + + public isCreating: boolean = false; + + /** + * Create a new agent with the selected model type. + */ + public createAgent(): void { + if (!this.selectedModelType || this.isCreating) { + return; + } + + this.isCreating = true; + + this.copilotManagerService + .createAgent(this.selectedModelType, this.customAgentName || undefined) + .pipe(takeUntil(this.destroy$)) + .subscribe({ + next: agentInfo => { + this.agentCreated.emit(agentInfo.id); + this.selectedModelType = null; + this.customAgentName = ""; + this.isCreating = false; + }, + error: (error: unknown) => { + this.notificationService.error(`Failed to create agent: ${error}`); + this.isCreating = false; + }, + }); + } + + public canCreate(): boolean { + return this.selectedModelType !== null && !this.isCreating; + } +} diff --git a/frontend/src/app/workspace/component/workspace.component.html b/frontend/src/app/workspace/component/workspace.component.html index 2e662831a33..abe3d2a786f 100644 --- a/frontend/src/app/workspace/component/workspace.component.html +++ b/frontend/src/app/workspace/component/workspace.component.html @@ -33,5 +33,6 @@ + diff --git a/frontend/src/app/workspace/component/workspace.component.ts b/frontend/src/app/workspace/component/workspace.component.ts index 8958f08df10..14e1a937a0c 100644 --- a/frontend/src/app/workspace/component/workspace.component.ts +++ b/frontend/src/app/workspace/component/workspace.component.ts @@ -283,4 +283,8 @@ export class WorkspaceComponent implements AfterViewInit, OnInit, OnDestroy { public triggerCenter(): void { this.workflowActionService.getTexeraGraph().triggerCenterEvent(); } + + public get copilotEnabled(): boolean { + return this.config.env.copilotEnabled; + } } diff --git a/frontend/src/app/workspace/service/compile-workflow/workflow-compiling.service.ts b/frontend/src/app/workspace/service/compile-workflow/workflow-compiling.service.ts index 1b45f755994..9648d2b305f 100644 --- a/frontend/src/app/workspace/service/compile-workflow/workflow-compiling.service.ts +++ b/frontend/src/app/workspace/service/compile-workflow/workflow-compiling.service.ts @@ -148,6 +148,17 @@ export class WorkflowCompilingService { ); } + public getOperatorOutputSchemaMap(operatorID: string): OperatorPortSchemaMap | undefined { + if ( + this.currentCompilationStateInfo.state == CompilationState.Uninitialized || + !this.currentCompilationStateInfo.operatorOutputPortSchemaMap + ) { + return undefined; + } + + return this.currentCompilationStateInfo.operatorOutputPortSchemaMap[operatorID]; + } + public getPortInputSchema(operatorID: string, portIndex: number): PortSchema | undefined { return this.getOperatorInputSchemaMap(operatorID)?.[serializePortIdentity({ id: portIndex, internal: false })]; } diff --git a/frontend/src/app/workspace/service/copilot/copilot-prompts.ts b/frontend/src/app/workspace/service/copilot/copilot-prompts.ts new file mode 100644 index 00000000000..48842211308 --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/copilot-prompts.ts @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * System prompts for Texera Copilot + */ + +export const COPILOT_SYSTEM_PROMPT = `# Texera Copilot + +You are Texera Copilot, an AI assistant for helping users do data science using Texera workflows. + +Your job is to leverage tools to help users understand Texera's functionalities, including what operators are available +and how to use them. + +You also need to help users understand the workflow they are currently working on. + +During the process, leverage tool calls whenever needed. Current available tools are all READ-ONLY. Thus you cannot edit +user's workflow. +`; diff --git a/frontend/src/app/workspace/service/copilot/texera-copilot-manager.service.spec.ts b/frontend/src/app/workspace/service/copilot/texera-copilot-manager.service.spec.ts new file mode 100644 index 00000000000..8f3820afce4 --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/texera-copilot-manager.service.spec.ts @@ -0,0 +1,259 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { TestBed } from "@angular/core/testing"; +import { HttpClientTestingModule, HttpTestingController } from "@angular/common/http/testing"; +import { TexeraCopilotManagerService } from "./texera-copilot-manager.service"; +import { CopilotState } from "./texera-copilot"; +import { commonTestProviders } from "../../../common/testing/test-utils"; +import { AppSettings } from "../../../common/app-setting"; + +describe("TexeraCopilotManagerService", () => { + let service: TexeraCopilotManagerService; + let httpMock: HttpTestingController; + + beforeEach(() => { + TestBed.configureTestingModule({ + imports: [HttpClientTestingModule], + providers: [TexeraCopilotManagerService, ...commonTestProviders], + }); + + service = TestBed.inject(TexeraCopilotManagerService); + httpMock = TestBed.inject(HttpTestingController); + }); + + afterEach(() => { + httpMock.verify(); + }); + + it("should be created", () => { + expect(service).toBeTruthy(); + }); + + describe("fetchModelTypes", () => { + it("should fetch and format model types", done => { + const mockResponse = { + data: [ + { id: "gpt-4", object: "model", created: 1234567890, owned_by: "openai" }, + { id: "claude-3", object: "model", created: 1234567891, owned_by: "anthropic" }, + ], + object: "list", + }; + + service.fetchModelTypes().subscribe(models => { + expect(models.length).toBe(2); + expect(models[0].id).toBe("gpt-4"); + expect(models[0].name).toBe("Gpt 4"); + expect(models[1].id).toBe("claude-3"); + expect(models[1].name).toBe("Claude 3"); + done(); + }); + + const req = httpMock.expectOne(`${AppSettings.getApiEndpoint()}/models`); + expect(req.request.method).toBe("GET"); + req.flush(mockResponse); + }); + + it("should handle fetch error gracefully", done => { + service.fetchModelTypes().subscribe(models => { + expect(models).toEqual([]); + done(); + }); + + const req = httpMock.expectOne(`${AppSettings.getApiEndpoint()}/models`); + req.error(new ProgressEvent("error")); + }); + + it("should cache model types with shareReplay", done => { + const mockResponse = { + data: [{ id: "gpt-4", object: "model", created: 1234567890, owned_by: "openai" }], + object: "list", + }; + + service.fetchModelTypes().subscribe(() => { + service.fetchModelTypes().subscribe(models => { + expect(models.length).toBe(1); + done(); + }); + }); + + const req = httpMock.expectOne(`${AppSettings.getApiEndpoint()}/models`); + req.flush(mockResponse); + }); + }); + + describe("getAllAgents", () => { + it("should return empty array initially", done => { + service.getAllAgents().subscribe(agents => { + expect(agents).toEqual([]); + done(); + }); + }); + }); + + describe("getAgentCount", () => { + it("should return 0 initially", done => { + service.getAgentCount().subscribe(count => { + expect(count).toBe(0); + done(); + }); + }); + }); + + describe("getAgent", () => { + it("should throw error when agent not found", done => { + service.getAgent("non-existent").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("isAgentConnected", () => { + it("should return false for non-existent agent", done => { + service.isAgentConnected("non-existent").subscribe(connected => { + expect(connected).toBe(false); + done(); + }); + }); + }); + + describe("agent lifecycle management", () => { + it("should emit agent change event on agent creation", done => { + let eventEmitted = false; + + service.agentChange$.subscribe(() => { + eventEmitted = true; + }); + + setTimeout(() => { + expect(eventEmitted).toBe(false); + done(); + }, 100); + }); + }); + + describe("sendMessage", () => { + it("should throw error for non-existent agent", done => { + service.sendMessage("non-existent", "test message").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("getAgentResponses", () => { + it("should throw error for non-existent agent", done => { + service.getAgentResponses("non-existent").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("getAgentResponsesObservable", () => { + it("should throw error for non-existent agent", done => { + service.getReActStepsObservable("non-existent").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("clearMessages", () => { + it("should throw error for non-existent agent", done => { + service.clearMessages("non-existent").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("stopGeneration", () => { + it("should throw error for non-existent agent", done => { + service.stopGeneration("non-existent").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("getAgentState", () => { + it("should throw error for non-existent agent", done => { + service.getAgentState("non-existent").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("getAgentStateObservable", () => { + it("should throw error for non-existent agent", done => { + service.getAgentStateObservable("non-existent").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("getSystemInfo", () => { + it("should throw error for non-existent agent", done => { + service.getSystemInfo("non-existent").subscribe({ + next: () => fail("Should have thrown error"), + error: (error: unknown) => { + expect((error as Error).message).toContain("not found"); + done(); + }, + }); + }); + }); + + describe("deleteAgent", () => { + it("should return false for non-existent agent", done => { + service.deleteAgent("non-existent").subscribe(deleted => { + expect(deleted).toBe(false); + done(); + }); + }); + }); +}); diff --git a/frontend/src/app/workspace/service/copilot/texera-copilot-manager.service.ts b/frontend/src/app/workspace/service/copilot/texera-copilot-manager.service.ts new file mode 100644 index 00000000000..ab7d393265f --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/texera-copilot-manager.service.ts @@ -0,0 +1,248 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Injectable, Injector } from "@angular/core"; +import { HttpClient } from "@angular/common/http"; +import { TexeraCopilot, ReActStep, CopilotState } from "./texera-copilot"; +import { Observable, Subject, catchError, map, of, shareReplay, tap, defer, throwError, switchMap } from "rxjs"; +import { AppSettings } from "../../../common/app-setting"; + +export interface AgentInfo { + id: string; + name: string; + modelType: string; + instance: TexeraCopilot; + createdAt: Date; +} + +export interface ModelType { + id: string; + name: string; + description: string; + icon: string; +} + +interface LiteLLMModel { + id: string; + object: string; + created: number; + owned_by: string; +} + +interface LiteLLMModelsResponse { + data: LiteLLMModel[]; + object: string; +} + +/** + * Texera Copilot Manager Service manages multiple AI agent instances for workflow assistance. + * + * This service provides centralized management for multiple copilot agents, allowing users to: + * 1. Create and delete multiple agent instances with different LLM models + * 2. Route messages to specific agents + * 3. Track agent states and conversation history + * 4. Query available LLM models from the backend + * + * Each agent is a separate TexeraCopilot instance with its own: + * - Model configuration (e.g., GPT-4, Claude, etc.) + * - Conversation history + * - State (available, generating, stopping, unavailable) + * + * The service acts as a registry and coordinator, ensuring proper lifecycle management + * and providing observable streams for agent changes and state updates. + */ +@Injectable({ + providedIn: "root", +}) +export class TexeraCopilotManagerService { + private agents = new Map(); + private agentCounter = 0; + private agentChangeSubject = new Subject(); + public agentChange$ = this.agentChangeSubject.asObservable(); + + private modelTypes$: Observable | null = null; + + constructor( + private injector: Injector, + private http: HttpClient + ) {} + + public createAgent(modelType: string, customName?: string): Observable { + return defer(() => { + const agentId = `agent-${++this.agentCounter}`; + const agentName = customName || `Agent ${this.agentCounter}`; + + const agentInstance = this.createCopilotInstance(modelType); + agentInstance.setAgentInfo(agentName); + + return agentInstance.initialize().pipe( + map(() => { + const agentInfo: AgentInfo = { + id: agentId, + name: agentName, + modelType, + instance: agentInstance, + createdAt: new Date(), + }; + + this.agents.set(agentId, agentInfo); + this.agentChangeSubject.next(); + + return agentInfo; + }), + catchError((error: unknown) => { + return throwError(() => error); + }) + ); + }); + } + + /** + * Helper method to get an agent and execute a callback with it. + * Handles agent lookup and error throwing if not found. + */ + private withAgent(agentId: string, callback: (agent: AgentInfo) => Observable): Observable { + return defer(() => { + const agent = this.agents.get(agentId); + if (!agent) { + return throwError(() => new Error(`Agent with ID ${agentId} not found`)); + } + return callback(agent); + }); + } + + public getAgent(agentId: string): Observable { + return this.withAgent(agentId, agent => of(agent)); + } + + public getAllAgents(): Observable { + return of(Array.from(this.agents.values())); + } + public deleteAgent(agentId: string): Observable { + return defer(() => { + const agent = this.agents.get(agentId); + if (!agent) { + return of(false); + } + + return agent.instance.disconnect().pipe( + map(() => { + this.agents.delete(agentId); + this.agentChangeSubject.next(); + return true; + }) + ); + }); + } + + public fetchModelTypes(): Observable { + if (!this.modelTypes$) { + this.modelTypes$ = this.http.get(`${AppSettings.getApiEndpoint()}/models`).pipe( + map(response => + response.data.map((model: LiteLLMModel) => ({ + id: model.id, + name: this.formatModelName(model.id), + description: `Model: ${model.id}`, + icon: "robot", + })) + ), + catchError((error: unknown) => { + console.error("Failed to fetch models from API:", error); + return of([]); + }), + shareReplay(1) + ); + } + return this.modelTypes$; + } + + private formatModelName(modelId: string): string { + return modelId + .split("-") + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(" "); + } + + public getAgentCount(): Observable { + return of(this.agents.size); + } + public sendMessage(agentId: string, message: string): Observable { + return this.withAgent(agentId, agent => agent.instance.sendMessage(message)); + } + + public getReActStepsObservable(agentId: string): Observable { + return this.withAgent(agentId, agent => agent.instance.reActSteps$); + } + + public getAgentResponses(agentId: string): Observable { + return this.withAgent(agentId, agent => of(agent.instance.getReActSteps())); + } + + public clearMessages(agentId: string): Observable { + return this.withAgent(agentId, agent => { + agent.instance.clearMessages(); + return of(undefined); + }); + } + + public stopGeneration(agentId: string): Observable { + return this.withAgent(agentId, agent => { + agent.instance.stopGeneration(); + return of(undefined); + }); + } + + public getAgentState(agentId: string): Observable { + return this.withAgent(agentId, agent => of(agent.instance.getState())); + } + + public getAgentStateObservable(agentId: string): Observable { + return this.withAgent(agentId, agent => agent.instance.state$); + } + + public isAgentConnected(agentId: string): Observable { + return this.withAgent(agentId, agent => of(agent.instance.isConnected())).pipe(catchError(() => of(false))); + } + + public getSystemInfo( + agentId: string + ): Observable<{ systemPrompt: string; tools: Array<{ name: string; description: string; inputSchema: any }> }> { + return this.withAgent(agentId, agent => + of({ + systemPrompt: agent.instance.getSystemPrompt(), + tools: agent.instance.getToolsInfo(), + }) + ); + } + private createCopilotInstance(modelType: string): TexeraCopilot { + const childInjector = Injector.create({ + providers: [ + { + provide: TexeraCopilot, + }, + ], + parent: this.injector, + }); + + const copilotInstance = childInjector.get(TexeraCopilot); + copilotInstance.setModelType(modelType); + + return copilotInstance; + } +} diff --git a/frontend/src/app/workspace/service/copilot/texera-copilot.spec.ts b/frontend/src/app/workspace/service/copilot/texera-copilot.spec.ts new file mode 100644 index 00000000000..dd5a7e5bb6a --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/texera-copilot.spec.ts @@ -0,0 +1,189 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { TestBed } from "@angular/core/testing"; +import { TexeraCopilot, CopilotState } from "./texera-copilot"; +import { WorkflowActionService } from "../workflow-graph/model/workflow-action.service"; +import { WorkflowUtilService } from "../workflow-graph/util/workflow-util.service"; +import { OperatorMetadataService } from "../operator-metadata/operator-metadata.service"; +import { WorkflowCompilingService } from "../compile-workflow/workflow-compiling.service"; +import { NotificationService } from "../../../common/service/notification/notification.service"; +import { commonTestProviders } from "../../../common/testing/test-utils"; + +describe("TexeraCopilot", () => { + let service: TexeraCopilot; + let mockWorkflowActionService: jasmine.SpyObj; + let mockWorkflowUtilService: jasmine.SpyObj; + let mockOperatorMetadataService: jasmine.SpyObj; + let mockWorkflowCompilingService: jasmine.SpyObj; + let mockNotificationService: jasmine.SpyObj; + + beforeEach(() => { + mockWorkflowActionService = jasmine.createSpyObj("WorkflowActionService", ["getTexeraGraph"]); + mockWorkflowUtilService = jasmine.createSpyObj("WorkflowUtilService", ["getOperatorTypeList"]); + mockOperatorMetadataService = jasmine.createSpyObj("OperatorMetadataService", ["getOperatorSchema"]); + mockWorkflowCompilingService = jasmine.createSpyObj("WorkflowCompilingService", [ + "getOperatorInputSchemaMap", + "getOperatorOutputSchemaMap", + ]); + mockNotificationService = jasmine.createSpyObj("NotificationService", ["info", "error"]); + + TestBed.configureTestingModule({ + providers: [ + TexeraCopilot, + { provide: WorkflowActionService, useValue: mockWorkflowActionService }, + { provide: WorkflowUtilService, useValue: mockWorkflowUtilService }, + { provide: OperatorMetadataService, useValue: mockOperatorMetadataService }, + { provide: WorkflowCompilingService, useValue: mockWorkflowCompilingService }, + { provide: NotificationService, useValue: mockNotificationService }, + ...commonTestProviders, + ], + }); + + service = TestBed.inject(TexeraCopilot); + }); + + it("should be created", () => { + expect(service).toBeTruthy(); + }); + + it("should set agent info correctly", () => { + service.setAgentInfo("Test Agent"); + expect(service).toBeTruthy(); + }); + + it("should set model type correctly", () => { + service.setModelType("gpt-4"); + expect(service).toBeTruthy(); + }); + + it("should have initial state as UNAVAILABLE", () => { + expect(service.getState()).toBe(CopilotState.UNAVAILABLE); + }); + + it("should update state correctly", done => { + service.state$.subscribe(state => { + if (state === CopilotState.UNAVAILABLE) { + expect(state).toBe(CopilotState.UNAVAILABLE); + done(); + } + }); + }); + + it("should clear messages correctly", () => { + service.clearMessages(); + expect(service.getReActSteps().length).toBe(0); + }); + + it("should stop generation when in GENERATING state", () => { + service.stopGeneration(); + expect(service).toBeTruthy(); + }); + + it("should return system prompt", () => { + const prompt = service.getSystemPrompt(); + expect(prompt).toBeTruthy(); + expect(typeof prompt).toBe("string"); + }); + + it("should return tools info", done => { + // Tools are only created after initialize() is called + service.initialize().subscribe(() => { + const tools = service.getToolsInfo(); + expect(tools).toBeTruthy(); + expect(Array.isArray(tools)).toBe(true); + expect(tools.length).toBeGreaterThan(0); + tools.forEach(tool => { + expect(tool.name).toBeTruthy(); + expect(tool.description).toBeTruthy(); + }); + done(); + }); + }); + + it("should check if connected", () => { + expect(service.isConnected()).toBe(false); + }); + + it("should emit agent responses correctly", done => { + service.reActSteps$.subscribe(responses => { + if (responses.length > 0) { + expect(responses[0].role).toBe("user"); + expect(responses[0].content).toBe("test message"); + done(); + } + }); + + // emitReActStep signature: (messageId, stepId, role, content, isBegin, isEnd, toolCalls?, toolResults?, usage?, operatorAccess?) + (service as any).emitReActStep("test-id", 0, "user", "test message", true, true); + }); + + it("should return empty agent responses initially", () => { + const responses = service.getReActSteps(); + expect(responses).toEqual([]); + }); + + describe("disconnect", () => { + it("should disconnect and clear state", done => { + service.disconnect().subscribe(() => { + expect(service.getState()).toBe(CopilotState.UNAVAILABLE); + expect(service.getReActSteps().length).toBe(0); + done(); + }); + }); + + it("should show notification on disconnect", done => { + service.setAgentInfo("Test Agent"); + service.disconnect().subscribe(() => { + expect(mockNotificationService.info).toHaveBeenCalled(); + done(); + }); + }); + }); + + describe("state management", () => { + it("should transition from UNAVAILABLE to GENERATING to AVAILABLE", done => { + const states: CopilotState[] = []; + + service.state$.subscribe(state => { + states.push(state); + if (states.length === 1) { + expect(states[0]).toBe(CopilotState.UNAVAILABLE); + done(); + } + }); + }); + }); + + describe("workflow tools", () => { + it("should create workflow tools correctly", () => { + const tools = (service as any).createWorkflowTools(); + expect(tools).toBeTruthy(); + expect(typeof tools).toBe("object"); + // Tool names match the constants in the tool files + expect(tools.listAllOperatorTypes).toBeTruthy(); + expect(tools.listOperatorsInCurrentWorkflow).toBeTruthy(); + expect(tools.listCurrentLinks).toBeTruthy(); + expect(tools.getCurrentOperator).toBeTruthy(); + expect(tools.getOperatorPropertiesSchema).toBeTruthy(); + expect(tools.getOperatorPortsInfo).toBeTruthy(); + expect(tools.getOperatorMetadata).toBeTruthy(); + }); + }); +}); diff --git a/frontend/src/app/workspace/service/copilot/texera-copilot.ts b/frontend/src/app/workspace/service/copilot/texera-copilot.ts new file mode 100644 index 00000000000..cb7cc91e111 --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/texera-copilot.ts @@ -0,0 +1,388 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Injectable } from "@angular/core"; +import { BehaviorSubject, Observable, from, of, throwError, defer } from "rxjs"; +import { map, catchError, tap, switchMap, finalize } from "rxjs/operators"; +import { WorkflowActionService } from "../workflow-graph/model/workflow-action.service"; +import { toolWithTimeout } from "./tool/tools-utility"; +import * as CurrentWorkflowTools from "./tool/current-workflow-editing-observing-tools"; +import * as MetadataTools from "./tool/workflow-metadata-tools"; +import { ToolOperatorAccess, parseOperatorAccessFromStep } from "./tool/react-step-operator-parser"; +import { OperatorMetadataService } from "../operator-metadata/operator-metadata.service"; +import { createOpenAI } from "@ai-sdk/openai"; +import { generateText, type ModelMessage, stepCountIs } from "ai"; +import { WorkflowUtilService } from "../workflow-graph/util/workflow-util.service"; +import { AppSettings } from "../../../common/app-setting"; +import { WorkflowCompilingService } from "../compile-workflow/workflow-compiling.service"; +import { COPILOT_SYSTEM_PROMPT } from "./copilot-prompts"; +import { NotificationService } from "../../../common/service/notification/notification.service"; + +export enum CopilotState { + UNAVAILABLE = "Unavailable", + AVAILABLE = "Available", + GENERATING = "Generating", + STOPPING = "Stopping", +} + +/** + * Represents a single step in the ReAct (Reasoning and Acting) conversation flow. + * Each step can be either a user message or an agent response with potential tool calls. + */ +export interface ReActStep { + messageId: string; + stepId: number; + role: "user" | "agent"; + content: string; + isBegin: boolean; + isEnd: boolean; + timestamp: number; + toolCalls?: any[]; + toolResults?: any[]; + usage?: { + inputTokens?: number; + outputTokens?: number; + totalTokens?: number; + cachedInputTokens?: number; + }; + /** + * Map from tool call index to operator access information, which tracks operators were viewed or modified during the tool call. + */ + operatorAccess?: Map; +} + +/** + * Texera Copilot Service provides AI-powered assistance for workflow creation and manipulation. + * + * This service manages a single AI agent instance that can: + * 1. Interact with users through natural language messages + * 2. Execute workflow operations using specialized tools + * 3. Maintain conversation history and state + * + * The service communicates with an LLM backend (via LiteLLM) to generate responses and uses + * workflow tools to perform actions like listing operators, getting operator schemas, and + * manipulating workflow components. + * + * State management includes: + * - UNAVAILABLE: Agent not initialized + * - AVAILABLE: Agent ready to receive messages + * - GENERATING: Agent currently processing and generating response + * - STOPPING: Agent in the process of stopping generation + */ +@Injectable() +export class TexeraCopilot { + /** + * Maximum number of ReAct reasoning/action cycles allowed per generation. + * Prevents infinite loops and excessive token usage. + */ + private static readonly MAX_REACT_STEPS = 50; + + private model: any; + private modelType = ""; + private agentName = ""; + + /** + * Conversation history in LLM API format. + * Used internally to maintain context for generateText() API calls. + * Contains the raw message format expected by the AI model. + */ + private messages: ModelMessage[] = []; + + /** + * Representing a step in ReAct (Reasoning + Acting). + * This is what gets displayed in the UI to show the agent's reasoning process. + * Each step contains messageId (randomly generated UUID) and stepId (incremental from 0). + */ + private reActSteps: ReActStep[] = []; + private reActStepsSubject = new BehaviorSubject([]); + public reActSteps$ = this.reActStepsSubject.asObservable(); + + private state = CopilotState.UNAVAILABLE; + private stateSubject = new BehaviorSubject(CopilotState.UNAVAILABLE); + public state$ = this.stateSubject.asObservable(); + private tools: Record = {}; + + constructor( + private workflowActionService: WorkflowActionService, + private workflowUtilService: WorkflowUtilService, + private operatorMetadataService: OperatorMetadataService, + private workflowCompilingService: WorkflowCompilingService, + private notificationService: NotificationService + ) {} + + public setAgentInfo(agentName: string): void { + this.agentName = agentName; + } + + public setModelType(modelType: string): void { + this.modelType = modelType; + } + + private setState(newState: CopilotState): void { + this.state = newState; + this.stateSubject.next(newState); + } + + private emitReActStep( + messageId: string, + stepId: number, + role: "user" | "agent", + content: string, + isBegin: boolean, + isEnd: boolean, + toolCalls?: any[], + toolResults?: any[], + usage?: ReActStep["usage"], + operatorAccess?: Map + ): void { + this.reActSteps.push({ + messageId, + stepId, + role, + content, + isBegin, + isEnd, + timestamp: Date.now(), + toolCalls, + toolResults, + usage, + operatorAccess, + }); + this.reActStepsSubject.next([...this.reActSteps]); + } + + public initialize(): Observable { + return defer(() => { + try { + this.model = createOpenAI({ + baseURL: new URL(`${AppSettings.getApiEndpoint()}`, document.baseURI).toString(), + // apiKey is required by the library for creating the OpenAI compatible client; + // For security reason, we store the apiKey at the backend, thus the value is dummy here. + apiKey: "dummy", + }).chat(this.modelType); + + // Create tools once during initialization + this.tools = this.createWorkflowTools(); + + this.setState(CopilotState.AVAILABLE); + return of(undefined); + } catch (error: unknown) { + this.setState(CopilotState.UNAVAILABLE); + return throwError(() => error); + } + }); + } + + public sendMessage(message: string): Observable { + return defer(() => { + if (!this.model) { + return throwError(() => new Error("Copilot not initialized")); + } + + if (this.state !== CopilotState.AVAILABLE) { + return throwError(() => new Error(`Cannot send message: agent is ${this.state}`)); + } + + this.setState(CopilotState.GENERATING); + + // Generate unique message ID for this conversation turn + const messageId = crypto.randomUUID(); + let stepId = 0; + + // Emit user message as first step + this.emitReActStep(messageId, stepId++, "user", message, true, true); + this.messages.push({ role: "user", content: message }); + + let isFirstStep = true; + + /** + * Generate text using the AI model with ReAct (Reasoning + Acting) pattern. + * This is the core of the agent lifecycle with several callbacks: + * + * Lifecycle flow: + * 1. generateText() starts the LLM generation + * 2. stopWhen() - checked before each step to determine if generation should stop + * 3. onStepFinish() - called DURING generation after each reasoning/action step (real-time updates) + * 4. pipe operators - executed AFTER generation completes (final processing) + */ + return from( + generateText({ + model: this.model, + messages: this.messages, + tools: this.tools, + system: COPILOT_SYSTEM_PROMPT, + /** + * stopWhen - Determines if generation should stop. + * Called before each step during generation. + * Returns true to stop, false to continue. + */ + stopWhen: ({ steps }) => { + if (this.state === CopilotState.STOPPING) { + this.notificationService.info(`Agent ${this.agentName} has stopped generation`); + return true; + } + // Stop if step count reaches max limit to prevent infinite loops + return stepCountIs(TexeraCopilot.MAX_REACT_STEPS)({ steps }); + }, + /** + * onStepFinish is called DURING generation after each ReAct step completes. + * This provides real-time updates to the UI as the agent reasons and acts. + * + * Each step may include: + * - text: The agent's reasoning or response text + * - toolCalls: Tools the agent decided to call + * - toolResults: Results from executed tools + * - usage: Token usage for this step + * + * Note: This is called multiple times during a single generation, + * once per reasoning/action cycle. + */ + onStepFinish: ({ text, toolCalls, toolResults, usage }) => { + if (this.state === CopilotState.STOPPING) { + return; + } + + // Parse operator access from tool results to track viewed/modified operators + const operatorAccess = parseOperatorAccessFromStep(toolCalls || [], toolResults || []); + + this.emitReActStep( + messageId, + stepId++, + "agent", + text || "", + isFirstStep, + false, + toolCalls, + toolResults, + usage as any, + operatorAccess + ); + isFirstStep = false; + }, + }) + ).pipe( + /** + * To this point, generateText has finished. + * All the responses from AI are recorded in responses variable. + */ + tap(({ response }) => { + this.messages.push(...response.messages); + this.reActStepsSubject.next([...this.reActSteps]); + }), + map(() => undefined), + catchError((err: unknown) => { + const errorText = `Error: ${err instanceof Error ? err.message : String(err)}`; + this.messages.push({ role: "assistant", content: errorText }); + this.emitReActStep(messageId, stepId++, "agent", errorText, false, true); + return throwError(() => err); + }), + /** + * Resets agent state back to AVAILABLE so it can handle new messages. + */ + finalize(() => { + this.setState(CopilotState.AVAILABLE); + }) + ); + }); + } + + private createWorkflowTools(): Record { + const listOperatorsInCurrentWorkflowTool = toolWithTimeout( + CurrentWorkflowTools.createListOperatorsInCurrentWorkflowTool(this.workflowActionService) + ); + const listLinksTool = toolWithTimeout(CurrentWorkflowTools.createListCurrentLinksTool(this.workflowActionService)); + const listAllOperatorTypesTool = toolWithTimeout( + MetadataTools.createListAllOperatorTypesTool(this.workflowUtilService) + ); + const getOperatorTool = toolWithTimeout( + CurrentWorkflowTools.createGetCurrentOperatorTool(this.workflowActionService, this.workflowCompilingService) + ); + const getOperatorPropertiesSchemaTool = toolWithTimeout( + MetadataTools.createGetOperatorPropertiesSchemaTool(this.operatorMetadataService) + ); + const getOperatorPortsInfoTool = toolWithTimeout( + MetadataTools.createGetOperatorPortsInfoTool(this.operatorMetadataService) + ); + const getOperatorMetadataTool = toolWithTimeout( + MetadataTools.createGetOperatorMetadataTool(this.operatorMetadataService) + ); + + return { + [MetadataTools.TOOL_NAME_LIST_ALL_OPERATOR_TYPES]: listAllOperatorTypesTool, + [CurrentWorkflowTools.TOOL_NAME_LIST_OPERATORS_IN_CURRENT_WORKFLOW]: listOperatorsInCurrentWorkflowTool, + [CurrentWorkflowTools.TOOL_NAME_LIST_CURRENT_LINKS]: listLinksTool, + [CurrentWorkflowTools.TOOL_NAME_GET_CURRENT_OPERATOR]: getOperatorTool, + [MetadataTools.TOOL_NAME_GET_OPERATOR_PROPERTIES_SCHEMA]: getOperatorPropertiesSchemaTool, + [MetadataTools.TOOL_NAME_GET_OPERATOR_PORTS_INFO]: getOperatorPortsInfoTool, + [MetadataTools.TOOL_NAME_GET_OPERATOR_METADATA]: getOperatorMetadataTool, + }; + } + + public getReActSteps(): ReActStep[] { + return [...this.reActSteps]; + } + + public stopGeneration(): void { + if (this.state !== CopilotState.GENERATING) { + return; + } + this.setState(CopilotState.STOPPING); + } + + public clearMessages(): void { + this.messages = []; + this.reActSteps = []; + this.reActStepsSubject.next([]); + } + + public getState(): CopilotState { + return this.state; + } + + public disconnect(): Observable { + return defer(() => { + if (this.state === CopilotState.GENERATING) { + this.stopGeneration(); + } + + this.clearMessages(); + this.tools = {}; // Clear tools to free memory + this.setState(CopilotState.UNAVAILABLE); + this.notificationService.info(`Agent ${this.agentName} is removed successfully`); + + return of(undefined); + }); + } + + public isConnected(): boolean { + return this.state !== CopilotState.UNAVAILABLE; + } + + public getSystemPrompt(): string { + return COPILOT_SYSTEM_PROMPT; + } + + public getToolsInfo(): Array<{ name: string; description: string; inputSchema: any }> { + return Object.entries(this.tools).map(([name, tool]) => ({ + name: name, + description: tool.description || "No description available", + inputSchema: tool.parameters || {}, + })); + } +} diff --git a/frontend/src/app/workspace/service/copilot/tool/current-workflow-editing-observing-tools.ts b/frontend/src/app/workspace/service/copilot/tool/current-workflow-editing-observing-tools.ts new file mode 100644 index 00000000000..c67a7b1707c --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/tool/current-workflow-editing-observing-tools.ts @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { z } from "zod"; +import { tool } from "ai"; +import { WorkflowActionService } from "../../workflow-graph/model/workflow-action.service"; +import { OperatorMetadataService } from "../../operator-metadata/operator-metadata.service"; +import { OperatorLink } from "../../../types/workflow-common.interface"; +import { WorkflowUtilService } from "../../workflow-graph/util/workflow-util.service"; +import { WorkflowCompilingService } from "../../compile-workflow/workflow-compiling.service"; +import { ValidationWorkflowService } from "../../validation/validation-workflow.service"; +import { createSuccessResult, createErrorResult } from "./tools-utility"; + +// Tool name constants +export const TOOL_NAME_LIST_OPERATORS_IN_CURRENT_WORKFLOW = "listOperatorsInCurrentWorkflow"; +export const TOOL_NAME_LIST_CURRENT_LINKS = "listCurrentLinks"; +export const TOOL_NAME_GET_CURRENT_OPERATOR = "getCurrentOperator"; + +/** + * Create listLinksInCurrentWorkflow tool for getting all links in the workflow + */ +export function createListCurrentLinksTool(workflowActionService: WorkflowActionService) { + return tool({ + name: TOOL_NAME_LIST_CURRENT_LINKS, + description: "Get all links in the current workflow", + inputSchema: z.object({}), + execute: async () => { + try { + const links = workflowActionService.getTexeraGraph().getAllLinks(); + return createSuccessResult( + { + links: links, + count: links.length, + }, + [], + [] + ); + } catch (error: any) { + return createErrorResult(error.message); + } + }, + }); +} + +export function createListOperatorsInCurrentWorkflowTool(workflowActionService: WorkflowActionService) { + return tool({ + name: TOOL_NAME_LIST_OPERATORS_IN_CURRENT_WORKFLOW, + description: "Get all operator IDs, types and custom names in the current workflow", + inputSchema: z.object({}), + execute: async () => { + try { + const operators = workflowActionService.getTexeraGraph().getAllOperators(); + const operatorIds = operators.map(op => op.operatorID); + return createSuccessResult( + { + operators: operators.map(op => ({ + operatorId: op.operatorID, + operatorType: op.operatorType, + customDisplayName: op.customDisplayName, + })), + count: operators.length, + }, + operatorIds, + [] + ); + } catch (error: any) { + return createErrorResult(error.message); + } + }, + }); +} + +export function createGetCurrentOperatorTool( + workflowActionService: WorkflowActionService, + workflowCompilingService: WorkflowCompilingService +) { + return tool({ + name: TOOL_NAME_GET_CURRENT_OPERATOR, + description: + "Get detailed information about a specific operator in the current workflow, including its input and output schemas", + inputSchema: z.object({ + operatorId: z.string().describe("ID of the operator to retrieve"), + }), + execute: async (args: { operatorId: string }) => { + try { + const operator = workflowActionService.getTexeraGraph().getOperator(args.operatorId); + + // Get input schema (empty map if not available) + const inputSchemaMap = workflowCompilingService.getOperatorInputSchemaMap(args.operatorId); + const inputSchema = inputSchemaMap || {}; + + // Get output schema (empty map if not available) + const outputSchemaMap = workflowCompilingService.getOperatorOutputSchemaMap(args.operatorId); + const outputSchema = outputSchemaMap || {}; + + return createSuccessResult( + { + operator: operator, + inputSchema: inputSchema, + outputSchema: outputSchema, + message: `Retrieved operator ${args.operatorId}`, + }, + [args.operatorId], + [] + ); + } catch (error: any) { + return createErrorResult(error.message || `Operator ${args.operatorId} not found`); + } + }, + }); +} diff --git a/frontend/src/app/workspace/service/copilot/tool/react-step-operator-parser.ts b/frontend/src/app/workspace/service/copilot/tool/react-step-operator-parser.ts new file mode 100644 index 00000000000..2477a21579d --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/tool/react-step-operator-parser.ts @@ -0,0 +1,152 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Central parser module to extract operator access information from tool results. + * Tools should populate viewedOperatorIds and modifiedOperatorIds in their results. + */ + +/** + * Operator access information indicating which operators were viewed or modified. + * Tools should populate these fields in their results to indicate operator interaction. + */ +export interface ToolOperatorAccess { + viewedOperatorIds: string[]; + modifiedOperatorIds: string[]; +} + +/** + * Parse operator access from a tool call's result. + * Tools should populate viewedOperatorIds and modifiedOperatorIds in their results. + * + * @param toolCall - The tool call object containing toolName and args + * @param toolResult - The tool result object containing the output/result with operator IDs + * @returns ToolOperatorAccess object with viewedOperatorIds and modifiedOperatorIds + */ +export function parseOperatorAccessFromToolCall(toolCall: any, toolResult?: any): ToolOperatorAccess { + const access: ToolOperatorAccess = { viewedOperatorIds: [], modifiedOperatorIds: [] }; + + if (!toolResult || !toolResult.output) { + return access; + } + + try { + const output = toolResult.output; + + // Extract viewedOperatorIds from tool result + if (Array.isArray(output.viewedOperatorIds)) { + access.viewedOperatorIds = output.viewedOperatorIds.filter((id: any) => id && typeof id === "string"); + } + + // Extract modifiedOperatorIds from tool result + if (Array.isArray(output.modifiedOperatorIds)) { + access.modifiedOperatorIds = output.modifiedOperatorIds.filter((id: any) => id && typeof id === "string"); + } + + // Remove duplicates + access.viewedOperatorIds = [...new Set(access.viewedOperatorIds)]; + access.modifiedOperatorIds = [...new Set(access.modifiedOperatorIds)]; + } catch (error) { + console.error("Error parsing operator access from tool result:", error); + } + + return access; +} + +/** + * Parse operator access for all tool calls in a step. + * + * @param toolCalls - Array of tool call objects + * @param toolResults - Array of corresponding tool result objects + * @returns Map from tool call index to ToolOperatorAccess + */ +export function parseOperatorAccessFromStep(toolCalls: any[], toolResults?: any[]): Map { + const accessMap = new Map(); + + if (!toolCalls || toolCalls.length === 0) { + return accessMap; + } + + for (let i = 0; i < toolCalls.length; i++) { + const toolCall = toolCalls[i]; + const toolResult = toolResults && toolResults[i] ? toolResults[i] : undefined; + const access = parseOperatorAccessFromToolCall(toolCall, toolResult); + + // Only add to map if there are any viewed or modified operations + if (access.viewedOperatorIds.length > 0 || access.modifiedOperatorIds.length > 0) { + accessMap.set(i, access); + } + } + + return accessMap; +} + +/** + * Extract all viewed operator IDs from a ReActStep. + * + * @param step - The ReActStep to extract from + * @returns Array of unique operator IDs that were viewed + */ +export function getAllViewedOperatorIds(step: { operatorAccess?: Map }): string[] { + if (!step.operatorAccess) { + return []; + } + + const allViewedIds: string[] = []; + for (const access of step.operatorAccess.values()) { + allViewedIds.push(...access.viewedOperatorIds); + } + + // Return unique operator IDs + return [...new Set(allViewedIds)]; +} + +/** + * Extract all modified operator IDs from a ReActStep. + * + * @param step - The ReActStep to extract from + * @returns Array of unique operator IDs that were modified + */ +export function getAllModifiedOperatorIds(step: { operatorAccess?: Map }): string[] { + if (!step.operatorAccess) { + return []; + } + + const allModifiedIds: string[] = []; + for (const access of step.operatorAccess.values()) { + allModifiedIds.push(...access.modifiedOperatorIds); + } + + // Return unique operator IDs + return [...new Set(allModifiedIds)]; +} + +/** + * Extract all operator IDs (both viewed and modified) from a ReActStep. + * + * @param step - The ReActStep to extract from + * @returns Array of unique operator IDs involved in this step + */ +export function getAllOperatorIds(step: { operatorAccess?: Map }): string[] { + const viewedIds = getAllViewedOperatorIds(step); + const modifiedIds = getAllModifiedOperatorIds(step); + + // Combine and return unique IDs + return [...new Set([...viewedIds, ...modifiedIds])]; +} diff --git a/frontend/src/app/workspace/service/copilot/tool/tools-utility.ts b/frontend/src/app/workspace/service/copilot/tool/tools-utility.ts new file mode 100644 index 00000000000..f098bd1d91c --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/tool/tools-utility.ts @@ -0,0 +1,138 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Tool execution timeout in milliseconds (2 minutes) +export const TOOL_TIMEOUT_MS = 120000; + +// Maximum token limit for operator result data to prevent overwhelming LLM context +// Estimated as characters / 4 (common approximation for token counting) +export const MAX_OPERATOR_RESULT_TOKEN_LIMIT = 1000; + +/** + * Base interface for all tool execution results. + * Ensures consistent structure across all tools with required tracking fields. + */ +export interface BaseToolResult { + /** + * Indicates whether the tool execution was successful. + */ + success: boolean; + + /** + * Error message if the tool execution failed. + */ + error?: string; + + /** + * List of operator IDs that were viewed/read during tool execution. + * Empty array if no operators were viewed. + */ + viewedOperatorIds: string[]; + + /** + * List of operator IDs that were modified/written during tool execution. + * Empty array if no operators were modified. + */ + modifiedOperatorIds: string[]; +} + +/** + * Creates a successful tool result with default values for required fields. + * Tools can extend this with additional custom fields. + * + * @param data - Custom data fields for the tool result + * @param viewedOperatorIds - Operator IDs that were viewed (default: []) + * @param modifiedOperatorIds - Operator IDs that were modified (default: []) + * @returns BaseToolResult with success=true and provided data + */ +export function createSuccessResult>( + data: T, + viewedOperatorIds: string[] = [], + modifiedOperatorIds: string[] = [] +): BaseToolResult & T { + return { + success: true, + viewedOperatorIds, + modifiedOperatorIds, + ...data, + }; +} + +/** + * Creates a failed tool result with an error message. + * + * @param error - Error message describing the failure + * @returns BaseToolResult with success=false and error message + */ +export function createErrorResult(error: string): BaseToolResult { + return { + success: false, + error, + viewedOperatorIds: [], + modifiedOperatorIds: [], + }; +} + +/** + * Estimates the number of tokens in a JSON-serializable object + * Uses a common approximation: tokens ≈ characters / 4 + */ +export function estimateTokenCount(data: any): number { + try { + const jsonString = JSON.stringify(data); + return Math.ceil(jsonString.length / 4); + } catch (error) { + // Fallback if JSON.stringify fails + return 0; + } +} + +/** + * Wraps a tool definition to add timeout protection to its execute function + * Uses AbortController to properly cancel operations on timeout + */ +export function toolWithTimeout(toolConfig: any): any { + const originalExecute = toolConfig.execute; + + return { + ...toolConfig, + execute: async (args: any) => { + const abortController = new AbortController(); + + const timeoutPromise = new Promise((_, reject) => { + setTimeout(() => { + abortController.abort(); + reject(new Error("timeout")); + }, TOOL_TIMEOUT_MS); + }); + + try { + const argsWithSignal = { ...args, signal: abortController.signal }; + return await Promise.race([originalExecute(argsWithSignal), timeoutPromise]); + } catch (error: any) { + if (error.message === "timeout") { + return createErrorResult( + "Tool execution timeout - operation took longer than 2 minutes. Please try again later." + ); + } + throw error; + } + }, + }; +} diff --git a/frontend/src/app/workspace/service/copilot/tool/workflow-metadata-tools.ts b/frontend/src/app/workspace/service/copilot/tool/workflow-metadata-tools.ts new file mode 100644 index 00000000000..eca6cb76afc --- /dev/null +++ b/frontend/src/app/workspace/service/copilot/tool/workflow-metadata-tools.ts @@ -0,0 +1,140 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { z } from "zod"; +import { tool } from "ai"; +import { OperatorMetadataService } from "../../operator-metadata/operator-metadata.service"; +import { WorkflowUtilService } from "../../workflow-graph/util/workflow-util.service"; + +// Tool name constants +export const TOOL_NAME_LIST_ALL_OPERATOR_TYPES = "listAllOperatorTypes"; +export const TOOL_NAME_GET_OPERATOR_PROPERTIES_SCHEMA = "getOperatorPropertiesSchema"; +export const TOOL_NAME_GET_OPERATOR_PORTS_INFO = "getOperatorPortsInfo"; +export const TOOL_NAME_GET_OPERATOR_METADATA = "getOperatorMetadata"; + +export function createListAllOperatorTypesTool(workflowUtilService: WorkflowUtilService) { + return tool({ + name: TOOL_NAME_LIST_ALL_OPERATOR_TYPES, + description: "Get all available operator types in the system", + inputSchema: z.object({}), + execute: async () => { + try { + const operatorTypes = workflowUtilService.getOperatorTypeList(); + return { + success: true, + operatorTypes: operatorTypes, + count: operatorTypes.length, + }; + } catch (error: any) { + return { success: false, error: error.message }; + } + }, + }); +} + +/** + * Create getOperatorPropertiesSchema tool for getting just the properties schema + * More token-efficient than getOperatorSchema for property-focused queries + */ +export function createGetOperatorPropertiesSchemaTool(operatorMetadataService: OperatorMetadataService) { + return tool({ + name: TOOL_NAME_GET_OPERATOR_PROPERTIES_SCHEMA, + description: "Get only the properties schema for an operator type. Use this before setting operator properties.", + inputSchema: z.object({ + operatorType: z.string().describe("Type of the operator to get properties schema for"), + }), + execute: async (args: { operatorType: string }) => { + try { + const schema = operatorMetadataService.getOperatorSchema(args.operatorType); + const propertiesSchema = { + properties: schema.jsonSchema.properties, + required: schema.jsonSchema.required, + definitions: schema.jsonSchema.definitions, + }; + + return { + success: true, + propertiesSchema: propertiesSchema, + operatorType: args.operatorType, + message: `Retrieved properties schema for operator type ${args.operatorType}`, + }; + } catch (error: any) { + return { success: false, error: error.message }; + } + }, + }); +} + +export function createGetOperatorPortsInfoTool(operatorMetadataService: OperatorMetadataService) { + return tool({ + name: TOOL_NAME_GET_OPERATOR_PORTS_INFO, + description: + "Get input and output port information for an operator type. This is more token-efficient than getOperatorSchema and returns only port details (display names, multi-input support, etc.).", + inputSchema: z.object({ + operatorType: z.string().describe("Type of the operator to get port information for"), + }), + execute: async (args: { operatorType: string }) => { + try { + const schema = operatorMetadataService.getOperatorSchema(args.operatorType); + const portsInfo = { + inputPorts: schema.additionalMetadata.inputPorts, + outputPorts: schema.additionalMetadata.outputPorts, + dynamicInputPorts: schema.additionalMetadata.dynamicInputPorts, + dynamicOutputPorts: schema.additionalMetadata.dynamicOutputPorts, + }; + + return { + success: true, + portsInfo: portsInfo, + operatorType: args.operatorType, + message: `Retrieved port information for operator type ${args.operatorType}`, + }; + } catch (error: any) { + return { success: false, error: error.message }; + } + }, + }); +} + +export function createGetOperatorMetadataTool(operatorMetadataService: OperatorMetadataService) { + return tool({ + name: TOOL_NAME_GET_OPERATOR_METADATA, + description: + "Get semantic metadata for an operator type, including user-friendly name, description, operator group, and capabilities. This is very useful to understand the semantics and purpose of each operator type - what it does, how it works, and what kind of data transformation it performs.", + inputSchema: z.object({ + operatorType: z.string().describe("Type of the operator to get metadata for"), + }), + execute: async (args: { operatorType: string; signal?: AbortSignal }) => { + try { + const schema = operatorMetadataService.getOperatorSchema(args.operatorType); + + const metadata = schema.additionalMetadata; + return { + success: true, + metadata: metadata, + operatorType: args.operatorType, + operatorVersion: schema.operatorVersion, + message: `Retrieved metadata for operator type ${args.operatorType}`, + }; + } catch (error: any) { + return { success: false, error: error.message }; + } + }, + }); +} diff --git a/frontend/src/app/workspace/service/workflow-graph/util/workflow-util.service.ts b/frontend/src/app/workspace/service/workflow-graph/util/workflow-util.service.ts index f172975e001..8bca014062a 100644 --- a/frontend/src/app/workspace/service/workflow-graph/util/workflow-util.service.ts +++ b/frontend/src/app/workspace/service/workflow-graph/util/workflow-util.service.ts @@ -59,6 +59,13 @@ export class WorkflowUtilService { return this.operatorSchemaListCreatedSubject.asObservable(); } + /** + * Returns a list of all available operator types + */ + public getOperatorTypeList(): string[] { + return this.operatorSchemaList.map(schema => schema.operatorType); + } + /** * Generates a new UUID for operator */ @@ -106,7 +113,7 @@ export class WorkflowUtilService { * @param operatorType type of an Operator * @returns a new OperatorPredicate of the operatorType */ - public getNewOperatorPredicate(operatorType: string): OperatorPredicate { + public getNewOperatorPredicate(operatorType: string, customDisplayName?: string): OperatorPredicate { const operatorSchema = this.operatorSchemaList.find(schema => schema.operatorType === operatorType); if (operatorSchema === undefined) { throw new Error(`operatorType ${operatorType} doesn't exist in operator metadata`); @@ -131,8 +138,8 @@ export class WorkflowUtilService { // by default, the operator is not disabled const isDisabled = false; - // by default, the operator name is the user friendly name - const customDisplayName = operatorSchema.additionalMetadata.userFriendlyName; + // Use provided customDisplayName or default to the user friendly name from schema + const displayName = customDisplayName ?? operatorSchema.additionalMetadata.userFriendlyName; const dynamicInputPorts = operatorSchema.additionalMetadata.dynamicInputPorts ?? false; const dynamicOutputPorts = operatorSchema.additionalMetadata.dynamicOutputPorts ?? false; @@ -160,7 +167,7 @@ export class WorkflowUtilService { outputPorts, showAdvanced, isDisabled, - customDisplayName, + customDisplayName: displayName, dynamicInputPorts, dynamicOutputPorts, }; diff --git a/frontend/yarn.lock b/frontend/yarn.lock index 394dcd78803..910cc1b5f6d 100644 --- a/frontend/yarn.lock +++ b/frontend/yarn.lock @@ -24,6 +24,53 @@ __metadata: languageName: node linkType: hard +"@ai-sdk/gateway@npm:2.0.9": + version: 2.0.9 + resolution: "@ai-sdk/gateway@npm:2.0.9" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.17" + "@vercel/oidc": "npm:3.0.3" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/840f94795b96c0fa6e73897ea8dba95fc78af1f8482f3b7d8439b6233b4f4de6979a8b67206f4bbf32649baf2acfb1153a46792dfa20259ca9f5fd214fb25fa5 + languageName: node + linkType: hard + +"@ai-sdk/openai@npm:2.0.67": + version: 2.0.67 + resolution: "@ai-sdk/openai@npm:2.0.67" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.17" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/7e5c407504d7902c17c816aaccd83f642a3b82012cd8467c8f58aef5f08a49b6c31fff775439d541d40b0c8b5b94cc384f18096d1968e23670e22a56fe82d8bd + languageName: node + linkType: hard + +"@ai-sdk/provider-utils@npm:3.0.17": + version: 3.0.17 + resolution: "@ai-sdk/provider-utils@npm:3.0.17" + dependencies: + "@ai-sdk/provider": "npm:2.0.0" + "@standard-schema/spec": "npm:^1.0.0" + eventsource-parser: "npm:^3.0.6" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/1bae6dc4cacd0305b6aa152f9589bbd61c29f150155482c285a77f83d7ed416d52bc2aa7fdaba2e5764530392d9e8f799baea34a63dce6c72ecd3de364dc62d1 + languageName: node + linkType: hard + +"@ai-sdk/provider@npm:2.0.0": + version: 2.0.0 + resolution: "@ai-sdk/provider@npm:2.0.0" + dependencies: + json-schema: "npm:^0.4.0" + checksum: 10c0/e50e520016c9fc0a8b5009cadd47dae2f1c81ec05c1792b9e312d7d15479f024ca8039525813a33425c884e3449019fed21043b1bfabd6a2626152ca9a388199 + languageName: node + linkType: hard + "@ali-hm/angular-tree-component@npm:12.0.5": version: 12.0.5 resolution: "@ali-hm/angular-tree-component@npm:12.0.5" @@ -4598,6 +4645,13 @@ __metadata: languageName: node linkType: hard +"@opentelemetry/api@npm:1.9.0": + version: 1.9.0 + resolution: "@opentelemetry/api@npm:1.9.0" + checksum: 10c0/9aae2fe6e8a3a3eeb6c1fdef78e1939cf05a0f37f8a4fae4d6bf2e09eb1e06f966ece85805626e01ba5fab48072b94f19b835449e58b6d26720ee19a58298add + languageName: node + linkType: hard + "@parcel/watcher-android-arm64@npm:2.4.1": version: 2.4.1 resolution: "@parcel/watcher-android-arm64@npm:2.4.1" @@ -4916,6 +4970,13 @@ __metadata: languageName: node linkType: hard +"@standard-schema/spec@npm:^1.0.0": + version: 1.0.0 + resolution: "@standard-schema/spec@npm:1.0.0" + checksum: 10c0/a1ab9a8bdc09b5b47aa8365d0e0ec40cc2df6437be02853696a0e377321653b0d3ac6f079a8c67d5ddbe9821025584b1fb71d9cc041a6666a96f1fadf2ece15f + languageName: node + linkType: hard + "@stoplight/json-ref-resolver@npm:3.1.5": version: 3.1.5 resolution: "@stoplight/json-ref-resolver@npm:3.1.5" @@ -5927,6 +5988,13 @@ __metadata: languageName: node linkType: hard +"@vercel/oidc@npm:3.0.3": + version: 3.0.3 + resolution: "@vercel/oidc@npm:3.0.3" + checksum: 10c0/c8eecb1324559435f4ab8a955f5ef44f74f546d11c2ddcf28151cb636d989bd4b34e0673fd8716cb21bb21afb34b3de663bacc30c9506036eeecbcbf2fd86241 + languageName: node + linkType: hard + "@vitejs/plugin-basic-ssl@npm:1.0.1": version: 1.0.1 resolution: "@vitejs/plugin-basic-ssl@npm:1.0.1" @@ -6348,6 +6416,20 @@ __metadata: languageName: node linkType: hard +"ai@npm:5.0.93": + version: 5.0.93 + resolution: "ai@npm:5.0.93" + dependencies: + "@ai-sdk/gateway": "npm:2.0.9" + "@ai-sdk/provider": "npm:2.0.0" + "@ai-sdk/provider-utils": "npm:3.0.17" + "@opentelemetry/api": "npm:1.9.0" + peerDependencies: + zod: ^3.25.76 || ^4.1.8 + checksum: 10c0/64e80a36248ef20d42a0e688a58da6ca4255e2510a47d6ee9c2318a669df8e8ffb5c63ddb03a3cadd841058414bd4b324783f168c761f82421e4a1c2ac958933 + languageName: node + linkType: hard + "ajv-formats@npm:2.1.1, ajv-formats@npm:^2.1.1": version: 2.1.1 resolution: "ajv-formats@npm:2.1.1" @@ -10565,6 +10647,13 @@ __metadata: languageName: node linkType: hard +"eventsource-parser@npm:^3.0.6": + version: 3.0.6 + resolution: "eventsource-parser@npm:3.0.6" + checksum: 10c0/70b8ccec7dac767ef2eca43f355e0979e70415701691382a042a2df8d6a68da6c2fca35363669821f3da876d29c02abe9b232964637c1b6635c940df05ada78a + languageName: node + linkType: hard + "execa@npm:^5.0.0": version: 5.1.1 resolution: "execa@npm:5.1.1" @@ -11615,6 +11704,7 @@ __metadata: resolution: "gui@workspace:." dependencies: "@abacritt/angularx-social-login": "npm:2.3.0" + "@ai-sdk/openai": "npm:2.0.67" "@ali-hm/angular-tree-component": "npm:12.0.5" "@angular-builders/custom-webpack": "npm:16.0.1" "@angular-devkit/build-angular": "npm:16.2.12" @@ -11665,6 +11755,7 @@ __metadata: "@types/validator": "npm:13.12.0" "@typescript-eslint/eslint-plugin": "npm:7.0.2" "@typescript-eslint/parser": "npm:7.0.2" + ai: "npm:5.0.93" ajv: "npm:8.10.0" babel-plugin-dynamic-import-node: "npm:2.3.3" backbone: "npm:1.4.1" @@ -11741,6 +11832,7 @@ __metadata: y-quill: "npm:0.1.5" y-websocket: "npm:1.4.0" yjs: "npm:13.5.41" + zod: "npm:3.25.76" zone.js: "npm:0.13.0" languageName: unknown linkType: soft @@ -13109,6 +13201,13 @@ __metadata: languageName: node linkType: hard +"json-schema@npm:^0.4.0": + version: 0.4.0 + resolution: "json-schema@npm:0.4.0" + checksum: 10c0/d4a637ec1d83544857c1c163232f3da46912e971d5bf054ba44fdb88f07d8d359a462b4aec46f2745efbc57053365608d88bc1d7b1729f7b4fc3369765639ed3 + languageName: node + linkType: hard + "json-stable-stringify-without-jsonify@npm:^1.0.1": version: 1.0.1 resolution: "json-stable-stringify-without-jsonify@npm:1.0.1" @@ -20078,6 +20177,13 @@ __metadata: languageName: node linkType: hard +"zod@npm:3.25.76": + version: 3.25.76 + resolution: "zod@npm:3.25.76" + checksum: 10c0/5718ec35e3c40b600316c5b4c5e4976f7fee68151bc8f8d90ec18a469be9571f072e1bbaace10f1e85cf8892ea12d90821b200e980ab46916a6166a4260a983c + languageName: node + linkType: hard + "zone.js@npm:0.13.0": version: 0.13.0 resolution: "zone.js@npm:0.13.0"