diff --git a/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/AiAssistantResource.scala b/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/AiAssistantResource.scala index f784165dc65..8a390b8cf8f 100644 --- a/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/AiAssistantResource.scala +++ b/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/AiAssistantResource.scala @@ -9,11 +9,30 @@ import javax.ws.rs.Consumes import javax.ws.rs.core.MediaType import play.api.libs.json.Json import kong.unirest.Unirest +import java.util.Base64 +import scala.sys.process._ +import play.api.libs.json._ +import java.nio.file.Paths +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule case class AIAssistantRequest(code: String, lineNumber: Int, allcode: String) +case class LocateUnannotatedRequest(selectedCode: String, startLine: Int) +case class UnannotatedArgument( + name: String, + startLine: Int, + startColumn: Int, + endLine: Int, + endColumn: Int +) +object UnannotatedArgument { + implicit val format: Format[UnannotatedArgument] = Json.format[UnannotatedArgument] +} @Path("/aiassistant") class AIAssistantResource { + val objectMapper = new ObjectMapper() + objectMapper.registerModule(DefaultScalaModule) final private lazy val isEnabled = AiAssistantManager.validAIAssistant @GET @RolesAllowed(Array("REGULAR", "ADMIN")) @@ -86,4 +105,45 @@ class AIAssistantResource { |- For the second situation: you return strictly according to the format " -> type", without adding any extra characters. No need for an explanation, just the result -> type is enough! """.stripMargin } + + @POST + @RolesAllowed(Array("REGULAR", "ADMIN")) + @Path("/annotate-argument") + @Consumes(Array(MediaType.APPLICATION_JSON)) + def locateUnannotated(request: LocateUnannotatedRequest, @Auth user: SessionUser): Response = { + // Encoding the code to transmit multi-line code as a single command-line argument + val encodedCode = Base64.getEncoder.encodeToString(request.selectedCode.getBytes("UTF-8")) + val pythonScriptPath = + Paths + .get( + "src", + "main", + "scala", + "edu", + "uci", + "ics", + "texera", + "web", + "resource", + "aiassistant", + "type_annotation_visitor.py" + ) + .toString + + try { + val command = s"""python $pythonScriptPath "$encodedCode" ${request.startLine}""" + val result = command.!! + val parsedResult = objectMapper.readValue(result, classOf[List[List[Any]]]).map { + case List(name: String, startLine: Int, startColumn: Int, endLine: Int, endColumn: Int) => + UnannotatedArgument(name, startLine, startColumn, endLine, endColumn) + case _ => + throw new RuntimeException("Unexpected format in Python script result") + } + Response.ok(Json.obj("result" -> Json.toJson(parsedResult))).build() + } catch { + case e: Exception => + e.printStackTrace() + Response.status(500).entity(s"Error executing the Python code: ${e.getMessage}").build() + } + } } diff --git a/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/test_type_annotation_visitor.py b/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/test_type_annotation_visitor.py new file mode 100644 index 00000000000..4f6a4b0d3f6 --- /dev/null +++ b/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/test_type_annotation_visitor.py @@ -0,0 +1,139 @@ +import pytest +from type_annotation_visitor import find_untyped_variables + +class TestFunctionsAndMethods: + + @pytest.fixture + def global_functions_code(self): + """This is the test for global function""" + return """def global_function(a, b=2, /, c=3, *, d, e=5, **kwargs): + pass + +def global_function_no_return(a, b): + return a + b + +def global_function_with_return(a: int, b: int) -> int: + return a + b +""" + + def test_global_functions(self, global_functions_code): + expected_result = [ + ["c", 1, 32, 1, 33], + ["a", 1, 21, 1, 22], + ["b", 1, 24, 1, 25], + ["d", 1, 40, 1, 41], + ["e", 1, 43, 1, 44], + ["kwargs", 1, 50, 1, 56], + ["a", 4, 31, 4, 32], + ["b", 4, 34, 4, 35] + + ] + untyped_vars = find_untyped_variables(global_functions_code, 1) + assert untyped_vars == expected_result + + @pytest.fixture + def class_methods_code(self): + """This is the test for class methods and static methods""" + return """class MyClass: + def instance_method_no_annotation(self, x, y): + pass + + @staticmethod + def static_method(a, b, /, c=3, *, d, **kwargs): + pass + + @staticmethod + def static_method_with_annotation(a: int, b: int, /, *, c: int = 5) -> int: + return a + b + c + + @classmethod + def class_method(cls, value, /, *, option=True): + pass + + @classmethod + def class_method_with_annotation(cls, value: str, /, *, flag: bool = False) -> str: + return value.upper() +""" + + def test_class_methods(self, class_methods_code): + expected_result = [ + ["x", 2, 45, 2, 46], + ["y", 2, 48, 2, 49], + ["c", 6, 32, 6, 33], + ["a", 6, 23, 6, 24], + ["b", 6, 26, 6, 27], + ["d", 6, 40, 6, 41], + ["kwargs", 6, 45, 6, 51], + ["value", 14, 27, 14, 32], + ["option", 14, 40, 14, 46] + ] + untyped_vars = find_untyped_variables(class_methods_code, 1) + assert untyped_vars == expected_result + + @pytest.fixture + def lambda_code(self): + """This is the test for lambda function""" + return """lambda_function = lambda x, y, /, z=0, *, w=1: x + y + z + w +lambda_function_with_annotation = lambda x: x * 2 +""" + + def test_lambda_functions(self, lambda_code): + with pytest.raises(ValueError) as exc_info: + find_untyped_variables(lambda_code, 1) + assert "Lambda functions do not support type annotation" in str(exc_info.value) + + @pytest.fixture + def comprehensive_functions_code(self): + """This is the test for comprehensive function""" + return """def default_args_function(a, b=2, /, c=3, *, d=4): + pass + +def args_kwargs_function(*args, **kwargs): + pass + +def function_with_return_annotation(a: int, b: int, /, *, c: int = 0) -> int: + return a + b + c + +def function_without_return_annotation(a, b): + return a + b +""" + + def test_comprehensive(self, comprehensive_functions_code): + expected_result = [ + ["c", 1, 38, 1, 39], + ["a", 1, 27, 1, 28], + ["b", 1, 30, 1, 31], + ["d", 1, 46, 1, 47], + ["args", 4, 27, 4, 31], + ["kwargs", 4, 35, 4, 41], + ["a", 10, 40, 10, 41], + ["b", 10, 43, 10, 44] + ] + untyped_vars = find_untyped_variables(comprehensive_functions_code, 1) + assert untyped_vars == expected_result + + @pytest.fixture + def multi_line_function_code(self): + """This is the test for multi-line function""" + return """def multi_line_function( + a, + b: int = 10, + /, + c: str = "hello", + *, + d, + e=20, + **kwargs +): + pass +""" + + def test_multi_lines_argument(self, multi_line_function_code): + expected_result = [ + ["a", 2, 5, 2, 6], + ["d", 7, 5, 7, 6], + ["e", 8, 5, 8, 6], + ["kwargs", 9, 7, 9, 13] + ] + untyped_vars = find_untyped_variables(multi_line_function_code, 1) + assert untyped_vars == expected_result diff --git a/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/type_annotation_visitor.py b/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/type_annotation_visitor.py new file mode 100644 index 00000000000..39089960554 --- /dev/null +++ b/core/amber/src/main/scala/edu/uci/ics/texera/web/resource/aiassistant/type_annotation_visitor.py @@ -0,0 +1,84 @@ +import ast +import json +import sys +import base64 + +class ParentNodeVisitor(ast.NodeVisitor): + def __init__(self): + self.parent = None + + def generic_visit(self, node): + node.parent = self.parent + previous_parent = self.parent + self.parent = node + super().generic_visit(node) + self.parent = previous_parent + +class TypeAnnotationVisitor(ast.NodeVisitor): + def __init__(self, start_line_offset=0): + self.untyped_args = [] + self.start_line_offset = start_line_offset + + def visit(self, node): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + self.process_function(node) + elif isinstance(node, ast.Lambda): + raise ValueError("Lambda functions do not support type annotation") + self.generic_visit(node) + + def process_function(self, node): + # Boolean to determine if it's a global function or a method + is_method = isinstance(node.parent, ast.ClassDef) + # Boolean to determine if it's a static method + is_staticmethod = False + if is_method and hasattr(node, 'decorator_list'): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == 'staticmethod': + is_staticmethod = True + elif isinstance(decorator, ast.Attribute) and decorator.attr == 'staticmethod': + is_staticmethod = True + args = node.args + + all_args = [] + all_args.extend(args.args) + # Positional-only + all_args.extend(args.posonlyargs) + # Keyword-only + all_args.extend(args.kwonlyargs) + # *args + if args.vararg: + all_args.append(args.vararg) + # **kwargs + if args.kwarg: + all_args.append(args.kwarg) + + start_index = 0 + # Skip the "self" or "cls" + if is_method and not is_staticmethod: + start_index = 1 + for i, arg in enumerate(all_args[start_index:]): + if not arg.annotation: + self.add_untyped_arg(arg) + + + def add_untyped_arg(self, arg): + start_line = arg.lineno + self.start_line_offset - 1 + start_col = arg.col_offset + 1 + end_line = start_line + end_col = start_col + len(arg.arg) + self.untyped_args.append([arg.arg, start_line, start_col, end_line, end_col]) + +def find_untyped_variables(source_code, start_line): + tree = ast.parse(source_code) + ParentNodeVisitor().visit(tree) + visitor = TypeAnnotationVisitor(start_line_offset=start_line) + visitor.visit(tree) + return visitor.untyped_args + +if __name__ == "__main__": + encoded_code = sys.argv[1] + start_line = int(sys.argv[2]) + # Encoding the code to transmit multi-line code as a single command-line argument before, so we need to decode it here + source_code = base64.b64decode(encoded_code).decode('utf-8') + untyped_variables = find_untyped_variables(source_code, start_line) + print(json.dumps(untyped_variables)) diff --git a/core/gui/src/app/workspace/component/code-editor-dialog/code-editor.component.ts b/core/gui/src/app/workspace/component/code-editor-dialog/code-editor.component.ts index c772ad7786c..037f23eceec 100644 --- a/core/gui/src/app/workspace/component/code-editor-dialog/code-editor.component.ts +++ b/core/gui/src/app/workspace/component/code-editor-dialog/code-editor.component.ts @@ -4,7 +4,7 @@ import { WorkflowActionService } from "../../service/workflow-graph/model/workfl import { WorkflowVersionService } from "../../../dashboard/service/user/workflow-version/workflow-version.service"; import { YText } from "yjs/dist/src/types/YText"; import { MonacoBinding } from "y-monaco"; -import { Subject } from "rxjs"; +import { Subject, take } from "rxjs"; import { takeUntil } from "rxjs/operators"; import { MonacoLanguageClient } from "monaco-languageclient"; import { toSocket, WebSocketMessageReader, WebSocketMessageWriter } from "vscode-ws-jsonrpc"; @@ -59,6 +59,10 @@ export class CodeEditorComponent implements AfterViewInit, SafeStyle, OnDestroy public currentRange: monaco.Range | undefined; public suggestionTop: number = 0; public suggestionLeft: number = 0; + // For "Add All Type Annotation" to show the UI individually + private userResponseSubject?: Subject; + private isMultipleVariables: boolean = false; + private componentDestroy = new Subject(); private generateLanguageTitle(language: string): string { return `${language.charAt(0).toUpperCase()}${language.slice(1)} UDF`; @@ -221,6 +225,103 @@ export class CodeEditorComponent implements AfterViewInit, SafeStyle, OnDestroy }, }); } + + // "Add All Type Annotation" Button + editor.addAction({ + id: "all-type-annotation-action", + label: "Add All Type Annotations", + contextMenuGroupId: "1_modification", + contextMenuOrder: 1.1, + run: ed => { + const selection = ed.getSelection(); + const model = ed.getModel(); + if (!model || !selection) { + return; + } + + const selectedCode = model.getValueInRange(selection); + const allCode = model.getValue(); + + this.aiAssistantService + .locateUnannotated(selectedCode, selection.startLineNumber) + .pipe(takeUntil(this.componentDestroy)) + .subscribe(variablesWithoutAnnotations => { + // If no unannotated variable, then do nothing. + if (variablesWithoutAnnotations.length == 0) { + return; + } + + let offset = 0; + let lastLine: number | undefined; + + this.isMultipleVariables = true; + this.userResponseSubject = new Subject(); + + const processNextVariable = (index: number) => { + if (index >= variablesWithoutAnnotations.length) { + this.isMultipleVariables = false; + this.userResponseSubject = undefined; + return; + } + + const currVariable = variablesWithoutAnnotations[index]; + + const variableCode = currVariable.name; + const variableLineNumber = currVariable.startLine; + + // Update range + if (lastLine !== undefined && lastLine === variableLineNumber) { + offset += this.currentSuggestion.length; + } else { + offset = 0; + } + + const variableRange = new monaco.Range( + currVariable.startLine, + currVariable.startColumn + offset, + currVariable.endLine, + currVariable.endColumn + offset + ); + + const highlight = editor.createDecorationsCollection([ + { + range: variableRange, + options: { + hoverMessage: { value: "Argument without Annotation" }, + isWholeLine: false, + className: "annotation-highlight", + }, + }, + ]); + + this.handleTypeAnnotation( + variableCode, + variableRange, + ed as monaco.editor.IStandaloneCodeEditor, + variableLineNumber, + allCode + ); + + lastLine = variableLineNumber; + + // Make sure the currVariable will not go to the next one until the user click the accept/decline button + if (this.userResponseSubject !== undefined) { + const userResponseSubject = this.userResponseSubject; + // Only take one response (accept/decline) + const subscription = userResponseSubject + .pipe(take(1)) + .pipe(takeUntil(this.componentDestroy)) + .subscribe(() => { + highlight.clear(); + subscription.unsubscribe(); + processNextVariable(index + 1); + }); + } + }; + processNextVariable(0); + }); + }, + }); }, }); if (this.language == "python") { @@ -237,7 +338,7 @@ export class CodeEditorComponent implements AfterViewInit, SafeStyle, OnDestroy ): void { this.aiAssistantService .getTypeAnnotations(code, lineNumber, allcode) - .pipe(takeUntil(this.workflowVersionStreamSubject)) + .pipe(takeUntil(this.componentDestroy)) .subscribe({ next: (response: TypeAnnotationResponse) => { const choices = response.choices || []; @@ -285,6 +386,11 @@ export class CodeEditorComponent implements AfterViewInit, SafeStyle, OnDestroy this.currentRange.endColumn ); this.insertTypeAnnotations(this.editor, selection, this.currentSuggestion); + + // Only for "Add All Type Annotation" + if (this.isMultipleVariables && this.userResponseSubject) { + this.userResponseSubject.next(); + } } // close the UI after adding the annotation this.showAnnotationSuggestion = false; @@ -296,9 +402,13 @@ export class CodeEditorComponent implements AfterViewInit, SafeStyle, OnDestroy this.showAnnotationSuggestion = false; this.currentCode = ""; this.currentSuggestion = ""; + + // Only for "Add All Type Annotation" + if (this.isMultipleVariables && this.userResponseSubject) { + this.userResponseSubject.next(); + } } - // Add the type annotation into monaco editor private insertTypeAnnotations( editor: monaco.editor.IStandaloneCodeEditor, selection: monaco.Selection, @@ -306,20 +416,9 @@ export class CodeEditorComponent implements AfterViewInit, SafeStyle, OnDestroy ) { const endLineNumber = selection.endLineNumber; const endColumn = selection.endColumn; - const range = new monaco.Range( - // Insert the content to the end of the selected code - endLineNumber, - endColumn, - endLineNumber, - endColumn - ); - const text = `${annotations}`; - const op = { - range: range, - text: text, - forceMoveMarkers: true, - }; - editor.executeEdits("add annotation", [op]); + const insertPosition = new monaco.Position(endLineNumber, endColumn); + const insertOffset = editor.getModel()?.getOffsetAt(insertPosition) || 0; + this.code?.insert(insertOffset, annotations); } private connectLanguageServer() { diff --git a/core/gui/src/app/workspace/service/ai-assistant/ai-assistant.service.ts b/core/gui/src/app/workspace/service/ai-assistant/ai-assistant.service.ts index 0463b6af309..1c378896128 100644 --- a/core/gui/src/app/workspace/service/ai-assistant/ai-assistant.service.ts +++ b/core/gui/src/app/workspace/service/ai-assistant/ai-assistant.service.ts @@ -13,6 +13,33 @@ export type TypeAnnotationResponse = { }>; }; +export interface UnannotatedArgument + extends Readonly<{ + name: string; + startLine: number; + startColumn: number; + endLine: number; + endColumn: number; + }> {} + +interface UnannotatedArgumentItem { + readonly underlying: { + readonly name: { readonly value: string }; + readonly startLine: { readonly value: number }; + readonly startColumn: { readonly value: number }; + readonly endLine: { readonly value: number }; + readonly endColumn: { readonly value: number }; + }; +} + +interface UnannotatedArgumentResponse { + readonly underlying: { + readonly result: { + readonly value: ReadonlyArray; + }; + }; +} + // Define AI model type export const AI_ASSISTANT_API_BASE_URL = `${AppSettings.getApiEndpoint()}/aiassistant`; export const AI_MODEL = { @@ -66,4 +93,44 @@ export class AIAssistantService { const requestBody = { code, lineNumber, allcode }; return this.http.post(`${AI_ASSISTANT_API_BASE_URL}/annotationresult`, requestBody, {}); } + + public locateUnannotated(selectedCode: string, startLine: number): Observable { + const requestBody = { selectedCode, startLine }; + + return this.http + .post(`${AI_ASSISTANT_API_BASE_URL}/annotate-argument`, requestBody) + .pipe( + map(response => { + if (response) { + const result = response.underlying.result.value.map( + (item: UnannotatedArgumentItem): UnannotatedArgument => ({ + name: item.underlying.name.value, + startLine: item.underlying.startLine.value, + startColumn: item.underlying.startColumn.value, + endLine: item.underlying.endLine.value, + endColumn: item.underlying.endColumn.value, + }) + ); + console.log("Unannotated Arguments:", result); + + return response.underlying.result.value.map( + (item: UnannotatedArgumentItem): UnannotatedArgument => ({ + name: item.underlying.name.value, + startLine: item.underlying.startLine.value, + startColumn: item.underlying.startColumn.value, + endLine: item.underlying.endLine.value, + endColumn: item.underlying.endColumn.value, + }) + ); + } else { + console.error("Unexpected response format:", response); + return []; + } + }), + catchError((error: unknown) => { + console.error("Request to backend failed:", error); + throw new Error("Request to backend failed"); + }) + ); + } } diff --git a/core/gui/src/styles.scss b/core/gui/src/styles.scss index d413e250da0..a8109d9439f 100644 --- a/core/gui/src/styles.scss +++ b/core/gui/src/styles.scss @@ -72,3 +72,11 @@ hr { .ant-tabs-tabpane { padding-right: 24px; } + +body { + overflow: hidden; +} + +.annotation-highlight { + background-color: #6a5acd; +}