Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/tvm/micro/project_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""MicroTVM Project API Client and Server"""
16 changes: 11 additions & 5 deletions python/tvm/micro/project_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Project API client.
"""
import base64
import io
import json
import logging
import platform
import os
import pathlib
import subprocess
Expand Down Expand Up @@ -56,6 +59,7 @@ class UnsupportedProtocolVersionError(ProjectAPIErrorBase):

class RPCError(ProjectAPIErrorBase):
def __init__(self, request, error):
ProjectAPIErrorBase.__init__()
self.request = request
self.error = error

Expand Down Expand Up @@ -129,7 +133,8 @@ def _request_reply(self, method, params):

if "error" in reply:
raise server.JSONRPCError.from_json(f"calling method {method}", reply["error"])
elif "result" not in reply:

if "result" not in reply:
raise MalformedReplyError(f"Expected 'result' key in server reply, got {reply!r}")

return reply["result"]
Expand Down Expand Up @@ -189,15 +194,16 @@ def write_transport(self, data, timeout_sec):

# NOTE: windows support untested
SERVER_LAUNCH_SCRIPT_FILENAME = (
f"launch_microtvm_api_server.{'sh' if os.system != 'win32' else '.bat'}"
f"launch_microtvm_api_server.{'sh' if platform.system() != 'Windows' else '.bat'}"
)


SERVER_PYTHON_FILENAME = "microtvm_api_server.py"


def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bool = False):
"""Launch server located in project_dir, and instantiate a Project API Client connected to it."""
"""Launch server located in project_dir, and instantiate a Project API Client
connected to it."""
args = None

project_dir = pathlib.Path(project_dir)
Expand All @@ -224,7 +230,7 @@ def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bo
if debug:
args.append("--debug")

api_server_proc = subprocess.Popen(
api_server_proc = subprocess.Popen( # pylint: disable=unused-variable
args, bufsize=0, pass_fds=(api_server_read_fd, api_server_write_fd), cwd=project_dir
)
os.close(api_server_read_fd)
Expand Down
60 changes: 36 additions & 24 deletions python/tvm/micro/project_api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import re
import select
import sys
import textwrap
import time
import traceback
import typing
Expand Down Expand Up @@ -100,6 +99,7 @@ class JSONRPCError(Exception):
"""An error class with properties that meet the JSON-RPC error spec."""

def __init__(self, code, message, data, client_context=None):
Exception.__init__(self)
self.code = code
self.message = message
self.data = data
Expand All @@ -123,9 +123,7 @@ def __str__(self):

@classmethod
def from_json(cls, client_context, json_error):
# Subclasses of ServerError capture exceptions that occur in the Handler, and thus return a
# traceback. The encoding in `json_error` is also slightly different to allow the specific subclass
# to be identified.
"""Convert an encapsulated ServerError into JSON-RPC compliant format."""
found_server_error = False
try:
if ErrorCode(json_error["code"]) == ErrorCode.SERVER_ERROR:
Expand All @@ -145,6 +143,8 @@ def from_json(cls, client_context, json_error):


class ServerError(JSONRPCError):
"""Superclass for JSON-RPC errors which occur while processing valid requests."""

@classmethod
def from_exception(cls, exc, **kw):
to_return = cls(**kw)
Expand All @@ -168,21 +168,25 @@ def __str__(self):
super_str = super(ServerError, self).__str__()
return context_str + super_str

def set_traceback(self, traceback):
def set_traceback(self, traceback): # pylint: disable=redefined-outer-name
"""Format a traceback to be embedded in the JSON-RPC format."""

if self.data is None:
self.data = {}

if "traceback" not in self.data:
# NOTE: TVM's FFI layer reorders Python stack traces several times and strips
# intermediary lines that start with "Traceback". This logic adds a comment to the first
# stack frame to explicitly identify the first stack frame line that occurs on the server.
# stack frame to explicitly identify the first stack frame line that occurs on the
# server.
traceback_list = list(traceback)

# The traceback list contains one entry per stack frame, and each entry contains 1-2 lines:
# The traceback list contains one entry per stack frame, and each entry contains 1-2
# lines:
# File "path/to/file", line 123, in <method>:
# <copy of the line>
# We want to place a comment on the first line of the outermost frame to indicate this is the
# server-side stack frame.
# We want to place a comment on the first line of the outermost frame to indicate this
# is the server-side stack frame.
first_frame_list = traceback_list[1].split("\n")
self.data["traceback"] = (
traceback_list[0]
Expand Down Expand Up @@ -307,7 +311,8 @@ def flash(self, options: dict):
def open_transport(self, options: dict) -> TransportTimeouts:
"""Open resources needed for the transport layer.

This function might e.g. open files or serial ports needed in write_transport or read_transport.
This function might e.g. open files or serial ports needed in write_transport or
read_transport.

Calling this function enables the write_transport and read_transport calls. If the
transport is not open, this method is a no-op.
Expand All @@ -323,14 +328,16 @@ def open_transport(self, options: dict) -> TransportTimeouts:
def close_transport(self):
"""Close resources needed to operate the transport layer.

This function might e.g. close files or serial ports needed in write_transport or read_transport.
This function might e.g. close files or serial ports needed in write_transport or
read_transport.

Calling this function disables the write_transport and read_transport calls. If the
transport is not open, this method is a no-op.
"""
raise NotImplementedError()

@abc.abstractmethod
# pylint: disable=unidiomatic-typecheck
def read_transport(self, n: int, timeout_sec: typing.Union[float, type(None)]) -> bytes:
"""Read data from the transport.

Expand Down Expand Up @@ -389,7 +396,8 @@ def write_transport(self, data: bytes, timeout_sec: float):
class ProjectAPIServer:
"""Base class for Project API Servers.

This API server implements communication using JSON-RPC 2.0: https://www.jsonrpc.org/specification
This API server implements communication using JSON-RPC 2.0:
https://www.jsonrpc.org/specification

Suggested use of this class is to import this module or copy this file into Project Generator
implementations, then instantiate it with server.start().
Expand Down Expand Up @@ -451,7 +459,7 @@ def serve_one_request(self):
_LOG.error("EOF")
return False

except Exception as exc:
except Exception as exc: # pylint: disable=broad-except
_LOG.error("Caught error reading request", exc_info=1)
return False

Expand All @@ -466,7 +474,7 @@ def serve_one_request(self):
request_id = None if not did_validate else request.get("id")
self._reply_error(request_id, exc)
return did_validate
except Exception as exc:
except Exception as exc: # pylint: disable=broad-except
message = "validating request"
if did_validate:
message = f"calling method {request['method']}"
Expand All @@ -481,7 +489,7 @@ def serve_one_request(self):
VALID_METHOD_RE = re.compile("^[a-zA-Z0-9_]+$")

def _validate_request(self, request):
if type(request) is not dict:
if not isinstance(request, dict):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST, f"request: want dict; got {request!r}", None
)
Expand All @@ -493,26 +501,28 @@ def _validate_request(self, request):
)

method = request.get("method")
if type(method) != str:
if not isinstance(method, str):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST, f'request["method"]: want str; got {method!r}', None
)

if not self.VALID_METHOD_RE.match(method):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST,
f'request["method"]: should match regex {self.VALID_METHOD_RE.pattern}; got {method!r}',
f'request["method"]: should match regex {self.VALID_METHOD_RE.pattern}; '
f"got {method!r}",
None,
)

params = request.get("params")
if type(params) != dict:
if not isinstance(params, dict):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST, f'request["params"]: want dict; got {type(params)}', None
)

request_id = request.get("id")
if type(request_id) not in (str, int, type(None)):
# pylint: disable=unidiomatic-typecheck
if not isinstance(request_id, (str, int, type(None))):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST,
f'request["id"]: want str, number, null; got {request_id!r}',
Expand All @@ -538,10 +548,11 @@ def _dispatch_request(self, request):
params = {}

for var_name, var_type in typing.get_type_hints(interface_method).items():
if var_name == "self" or var_name == "return":
if var_name in ("self", "return"):
continue

# NOTE: types can only be JSON-compatible types, so var_type is expected to be of type 'type'.
# NOTE: types can only be JSON-compatible types, so var_type is expected to be of type
# 'type'.
if var_name not in request_params:
raise JSONRPCError(
ErrorCode.INVALID_PARAMS,
Expand All @@ -553,7 +564,8 @@ def _dispatch_request(self, request):
if not has_preprocessing and not isinstance(param, var_type):
raise JSONRPCError(
ErrorCode.INVALID_PARAMS,
f'method {request["method"]}: parameter {var_name}: want {var_type!r}, got {type(param)!r}',
f'method {request["method"]}: parameter {var_name}: want {var_type!r}, '
f"got {type(param)!r}",
None,
)

Expand Down Expand Up @@ -636,7 +648,7 @@ def _await_nonblocking_ready(rlist, wlist, timeout_sec=None, end_time=None):
return True


def read_with_timeout(fd, n, timeout_sec):
def read_with_timeout(fd, n, timeout_sec): # pylint: disable=invalid-name
"""Read data from a file descriptor, with timeout.

This function is intended as a helper function for implementations of ProjectAPIHandler
Expand Down Expand Up @@ -683,7 +695,7 @@ def read_with_timeout(fd, n, timeout_sec):
return to_return


def write_with_timeout(fd, data, timeout_sec):
def write_with_timeout(fd, data, timeout_sec): # pylint: disable=invalid-name
"""Write data to a file descriptor, with timeout.

This function is intended as a helper function for implementations of ProjectAPIHandler
Expand Down