Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,30 @@ import javax.ws.rs.Consumes
import javax.ws.rs.core.MediaType
import play.api.libs.json.Json
import kong.unirest.Unirest
import java.util.Base64
import scala.sys.process._
import play.api.libs.json._
import java.nio.file.Paths
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule

case class AIAssistantRequest(code: String, lineNumber: Int, allcode: String)
case class LocateUnannotatedRequest(selectedCode: String, startLine: Int)
case class UnannotatedArgument(
name: String,
startLine: Int,
startColumn: Int,
endLine: Int,
endColumn: Int
)
object UnannotatedArgument {
implicit val format: Format[UnannotatedArgument] = Json.format[UnannotatedArgument]
}

@Path("/aiassistant")
class AIAssistantResource {
val objectMapper = new ObjectMapper()
objectMapper.registerModule(DefaultScalaModule)
final private lazy val isEnabled = AiAssistantManager.validAIAssistant
@GET
@RolesAllowed(Array("REGULAR", "ADMIN"))
Expand Down Expand Up @@ -86,4 +105,45 @@ class AIAssistantResource {
|- For the second situation: you return strictly according to the format " -> type", without adding any extra characters. No need for an explanation, just the result -> type is enough!
""".stripMargin
}

@POST
@RolesAllowed(Array("REGULAR", "ADMIN"))
@Path("/annotate-argument")
@Consumes(Array(MediaType.APPLICATION_JSON))
def locateUnannotated(request: LocateUnannotatedRequest, @Auth user: SessionUser): Response = {
// Encoding the code to transmit multi-line code as a single command-line argument
val encodedCode = Base64.getEncoder.encodeToString(request.selectedCode.getBytes("UTF-8"))
val pythonScriptPath =
Paths
.get(
"src",
"main",
"scala",
"edu",
"uci",
"ics",
"texera",
"web",
"resource",
"aiassistant",
"type_annotation_visitor.py"
)
.toString

try {
val command = s"""python $pythonScriptPath "$encodedCode" ${request.startLine}"""
val result = command.!!
val parsedResult = objectMapper.readValue(result, classOf[List[List[Any]]]).map {
case List(name: String, startLine: Int, startColumn: Int, endLine: Int, endColumn: Int) =>
UnannotatedArgument(name, startLine, startColumn, endLine, endColumn)
case _ =>
throw new RuntimeException("Unexpected format in Python script result")
}
Response.ok(Json.obj("result" -> Json.toJson(parsedResult))).build()
} catch {
case e: Exception =>
e.printStackTrace()
Response.status(500).entity(s"Error executing the Python code: ${e.getMessage}").build()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import pytest
from type_annotation_visitor import find_untyped_variables

class TestFunctionsAndMethods:

@pytest.fixture
def global_functions_code(self):
"""This is the test for global function"""
return """def global_function(a, b=2, /, c=3, *, d, e=5, **kwargs):
pass

def global_function_no_return(a, b):
return a + b

def global_function_with_return(a: int, b: int) -> int:
return a + b
"""

def test_global_functions(self, global_functions_code):
expected_result = [
["c", 1, 32, 1, 33],
["a", 1, 21, 1, 22],
["b", 1, 24, 1, 25],
["d", 1, 40, 1, 41],
["e", 1, 43, 1, 44],
["kwargs", 1, 50, 1, 56],
["a", 4, 31, 4, 32],
["b", 4, 34, 4, 35]

]
untyped_vars = find_untyped_variables(global_functions_code, 1)
assert untyped_vars == expected_result

@pytest.fixture
def class_methods_code(self):
"""This is the test for class methods and static methods"""
return """class MyClass:
def instance_method_no_annotation(self, x, y):
pass

@staticmethod
def static_method(a, b, /, c=3, *, d, **kwargs):
pass

@staticmethod
def static_method_with_annotation(a: int, b: int, /, *, c: int = 5) -> int:
return a + b + c

@classmethod
def class_method(cls, value, /, *, option=True):
pass

@classmethod
def class_method_with_annotation(cls, value: str, /, *, flag: bool = False) -> str:
return value.upper()
"""

def test_class_methods(self, class_methods_code):
expected_result = [
["x", 2, 45, 2, 46],
["y", 2, 48, 2, 49],
["c", 6, 32, 6, 33],
["a", 6, 23, 6, 24],
["b", 6, 26, 6, 27],
["d", 6, 40, 6, 41],
["kwargs", 6, 45, 6, 51],
["value", 14, 27, 14, 32],
["option", 14, 40, 14, 46]
]
untyped_vars = find_untyped_variables(class_methods_code, 1)
assert untyped_vars == expected_result

@pytest.fixture
def lambda_code(self):
"""This is the test for lambda function"""
return """lambda_function = lambda x, y, /, z=0, *, w=1: x + y + z + w
lambda_function_with_annotation = lambda x: x * 2
"""

def test_lambda_functions(self, lambda_code):
with pytest.raises(ValueError) as exc_info:
find_untyped_variables(lambda_code, 1)
assert "Lambda functions do not support type annotation" in str(exc_info.value)

@pytest.fixture
def comprehensive_functions_code(self):
"""This is the test for comprehensive function"""
return """def default_args_function(a, b=2, /, c=3, *, d=4):
pass

def args_kwargs_function(*args, **kwargs):
pass

def function_with_return_annotation(a: int, b: int, /, *, c: int = 0) -> int:
return a + b + c

def function_without_return_annotation(a, b):
return a + b
"""

def test_comprehensive(self, comprehensive_functions_code):
expected_result = [
["c", 1, 38, 1, 39],
["a", 1, 27, 1, 28],
["b", 1, 30, 1, 31],
["d", 1, 46, 1, 47],
["args", 4, 27, 4, 31],
["kwargs", 4, 35, 4, 41],
["a", 10, 40, 10, 41],
["b", 10, 43, 10, 44]
]
untyped_vars = find_untyped_variables(comprehensive_functions_code, 1)
assert untyped_vars == expected_result

@pytest.fixture
def multi_line_function_code(self):
"""This is the test for multi-line function"""
return """def multi_line_function(
a,
b: int = 10,
/,
c: str = "hello",
*,
d,
e=20,
**kwargs
):
pass
"""

def test_multi_lines_argument(self, multi_line_function_code):
expected_result = [
["a", 2, 5, 2, 6],
["d", 7, 5, 7, 6],
["e", 8, 5, 8, 6],
["kwargs", 9, 7, 9, 13]
]
untyped_vars = find_untyped_variables(multi_line_function_code, 1)
assert untyped_vars == expected_result
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import ast
import json
import sys
import base64

class ParentNodeVisitor(ast.NodeVisitor):
def __init__(self):
self.parent = None

def generic_visit(self, node):
node.parent = self.parent
previous_parent = self.parent
self.parent = node
super().generic_visit(node)
self.parent = previous_parent

class TypeAnnotationVisitor(ast.NodeVisitor):
def __init__(self, start_line_offset=0):
self.untyped_args = []
self.start_line_offset = start_line_offset

def visit(self, node):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
self.process_function(node)
elif isinstance(node, ast.Lambda):
raise ValueError("Lambda functions do not support type annotation")
self.generic_visit(node)

def process_function(self, node):
# Boolean to determine if it's a global function or a method
is_method = isinstance(node.parent, ast.ClassDef)
# Boolean to determine if it's a static method
is_staticmethod = False
if is_method and hasattr(node, 'decorator_list'):
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == 'staticmethod':
is_staticmethod = True
elif isinstance(decorator, ast.Attribute) and decorator.attr == 'staticmethod':
is_staticmethod = True
args = node.args

all_args = []
all_args.extend(args.args)
# Positional-only
all_args.extend(args.posonlyargs)
# Keyword-only
all_args.extend(args.kwonlyargs)
# *args
if args.vararg:
all_args.append(args.vararg)
# **kwargs
if args.kwarg:
all_args.append(args.kwarg)

start_index = 0
# Skip the "self" or "cls"
if is_method and not is_staticmethod:
start_index = 1
for i, arg in enumerate(all_args[start_index:]):
if not arg.annotation:
self.add_untyped_arg(arg)


def add_untyped_arg(self, arg):
start_line = arg.lineno + self.start_line_offset - 1
start_col = arg.col_offset + 1
end_line = start_line
end_col = start_col + len(arg.arg)
self.untyped_args.append([arg.arg, start_line, start_col, end_line, end_col])

def find_untyped_variables(source_code, start_line):
tree = ast.parse(source_code)
ParentNodeVisitor().visit(tree)
visitor = TypeAnnotationVisitor(start_line_offset=start_line)
visitor.visit(tree)
return visitor.untyped_args

if __name__ == "__main__":
encoded_code = sys.argv[1]
start_line = int(sys.argv[2])
# Encoding the code to transmit multi-line code as a single command-line argument before, so we need to decode it here
source_code = base64.b64decode(encoded_code).decode('utf-8')
untyped_variables = find_untyped_variables(source_code, start_line)
print(json.dumps(untyped_variables))
Loading