Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions python/tvm/micro/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,16 @@ class TemplateProject:
"""Defines a glue interface to interact with a template project through the API Server."""

@classmethod
def from_directory(cls, template_project_dir, options):
return cls(client.instantiate_from_dir(template_project_dir), options)
def from_directory(cls, template_project_dir):
return cls(client.instantiate_from_dir(template_project_dir))

def __init__(self, api_client, options):
def __init__(self, api_client):
self._api_client = api_client
self._options = options
self._info = self._api_client.server_info_query(__version__)
if not self._info["is_template"]:
raise NotATemplateProjectError()

def generate_project(self, graph_executor_factory, project_dir):
def generate_project(self, graph_executor_factory, project_dir, options):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just noting for posterity that originally i didn't do this because i wanted all calls to Project API server to use the same options, but since there is just one call used with TemplateProject, and you might actually want to examine self._info before constructing options, this makes more sense to me

Copy link
Contributor Author

@gromero gromero Sep 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point and that was exactly what I thought when I was going for that change. Thanks your comment, it really helps me to understand the design better and increases my confidence that I'm following the motivations behind it correctly :)

"""Generate a project given GraphRuntimeFactory."""
model_library_dir = utils.tempdir()
model_library_format_path = model_library_dir.relpath("model.tar")
Expand All @@ -112,10 +111,13 @@ def generate_project(self, graph_executor_factory, project_dir):
model_library_format_path=model_library_format_path,
standalone_crt_dir=get_standalone_crt_dir(),
project_dir=project_dir,
options=self._options,
options=options,
)

return GeneratedProject.from_directory(project_dir, self._options)
return GeneratedProject.from_directory(project_dir, options)

def info(self):
return self._info


def generate_project(
Expand Down Expand Up @@ -147,5 +149,5 @@ def generate_project(
GeneratedProject :
A class that wraps the generated project and which can be used to further interact with it.
"""
template = TemplateProject.from_directory(str(template_project_dir), options)
return template.generate_project(module, str(generated_project_dir))
template = TemplateProject.from_directory(str(template_project_dir))
return template.generate_project(module, str(generated_project_dir), options)