diff --git a/wren-ui/migrations/20250509000000_create_asking_task.js b/wren-ui/migrations/20250509000000_create_asking_task.js new file mode 100644 index 0000000000..bd0fcd5e0d --- /dev/null +++ b/wren-ui/migrations/20250509000000_create_asking_task.js @@ -0,0 +1,34 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = function (knex) { + return knex.schema.createTable('asking_task', (table) => { + table.increments('id').primary(); + table.string('query_id').notNullable().unique(); + table.text('question'); + table.jsonb('detail').defaultTo('{}'); + + table + .integer('thread_id') + .references('id') + .inTable('thread') + .onDelete('CASCADE'); + + table + .integer('thread_response_id') + .references('id') + .inTable('thread_response') + .onDelete('CASCADE'); + + table.timestamps(true, true); + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = function (knex) { + return knex.schema.dropTable('asking_task'); +}; diff --git a/wren-ui/migrations/20250509000001_add_task_id_to_thread.js b/wren-ui/migrations/20250509000001_add_task_id_to_thread.js new file mode 100644 index 0000000000..d483cc2bca --- /dev/null +++ b/wren-ui/migrations/20250509000001_add_task_id_to_thread.js @@ -0,0 +1,25 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table + .integer('asking_task_id') + .nullable() + .references('id') + .inTable('asking_task') + .onDelete('SET NULL'); + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table.dropForeign('asking_task_id'); + table.dropColumn('asking_task_id'); + }); +}; diff --git a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts index fbd5d4d8c6..0fe510c384 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts @@ -632,7 +632,6 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } private transformAskResult(body: any): AskResult { - const { type, intent_reasoning } = body; const { status, error } = this.transformStatusAndError(body); const candidates = (body?.response || []).map((candidate: any) => ({ type: candidate?.type?.toUpperCase() as AskCandidateType, @@ -641,11 +640,16 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { })); return { - type, + type: body?.type, status: status as AskResultStatus, error, response: candidates, - intentReasoning: intent_reasoning, + rephrasedQuestion: body?.rephrased_question, + intentReasoning: body?.intent_reasoning, + sqlGenerationReasoning: body?.sql_generation_reasoning, + retrievedTables: body?.retrieved_tables, + invalidSql: body?.invalid_sql, + traceId: body?.trace_id, }; } diff --git a/wren-ui/src/apollo/server/models/adaptor.ts b/wren-ui/src/apollo/server/models/adaptor.ts index 4339081a06..cbe6a516ce 100644 --- a/wren-ui/src/apollo/server/models/adaptor.ts +++ b/wren-ui/src/apollo/server/models/adaptor.ts @@ -122,7 +122,12 @@ export type AskResult = AskResponse< }>, AskResultStatus > & { + rephrasedQuestion?: string; intentReasoning?: string; + sqlGenerationReasoning?: string; + retrievedTables?: string[]; + invalidSql?: string; + traceId?: string; }; export enum RecommendationQuestionStatus { diff --git a/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts b/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts new file mode 100644 index 0000000000..889bbcf5d0 --- /dev/null +++ b/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts @@ -0,0 +1,71 @@ +import { Knex } from 'knex'; +import { BaseRepository, IBasicRepository } from './baseRepository'; +import { + camelCase, + isPlainObject, + mapKeys, + mapValues, + snakeCase, +} from 'lodash'; +import { AskResult } from '../models/adaptor'; + +export interface AskingTask { + id: number; + queryId: string; + question?: string; + detail?: AskResult; + threadId?: number; + threadResponseId?: number; + createdAt: Date; + updatedAt: Date; +} + +export interface IAskingTaskRepository extends IBasicRepository { + findByQueryId(queryId: string): Promise; +} + +export class AskingTaskRepository + extends BaseRepository + implements IAskingTaskRepository +{ + private readonly jsonbColumns = ['detail']; + + constructor(knexPg: Knex) { + super({ knexPg, tableName: 'asking_task' }); + } + + public async findByQueryId(queryId: string): Promise { + return this.findOneBy({ queryId }); + } + + protected override transformFromDBData = (data: any) => { + if (!isPlainObject(data)) { + throw new Error('Unexpected dbdata'); + } + const camelCaseData = mapKeys(data, (_value, key) => camelCase(key)); + const transformData = mapValues(camelCaseData, (value, key) => { + if (this.jsonbColumns.includes(key)) { + if (typeof value === 'string') { + return value ? JSON.parse(value) : value; + } + return value; + } + return value; + }); + return transformData as AskingTask; + }; + + protected override transformToDBData = (data: any) => { + if (!isPlainObject(data)) { + throw new Error('Unexpected dbdata'); + } + const transformedData = mapValues(data, (value, key) => { + if (this.jsonbColumns.includes(key)) { + return JSON.stringify(value); + } else { + return value; + } + }); + return mapKeys(transformedData, (_value, key) => snakeCase(key)); + }; +} diff --git a/wren-ui/src/apollo/server/repositories/index.ts b/wren-ui/src/apollo/server/repositories/index.ts index 85c78ffa74..8badbe9f0c 100644 --- a/wren-ui/src/apollo/server/repositories/index.ts +++ b/wren-ui/src/apollo/server/repositories/index.ts @@ -15,3 +15,4 @@ export * from './schemaChangeRepository'; export * from './dashboardRepository'; export * from './dashboardItemRepository'; export * from './sqlPairRepository'; +export * from './askingTaskRepository'; diff --git a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts index bb593d9b95..87634989ac 100644 --- a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts +++ b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts @@ -41,10 +41,11 @@ export interface ThreadResponseChartDetail { export interface ThreadResponse { id: number; // ID + askingTaskId?: number; // Reference to asking_task.id viewId?: number; // View ID, if the response is from a view threadId: number; // Reference to thread.id question: string; // Thread response question - sql: string; // SQL query generated by AI service + sql?: string; // SQL query generated by AI service answerDetail?: ThreadResponseAnswerDetail; // AI generated text-based answer detail breakdownDetail?: ThreadResponseBreakdownDetail; // Thread response breakdown detail chartDetail?: ThreadResponseChartDetail; // Thread response chart detail @@ -114,6 +115,8 @@ export class ThreadResponseRepository id: string | number, data: Partial<{ status: AskResultStatus; + sql: string; + viewId: number; answerDetail: ThreadResponseAnswerDetail; breakdownDetail: ThreadResponseBreakdownDetail; chartDetail: ThreadResponseChartDetail; @@ -122,6 +125,8 @@ export class ThreadResponseRepository ) { const transformedData = { status: data.status ? data.status : undefined, + sql: data.sql ? data.sql : undefined, + viewId: data.viewId ? data.viewId : undefined, answerDetail: data.answerDetail ? JSON.stringify(data.answerDetail) : undefined, diff --git a/wren-ui/src/apollo/server/resolvers.ts b/wren-ui/src/apollo/server/resolvers.ts index e952a1618f..898edc066a 100644 --- a/wren-ui/src/apollo/server/resolvers.ts +++ b/wren-ui/src/apollo/server/resolvers.ts @@ -91,6 +91,7 @@ const resolvers = { cancelAskingTask: askingResolver.cancelAskingTask, createInstantRecommendedQuestions: askingResolver.createInstantRecommendedQuestions, + rerunAskingTask: askingResolver.rerunAskingTask, // Thread createThread: askingResolver.createThread, diff --git a/wren-ui/src/apollo/server/resolvers/askingResolver.ts b/wren-ui/src/apollo/server/resolvers/askingResolver.ts index 645f5d01b9..cb2f292595 100644 --- a/wren-ui/src/apollo/server/resolvers/askingResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/askingResolver.ts @@ -16,6 +16,7 @@ import { IContext } from '../types'; import { getLogger } from '@server/utils'; import { format } from 'sql-formatter'; import { + AskingDetailTaskInput, constructCteSql, ThreadRecommendQuestionResult, } from '../services/askingService'; @@ -25,6 +26,7 @@ import { getSampleAskQuestions, } from '../data'; import { TelemetryEvent, WrenService } from '../telemetry/telemetry'; +import { TrackedAskingResult } from '../services'; const logger = getLogger('AskingResolver'); logger.level = 'debug'; @@ -44,7 +46,13 @@ export interface AskingTask { sql: string; }>; error: WrenAIError | null; + rephrasedQuestion?: string; intentReasoning?: string; + sqlGenerationReasoning?: string; + retrievedTables?: string[]; + invalidSql?: string; + traceId?: string; + queryId?: string; } // DetailedThread is a type that represents a detailed thread, which is a thread with responses. @@ -68,6 +76,7 @@ export class AskingResolver { constructor() { this.createAskingTask = this.createAskingTask.bind(this); this.cancelAskingTask = this.cancelAskingTask.bind(this); + this.rerunAskingTask = this.rerunAskingTask.bind(this); this.getAskingTask = this.getAskingTask.bind(this); this.createThread = this.createThread.bind(this); this.getThread = this.getThread.bind(this); @@ -97,6 +106,7 @@ export class AskingResolver { this.generateThreadResponseChart = this.generateThreadResponseChart.bind(this); this.adjustThreadResponseChart = this.adjustThreadResponseChart.bind(this); + this.transformAskingTask = this.transformAskingTask.bind(this); } public async generateProjectRecommendationQuestions( @@ -184,6 +194,10 @@ export class AskingResolver { const askingService = ctx.askingService; const askResult = await askingService.getAskingTask(taskId); + if (!askResult) { + return null; + } + // telemetry const eventName = TelemetryEvent.HOME_ASK_CANDIDATE; if (askResult.status === AskResultStatus.FINISHED) { @@ -206,27 +220,7 @@ export class AskingResolver { ); } - // construct candidates from response - const candidates = await Promise.all( - (askResult.response || []).map(async (response) => { - const view = response.viewId - ? await ctx.viewRepository.findOneBy({ id: response.viewId }) - : null; - return { - type: response.type, - sql: response.sql, - view, - }; - }), - ); - - return { - type: askResult.type, - status: askResult.status, - error: askResult.error, - candidates, - intentReasoning: askResult.intentReasoning, - }; + return this.transformAskingTask(askResult, ctx); } public async createThread( @@ -234,8 +228,9 @@ export class AskingResolver { args: { data: { question?: string; + taskId?: string; + // if we use recommendation questions, sql will be provided sql?: string; - viewId?: number; }; }, ctx: IContext, @@ -243,9 +238,28 @@ export class AskingResolver { const { data } = args; const askingService = ctx.askingService; + + // if taskId is provided, use the result from the asking task + // otherwise, use the input data + let threadInput: AskingDetailTaskInput; + if (data.taskId) { + const askingTask = await askingService.getAskingTask(data.taskId); + if (!askingTask) { + throw new Error(`Asking task ${data.taskId} not found`); + } + + threadInput = { + question: askingTask.question, + trackedAskingResult: askingTask, + }; + } else { + // when we use recommendation questions, there's no task to track + threadInput = data; + } + const eventName = TelemetryEvent.HOME_CREATE_THREAD; try { - const thread = await askingService.createThread(data); + const thread = await askingService.createThread(threadInput); ctx.telemetry.sendEvent(eventName, {}); return thread; } catch (err: any) { @@ -285,6 +299,7 @@ export class AskingResolver { threadId: response.threadId, question: response.question, sql: response.sql, + askingTaskId: response.askingTaskId, breakdownDetail: response.breakdownDetail, answerDetail: response.answerDetail, chartDetail: response.chartDetail, @@ -355,8 +370,9 @@ export class AskingResolver { threadId: number; data: { question?: string; + taskId?: string; + // if we use recommendation questions, sql will be provided sql?: string; - viewId?: number; }; }, ctx: IContext, @@ -365,8 +381,30 @@ export class AskingResolver { const askingService = ctx.askingService; const eventName = TelemetryEvent.HOME_ASK_FOLLOWUP_QUESTION; + + // if taskId is provided, use the result from the asking task + // otherwise, use the input data + let threadResponseInput: AskingDetailTaskInput; + if (data.taskId) { + const askingTask = await askingService.getAskingTask(data.taskId); + if (!askingTask) { + throw new Error(`Asking task ${data.taskId} not found`); + } + + threadResponseInput = { + question: askingTask.question, + trackedAskingResult: askingTask, + }; + } else { + // when we use recommendation questions, there's no task to track + threadResponseInput = data; + } + try { - const response = await askingService.createThreadResponse(data, threadId); + const response = await askingService.createThreadResponse( + threadResponseInput, + threadId, + ); ctx.telemetry.sendEvent(eventName, { data }); return response; } catch (err: any) { @@ -380,6 +418,24 @@ export class AskingResolver { } } + public async rerunAskingTask( + _root: any, + args: { responseId: number }, + ctx: IContext, + ): Promise { + const { responseId } = args; + const askingService = ctx.askingService; + const project = await ctx.projectService.getCurrentProject(); + + const task = await askingService.rerunAskingTask(responseId, { + language: WrenAILanguage[project.language] || WrenAILanguage.EN, + }); + ctx.telemetry.sendEvent(TelemetryEvent.HOME_RERUN_ASKING_TASK, { + responseId, + }); + return task; + } + public async generateThreadResponseBreakdown( _root: any, args: { responseId: number }, @@ -533,7 +589,15 @@ export class AskingResolver { // construct sql from breakdownDetail return format(constructCteSql(parent.breakdownDetail.steps)); } - return format(parent.sql); + return parent.sql ? format(parent.sql) : null; + }, + askingTask: async (parent: ThreadResponse, _args: any, ctx: IContext) => { + const askingService = ctx.askingService; + const askingTask = await askingService.getAskingTaskById( + parent.askingTaskId, + ); + if (!askingTask) return null; + return this.transformAskingTask(askingTask, ctx); }, }); @@ -561,4 +625,43 @@ export class AskingResolver { }; }, }); + + private async transformAskingTask( + askingTask: TrackedAskingResult, + ctx: IContext, + ): Promise { + // construct candidates from response + const candidates = await Promise.all( + (askingTask.response || []).map(async (response) => { + const view = response.viewId + ? await ctx.viewRepository.findOneBy({ id: response.viewId }) + : null; + return { + type: response.type, + sql: response.sql, + view, + }; + }), + ); + + // When the task got cancelled, the type is not set + // we set it to TEXT_TO_SQL as default + const type = + askingTask?.status === AskResultStatus.STOPPED && !askingTask.type + ? AskResultType.TEXT_TO_SQL + : askingTask.type; + return { + type, + status: askingTask.status, + error: askingTask.error, + candidates, + queryId: askingTask.queryId, + rephrasedQuestion: askingTask.rephrasedQuestion, + intentReasoning: askingTask.intentReasoning, + sqlGenerationReasoning: askingTask.sqlGenerationReasoning, + retrievedTables: askingTask.retrievedTables, + invalidSql: askingTask.invalidSql, + traceId: askingTask.traceId, + }; + } } diff --git a/wren-ui/src/apollo/server/schema.ts b/wren-ui/src/apollo/server/schema.ts index 0e5adcd70f..00d8919d0c 100644 --- a/wren-ui/src/apollo/server/schema.ts +++ b/wren-ui/src/apollo/server/schema.ts @@ -584,7 +584,13 @@ export const typeDefs = gql` type: AskingTaskType error: Error candidates: [ResultCandidate!]! + rephrasedQuestion: String intentReasoning: String + sqlGenerationReasoning: String + retrievedTables: [String!] + invalidSql: String + traceId: String + queryId: String } input InstantRecommendedQuestionsInput { @@ -614,13 +620,13 @@ export const typeDefs = gql` input CreateThreadInput { question: String sql: String - viewId: Int + taskId: String } input CreateThreadResponseInput { question: String sql: String - viewId: Int + taskId: String } input ThreadUniqueWhereInput { @@ -695,11 +701,12 @@ export const typeDefs = gql` id: Int! threadId: Int! question: String! - sql: String! + sql: String view: ViewInfo breakdownDetail: ThreadResponseBreakdownDetail answerDetail: ThreadResponseAnswerDetail chartDetail: ThreadResponseChartDetail + askingTask: AskingTask } # Thread only consists of basic information of a thread @@ -903,7 +910,7 @@ export const typeDefs = gql` view(where: ViewWhereUniqueInput!): ViewInfo! # Ask - askingTask(taskId: String!): AskingTask! + askingTask(taskId: String!): AskingTask suggestedQuestions: SuggestedQuestionResponse! threads: [Thread!]! thread(threadId: Int!): DetailedThread! @@ -982,6 +989,7 @@ export const typeDefs = gql` # Ask createAskingTask(data: AskingTaskInput!): Task! cancelAskingTask(taskId: String!): Boolean! + rerunAskingTask(responseId: Int!): Task! # Thread createThread(data: CreateThreadInput!): Thread! diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index 7767b3331d..aeb171d84f 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -1,6 +1,5 @@ import { IWrenAIAdaptor } from '@server/adaptors/wrenAIAdaptor'; import { - AskResult, AskResultStatus, RecommendationQuestionsResult, RecommendationQuestionsInput, @@ -26,7 +25,7 @@ import { TelemetryEvent, WrenService, } from '../telemetry/telemetry'; -import { IViewRepository, Project, View } from '../repositories'; +import { IViewRepository, Project } from '../repositories'; import { IQueryService, PreviewDataResponse } from './queryService'; import { IMDLService } from './mdlService'; import { @@ -36,6 +35,7 @@ import { } from '../backgrounds'; import { getConfig } from '@server/config'; import { TextBasedAnswerBackgroundTracker } from '../backgrounds/textBasedAnswerBackgroundTracker'; +import { IAskingTaskTracker, TrackedAskingResult } from './askingTaskTracker'; const config = getConfig(); @@ -60,7 +60,7 @@ export interface AskingTaskInput { export interface AskingDetailTaskInput { question?: string; sql?: string; - viewId?: number; + trackedAskingResult?: TrackedAskingResult; } export interface AskingDetailTaskUpdateInput { @@ -101,9 +101,22 @@ export interface IAskingService { createAskingTask( input: AskingTaskInput, payload: AskingPayload, + // if the asking task is rerun from a cancelled thread response + rerunFromCancelled?: boolean, + // if the asking task is rerun from a cancelled thread response, + // the previous task id is the task id of the cancelled thread response + previousTaskId?: number, + // if the asking task is rerun from a thread response + // the thread response id is the id of the cancelled thread response + threadResponseId?: number, + ): Promise; + rerunAskingTask( + threadResponseId: number, + payload: AskingPayload, ): Promise; cancelAskingTask(taskId: string): Promise; - getAskingTask(taskId: string): Promise; + getAskingTask(taskId: string): Promise; + getAskingTaskById(id: number): Promise; /** * Asking detail task. @@ -361,6 +374,7 @@ export class AskingService implements IAskingService { private queryService: IQueryService; private telemetry: PostHogTelemetry; private mdlService: IMDLService; + private askingTaskTracker: IAskingTaskTracker; constructor({ telemetry, @@ -372,6 +386,7 @@ export class AskingService implements IAskingService { threadResponseRepository, queryService, mdlService, + askingTaskTracker, }: { telemetry: PostHogTelemetry; wrenAIAdaptor: IWrenAIAdaptor; @@ -382,6 +397,7 @@ export class AskingService implements IAskingService { threadResponseRepository: IThreadResponseRepository; queryService: IQueryService; mdlService: IMDLService; + askingTaskTracker: IAskingTaskTracker; }) { this.wrenAIAdaptor = wrenAIAdaptor; this.deployService = deployService; @@ -423,6 +439,7 @@ export class AskingService implements IAskingService { }); this.mdlService = mdlService; + this.askingTaskTracker = askingTaskTracker; } public async getThreadRecommendationQuestions( @@ -520,6 +537,9 @@ export class AskingService implements IAskingService { public async createAskingTask( input: AskingTaskInput, payload: AskingPayload, + rerunFromCancelled?: boolean, + previousTaskId?: number, + threadResponseId?: number, ): Promise { const { threadId, language } = payload; const deployId = await this.getDeployId(); @@ -527,22 +547,60 @@ export class AskingService implements IAskingService { // if it's a follow-up question, then the input will have a threadId // then use the threadId to get the sql and get the steps of last thread response // construct it into AskHistory and pass to ask - const histories = threadId ? await this.getAskingHistory(threadId) : null; - const response = await this.wrenAIAdaptor.ask({ + const histories = threadId + ? await this.getAskingHistory(threadId, threadResponseId) + : null; + const response = await this.askingTaskTracker.createAskingTask({ query: input.question, histories, deployId, configurations: { language }, + rerunFromCancelled, + previousTaskId, + threadResponseId, }); return { id: response.queryId, }; } + public async rerunAskingTask( + threadResponseId: number, + payload: AskingPayload, + ): Promise { + const threadResponse = await this.threadResponseRepository.findOneBy({ + id: threadResponseId, + }); + + if (!threadResponse) { + throw new Error(`Thread response ${threadResponseId} not found`); + } + + // get the original question and ask again + const question = threadResponse.question; + const input = { + question, + }; + const askingPayload = { + ...payload, + // it's possible that the threadId is not provided in the payload + // so we'll just use the threadId from the thread response + threadId: threadResponse.threadId, + }; + const task = await this.createAskingTask( + input, + askingPayload, + true, + threadResponse.askingTaskId, + threadResponseId, + ); + return task; + } + public async cancelAskingTask(taskId: string): Promise { const eventName = TelemetryEvent.HOME_CANCEL_ASK; try { - await this.wrenAIAdaptor.cancelAsk(taskId); + await this.askingTaskTracker.cancelAskingTask(taskId); this.telemetry.sendEvent(eventName, {}); } catch (err: any) { this.telemetry.sendEvent(eventName, {}, err.extensions?.service, false); @@ -550,24 +608,26 @@ export class AskingService implements IAskingService { } } - public async getAskingTask(taskId: string): Promise { - return this.wrenAIAdaptor.getAskResult(taskId); + public async getAskingTask( + taskId: string, + ): Promise { + return this.askingTaskTracker.getAskingResult(taskId); + } + + public async getAskingTaskById( + id: number, + ): Promise { + return this.askingTaskTracker.getAskingResultById(id); } /** * Asking detail task. * The process of creating a thread is as follows: - * If input contains a viewId, simply create a thread from saved properties of the view. - * Otherwise, create a task on AI service to generate the detail. - * 1. create a task on AI service to generate the detail - * 2. create a thread and the first thread response with question and sql + * 1. create a thread and the first thread response + * 2. create a task on AI service to generate the detail + * 3. update the thread response with the task id */ public async createThread(input: AskingDetailTaskInput): Promise { - // if input contains a viewId, simply create a thread from saved properties of the view - if (input.viewId) { - return this.createThreadFromView(input); - } - // 1. create a thread and the first thread response const { id } = await this.projectService.getCurrentProject(); const thread = await this.threadRepository.createOne({ @@ -575,12 +635,23 @@ export class AskingService implements IAskingService { summary: input.question, }); - await this.threadResponseRepository.createOne({ + const threadResponse = await this.threadResponseRepository.createOne({ threadId: thread.id, question: input.question, sql: input.sql, + askingTaskId: input.trackedAskingResult?.taskId, }); + // if queryId is provided, update asking task + if (input.trackedAskingResult?.taskId) { + await this.askingTaskTracker.bindThreadResponse( + input.trackedAskingResult.taskId, + input.trackedAskingResult.queryId, + thread.id, + threadResponse.id, + ); + } + // return the task id return thread; } @@ -620,29 +691,23 @@ export class AskingService implements IAskingService { throw new Error(`Thread ${threadId} not found`); } - // if input contains a viewId, simply create a thread from saved properties of the view - if (input.viewId) { - const view = await this.viewRepository.findOneBy({ id: input.viewId }); - - if (!view) { - throw new Error(`View ${input.viewId} not found`); - } - - const res = await this.createThreadResponseFromView( - input.question, - view.statement, - view, - thread, - ); - return res; - } - const threadResponse = await this.threadResponseRepository.createOne({ threadId: thread.id, question: input.question, sql: input.sql, + askingTaskId: input.trackedAskingResult?.taskId, }); + // if queryId is provided, update asking task + if (input.trackedAskingResult?.taskId) { + await this.askingTaskTracker.bindThreadResponse( + input.trackedAskingResult.taskId, + input.trackedAskingResult.queryId, + thread.id, + threadResponse.id, + ); + } + return threadResponse; } @@ -933,49 +998,28 @@ export class AskingService implements IAskingService { * @param threadId * @returns Promise */ - private async getAskingHistory(threadId: number): Promise { + private async getAskingHistory( + threadId: number, + excludeThreadResponseId?: number, + ): Promise { if (!threadId) { return []; } - return await this.threadResponseRepository.getResponsesWithThread( + let responses = await this.threadResponseRepository.getResponsesWithThread( threadId, 10, ); - } - private async createThreadFromView(input: AskingDetailTaskInput) { - const view = await this.viewRepository.findOneBy({ id: input.viewId }); - if (!view) { - throw new Error(`View ${input.viewId} not found`); + // exclude the thread response if the excludeThreadResponseId is provided + // it's used when rerun the asking task, we don't want include the cancelled thread response + if (excludeThreadResponseId) { + responses = responses.filter( + (response) => response.id !== excludeThreadResponseId, + ); } - const { id } = await this.projectService.getCurrentProject(); - const thread = await this.threadRepository.createOne({ - projectId: id, - summary: input.question, - }); - - await this.createThreadResponseFromView( - input.question, - view.statement, - view, - thread, - ); - return thread; - } - - private async createThreadResponseFromView( - question: string, - sql: string, - view: View, - thread: Thread, - ) { - return this.threadResponseRepository.createOne({ - threadId: thread.id, - viewId: view.id, - question, - sql, - }); + // filter out the thread response with empty sql + return responses.filter((response) => response.sql); } private getThreadRecommendationQuestionsConfig(project: Project) { diff --git a/wren-ui/src/apollo/server/services/askingTaskTracker.ts b/wren-ui/src/apollo/server/services/askingTaskTracker.ts new file mode 100644 index 0000000000..2ed0274abe --- /dev/null +++ b/wren-ui/src/apollo/server/services/askingTaskTracker.ts @@ -0,0 +1,471 @@ +import { getLogger } from '@server/utils'; +import { + AskResult, + AskResultType, + AskResultStatus, + AskInput, +} from '@server/models/adaptor'; +import { + AskingTask, + IAskingTaskRepository, + IThreadResponseRepository, + IViewRepository, +} from '@server/repositories'; +import { IWrenAIAdaptor } from '../adaptors'; +import * as Errors from '@server/utils/error'; + +const logger = getLogger('AskingTaskTracker'); +logger.level = 'debug'; + +interface TrackedTask { + queryId: string; + taskId?: number; + lastPolled: number; + question?: string; + result?: AskResult; + isFinalized: boolean; + threadResponseId?: number; + rerunFromCancelled?: boolean; +} + +export type TrackedAskingResult = AskResult & { + taskId?: number; + queryId: string; + question: string; +}; + +export type CreateAskingTaskInput = AskInput & { + rerunFromCancelled?: boolean; + previousTaskId?: number; + threadResponseId?: number; +}; + +export interface IAskingTaskTracker { + createAskingTask(input: CreateAskingTaskInput): Promise<{ queryId: string }>; + getAskingResult(queryId: string): Promise; + getAskingResultById(id: number): Promise; + cancelAskingTask(queryId: string): Promise; + bindThreadResponse( + id: number, + queryId: string, + threadId: number, + threadResponseId: number, + ): Promise; +} + +export class AskingTaskTracker implements IAskingTaskTracker { + private wrenAIAdaptor: IWrenAIAdaptor; + private askingTaskRepository: IAskingTaskRepository; + private trackedTasks: Map = new Map(); + private trackedTasksById: Map = new Map(); + private pollingInterval: number; + private memoryRetentionTime: number; + private pollingIntervalId: NodeJS.Timeout; + private runningJobs = new Set(); + private threadResponseRepository: IThreadResponseRepository; + private viewRepository: IViewRepository; + + constructor({ + wrenAIAdaptor, + askingTaskRepository, + threadResponseRepository, + viewRepository, + pollingInterval = 1000, // 1 second + memoryRetentionTime = 5 * 60 * 1000, // 5 minutes + }: { + wrenAIAdaptor: IWrenAIAdaptor; + askingTaskRepository: IAskingTaskRepository; + threadResponseRepository: IThreadResponseRepository; + viewRepository: IViewRepository; + pollingInterval?: number; + memoryRetentionTime?: number; + }) { + this.wrenAIAdaptor = wrenAIAdaptor; + this.askingTaskRepository = askingTaskRepository; + this.threadResponseRepository = threadResponseRepository; + this.viewRepository = viewRepository; + this.pollingInterval = pollingInterval; + this.memoryRetentionTime = memoryRetentionTime; + this.startPolling(); + } + + public async createAskingTask( + input: CreateAskingTaskInput, + ): Promise<{ queryId: string }> { + try { + // Call the AI service to create a task + const response = await this.wrenAIAdaptor.ask(input); + const queryId = response.queryId; + + // validate the input + if ( + input.rerunFromCancelled && + (!input.previousTaskId || !input.threadResponseId) + ) { + throw new Error( + 'Previous task id and thread response id are required if rerun from cancelled', + ); + } + + // Start tracking this task + const task = { + queryId, + lastPolled: Date.now(), + question: input.query, + isFinalized: false, + rerunFromCancelled: input.rerunFromCancelled, + } as TrackedTask; + this.trackedTasks.set(queryId, task); + + // if rerun from cancelled, we update the query id to the previous task + if ( + input.rerunFromCancelled && + input.previousTaskId && + input.threadResponseId + ) { + // set the thread response id in memory to bind the task to the thread response + // we don't have to update to database here because the thread response id is already set in database + task.threadResponseId = input.threadResponseId; + + // update the task id in memory + this.trackedTasksById.set(input.previousTaskId, task); + + // get the latest result from the AI service + // we get the latest result first to make it more responsive to client-side + const result = await this.wrenAIAdaptor.getAskResult(queryId); + + // update the result in memory + task.result = result; + + // update the query id in database + await this.askingTaskRepository.updateOne(input.previousTaskId, { + queryId, + }); + } + + logger.info(`Created asking task with queryId: ${queryId}`); + return { queryId }; + } catch (err) { + logger.error(`Failed to create asking task: ${err}`); + throw err; + } + } + + public async getAskingResult( + queryId: string, + ): Promise { + // Check if we're tracking this task in memory + const trackedTask = this.trackedTasks.get(queryId); + + if (trackedTask && trackedTask.result) { + return { + ...trackedTask.result, + queryId, + question: trackedTask.question, + taskId: trackedTask.taskId, + }; + } + + // If not in memory or no result yet, check the database + return this.getAskingResultFromDB({ queryId }); + } + + public async getAskingResultById( + id: number, + ): Promise { + const task = this.trackedTasksById.get(id); + if (task) { + return this.getAskingResult(task.queryId); + } + + return this.getAskingResultFromDB({ taskId: id }); + } + + public async cancelAskingTask(queryId: string): Promise { + await this.wrenAIAdaptor.cancelAsk(queryId); + } + + public stopPolling(): void { + if (this.pollingIntervalId) { + clearInterval(this.pollingIntervalId); + } + } + + public async bindThreadResponse( + id: number, + queryId: string, + threadId: number, + threadResponseId: number, + ): Promise { + const task = this.trackedTasks.get(queryId); + if (!task) { + throw new Error(`Task ${queryId} not found`); + } + + task.threadResponseId = threadResponseId; + this.trackedTasksById.set(id, task); + await this.askingTaskRepository.updateOne(id, { + threadId, + threadResponseId, + }); + + // check if the task is finalized and has a sql + if (task.isFinalized) { + await this.updateThreadResponseWhenTaskFinalized(task); + } + } + + private startPolling(): void { + this.pollingIntervalId = setInterval(() => { + this.pollTasks(); + }, this.pollingInterval); + } + + private async pollTasks(): Promise { + const now = Date.now(); + const tasksToRemove: string[] = []; + + // Create an array of job functions + const jobs = Array.from(this.trackedTasks.entries()).map( + ([queryId, task]) => + async () => { + try { + // Skip if the job is already running + if (this.runningJobs.has(queryId)) { + return; + } + + // Skip finalized tasks that have been in memory too long + if ( + task.isFinalized && + now - task.lastPolled > this.memoryRetentionTime + ) { + tasksToRemove.push(queryId); + return; + } + + // Skip finalized tasks + if (task.isFinalized) { + return; + } + + // Mark the job as running + this.runningJobs.add(queryId); + + // Poll for updates + logger.info(`Polling for updates for task ${queryId}`); + const result = await this.wrenAIAdaptor.getAskResult(queryId); + task.lastPolled = now; + + // if result is not changed, we don't need to update the database + if (!this.isResultChanged(task.result, result)) { + this.runningJobs.delete(queryId); + return; + } + + // update task in memory if any change + task.result = result; + + // if result is still understanding, we don't need to update the database + if (result.status === AskResultStatus.UNDERSTANDING) { + this.runningJobs.delete(queryId); + return; + } + + // if it's identified as GENERAL or MISLEADING_QUER + // we don't need to update the database and finalize the task + if ( + result.type === AskResultType.GENERAL || + result.type === AskResultType.MISLEADING_QUERY + ) { + task.isFinalized = true; + // if it's rerun from cancelled, we need to update the task result to failed in db + if (task.rerunFromCancelled) { + const errorCode = + result.type === AskResultType.GENERAL + ? Errors.GeneralErrorCodes.IDENTIED_AS_GENERAL + : Errors.GeneralErrorCodes.IDENTIED_AS_MISLEADING_QUERY; + const error = { + code: errorCode, + message: Errors.errorMessages[errorCode], + shortMessage: Errors.shortMessages[errorCode], + }; + await this.updateTaskInDatabase( + { queryId }, + { + ...task, + // update the status to failed + // and the error message should be "IDENTIED_AS_GENERAL" or "IDENTIED_AS_MISLEADING_QUERY" + result: { + ...task.result, + status: AskResultStatus.FAILED, + error, + }, + }, + ); + } + this.runningJobs.delete(queryId); + return; + } + + // update the database + // note: type could be null if it's still being understood or it's stopped + // we already filtered out the understanding status above + // so we update to database if it's stopped as well here. + logger.info(`Updating task ${queryId} in database`); + await this.updateTaskInDatabase({ queryId }, task); + + // Check if task is now finalized + if (this.isTaskFinalized(result.status)) { + task.isFinalized = true; + // update thread response if threadResponseId is provided + if (task.threadResponseId) { + await this.updateThreadResponseWhenTaskFinalized(task); + } + + logger.info( + `Task ${queryId} is finalized with status: ${result.status}`, + ); + } + + // Mark the job as finished + this.runningJobs.delete(queryId); + } catch (err) { + this.runningJobs.delete(queryId); + logger.error(err.stack); + throw err; + } + }, + ); + + // Run all jobs in parallel + Promise.allSettled(jobs.map((job) => job())).then((results) => { + // Log any rejected promises + results.forEach((result, index) => { + if (result.status === 'rejected') { + logger.error(`Job ${index} failed: ${result.reason}`); + } + }); + + // Clean up tasks that have been in memory too long + if (tasksToRemove.length > 0) { + logger.info( + `Cleaning up tasks that have been in memory too long. Tasks: ${tasksToRemove.join( + ', ', + )}`, + ); + } + for (const queryId of tasksToRemove) { + this.trackedTasks.delete(queryId); + } + }); + } + + private async updateThreadResponseWhenTaskFinalized( + task: TrackedTask, + ): Promise { + const response = task?.result?.response?.[0]; + if (!response) { + return; + } + // if the generated response of asking task is not null, update the thread response + if (response.viewId) { + // get sql from the view + const view = await this.viewRepository.findOneBy({ + id: response.viewId, + }); + await this.threadResponseRepository.updateOne(task.threadResponseId, { + sql: view.statement, + viewId: response.viewId, + }); + } else { + await this.threadResponseRepository.updateOne(task.threadResponseId, { + sql: response?.sql, + }); + } + } + + private async getAskingResultFromDB({ + queryId, + taskId, + }: { + queryId?: string; + taskId?: number; + }): Promise { + let taskRecord: AskingTask | null = null; + if (queryId) { + taskRecord = await this.askingTaskRepository.findByQueryId(queryId); + } else if (taskId) { + taskRecord = await this.askingTaskRepository.findOneBy({ id: taskId }); + } + + if (!taskRecord) { + return null; + } + + return { + ...taskRecord?.detail, + queryId: queryId || taskRecord?.queryId, + question: taskRecord?.question, + taskId: taskRecord?.id, + }; + } + + private async updateTaskInDatabase( + filter: { queryId?: string; taskId?: number }, + trackedTask: TrackedTask, + ): Promise { + const { queryId, taskId } = filter; + let taskRecord: AskingTask | null = null; + if (queryId) { + taskRecord = await this.askingTaskRepository.findByQueryId(queryId); + } else if (taskId) { + taskRecord = await this.askingTaskRepository.findOneBy({ id: taskId }); + } + + if (!taskRecord) { + // if record not found, create one + const task = await this.askingTaskRepository.createOne({ + queryId, + question: trackedTask.question, + detail: trackedTask.result, + }); + // update the task id in memory + let existingTask: TrackedTask; + if (queryId) { + existingTask = this.trackedTasks.get(queryId); + } else if (taskId) { + existingTask = this.trackedTasksById.get(taskId); + } + if (existingTask) { + existingTask.taskId = task.id; + } + return; + } + + // update the task + await this.askingTaskRepository.updateOne(taskRecord.id, { + detail: trackedTask.result, + }); + } + + private isTaskFinalized(status: AskResultStatus): boolean { + return [ + AskResultStatus.FINISHED, + AskResultStatus.FAILED, + AskResultStatus.STOPPED, + ].includes(status); + } + + private isResultChanged( + previousResult: AskResult, + newResult: AskResult, + ): boolean { + // check status change + if (previousResult?.status !== newResult.status) { + return true; + } + + return false; + } +} diff --git a/wren-ui/src/apollo/server/services/index.ts b/wren-ui/src/apollo/server/services/index.ts index 99149ce9b7..1fccc42ac3 100644 --- a/wren-ui/src/apollo/server/services/index.ts +++ b/wren-ui/src/apollo/server/services/index.ts @@ -6,3 +6,4 @@ export * from './projectService'; export * from './queryService'; export * from './metadataService'; export * from './dashboardService'; +export * from './askingTaskTracker'; diff --git a/wren-ui/src/apollo/server/telemetry/telemetry.ts b/wren-ui/src/apollo/server/telemetry/telemetry.ts index c96cb1264d..c900e7982a 100644 --- a/wren-ui/src/apollo/server/telemetry/telemetry.ts +++ b/wren-ui/src/apollo/server/telemetry/telemetry.ts @@ -51,6 +51,7 @@ export enum TelemetryEvent { HOME_ANSWER_ADJUST_CHART = 'home_answer_adjust_chart', HOME_ASK_FOLLOWUP_QUESTION = 'home_ask_followup_question', HOME_CANCEL_ASK = 'home_cancel_ask', + HOME_RERUN_ASKING_TASK = 'home_rerun_asking_task', HOME_GENERATE_PROJECT_RECOMMENDATION_QUESTIONS = 'home_generate_project_recommendation_questions', HOME_GENERATE_THREAD_RECOMMENDATION_QUESTIONS = 'home_generate_thread_recommendation_questions', diff --git a/wren-ui/src/apollo/server/utils/error.ts b/wren-ui/src/apollo/server/utils/error.ts index a87b4fdc0a..ffd7e5343e 100644 --- a/wren-ui/src/apollo/server/utils/error.ts +++ b/wren-ui/src/apollo/server/utils/error.ts @@ -40,6 +40,11 @@ export enum GeneralErrorCodes { DEPLOY_SQL_PAIR_ERROR = 'DEPLOY_SQL_PAIR_ERROR', GENERATE_QUESTIONS_ERROR = 'GENERATE_QUESTIONS_ERROR', INVALID_SQL_ERROR = 'INVALID_SQL_ERROR', + + // asking task error + // when rerun from cancelled, the task is identified as general or misleading query + IDENTIED_AS_GENERAL = 'IDENTIED_AS_GENERAL', + IDENTIED_AS_MISLEADING_QUERY = 'IDENTIED_AS_MISLEADING_QUERY', } export const errorMessages = { @@ -87,6 +92,12 @@ export const errorMessages = { [GeneralErrorCodes.GENERATE_QUESTIONS_ERROR]: 'Generate questions error', [GeneralErrorCodes.INVALID_SQL_ERROR]: 'Invalid SQL, please check your SQL syntax', + + // asking task error + [GeneralErrorCodes.IDENTIED_AS_GENERAL]: + 'The question is identified as a general question, please follow-up ask with more specific questions.', + [GeneralErrorCodes.IDENTIED_AS_MISLEADING_QUERY]: + 'The question is identified as a misleading query, please follow-up ask with more specific questions.', }; export const shortMessages = { @@ -109,6 +120,9 @@ export const shortMessages = { [GeneralErrorCodes.GENERATE_QUESTIONS_ERROR]: 'Generate questions error', [GeneralErrorCodes.INVALID_SQL_ERROR]: 'Invalid SQL, please check your SQL syntax', + [GeneralErrorCodes.IDENTIED_AS_GENERAL]: 'Identified as general question', + [GeneralErrorCodes.IDENTIED_AS_MISLEADING_QUERY]: + 'Identified as misleading query', }; export const create = ( diff --git a/wren-ui/src/common.ts b/wren-ui/src/common.ts index 4a7996a479..162d3c0e19 100644 --- a/wren-ui/src/common.ts +++ b/wren-ui/src/common.ts @@ -15,6 +15,7 @@ import { DashboardItemRepository, DashboardRepository, SqlPairRepository, + AskingTaskRepository, } from '@server/repositories'; import { WrenEngineAdaptor, @@ -29,6 +30,7 @@ import { AskingService, MDLService, DashboardService, + AskingTaskTracker, } from '@server/services'; import { PostHogTelemetry } from './apollo/server/telemetry/telemetry'; import { @@ -63,6 +65,7 @@ export const initComponents = () => { const dashboardRepository = new DashboardRepository(knex); const dashboardItemRepository = new DashboardItemRepository(knex); const sqlPairRepository = new SqlPairRepository(knex); + const askingTaskRepository = new AskingTaskRepository(knex); // adaptors const wrenEngineAdaptor = new WrenEngineAdaptor({ @@ -105,6 +108,12 @@ export const initComponents = () => { wrenAIAdaptor, telemetry, }); + const askingTaskTracker = new AskingTaskTracker({ + wrenAIAdaptor, + askingTaskRepository, + threadResponseRepository, + viewRepository, + }); const askingService = new AskingService({ telemetry, wrenAIAdaptor, @@ -115,6 +124,7 @@ export const initComponents = () => { threadResponseRepository, queryService, mdlService, + askingTaskTracker, }); const dashboardService = new DashboardService({ projectService, @@ -158,6 +168,8 @@ export const initComponents = () => { dashboardRepository, dashboardItemRepository, sqlPairRepository, + askingTaskRepository, + // adaptors wrenEngineAdaptor, wrenAIAdaptor, @@ -173,6 +185,7 @@ export const initComponents = () => { dashboardService, sqlPairService, + askingTaskTracker, // background trackers projectRecommendQuestionBackgroundTracker, threadRecommendQuestionBackgroundTracker,