Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/contrib/msc/core/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
"""tvm.contrib.msc.core.codegen"""

from .codegen import *
from .sources import *
53 changes: 42 additions & 11 deletions python/tvm/contrib/msc/core/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""tvm.contrib.msc.core.codegen.codegen"""

import os
import subprocess
from typing import Dict, List, Optional, Any, Callable

import tvm
Expand All @@ -40,6 +42,8 @@ class CodeGen(object):
The config to print code.
build_folder: MSCDirectory
The codegen folder.
coda_format: str
The code format cpp| python.
"""

def __init__(
Expand All @@ -49,26 +53,34 @@ def __init__(
codegen_config: Optional[Dict[str, str]] = None,
print_config: Optional[Dict[str, str]] = None,
build_folder: msc_utils.MSCDirectory = None,
code_format: str = "python",
):
self._graph = graph
self._source_getter = source_getter
self._codegen_config = msc_utils.dump_dict(codegen_config)
self._print_config = msc_utils.dump_dict(print_config)
self._build_folder = build_folder or msc_utils.msc_dir(keep_history=False, cleanup=True)
self._code_format = code_format

def load(
self,
inputs: Optional[List[Any]] = None,
weights_binder: Optional[Callable[[MSCGraph, Any, msc_utils.MSCDirectory], Any]] = None,
pre_load: Optional[Callable[[msc_utils.MSCDirectory], Any]] = None,
post_load: Optional[Callable[[Any, msc_utils.MSCDirectory], Any]] = None,
build_model: bool = True,
) -> Any:
"""Generate source and load the model

Parameters
-------
inputs: list<any>
The inputs to build the model.
weights_binder: Callable
The method for binding weights to the model.
pre_load: Callable
The pre processing method before load.
post_load: Callable
The post processing method after load.
build_model: bool
Whether to build the model.

Returns
-------
Expand All @@ -79,13 +91,33 @@ def load(
sources = self._source_getter(self._graph, self._codegen_config, self._print_config)
inputs = inputs or []
with self._build_folder as folder:
# pre processing
if pre_load:
pre_load(folder)
for name, source in sources.items():
folder.add_file(name, source)
builder = msc_utils.load_callable(self._graph.name + ".py:" + self._graph.name)
obj = builder(*inputs)
# load weights
if weights_binder:
obj = weights_binder(obj, folder)
if build_model:
if self._code_format == "cpp":
with folder.create_dir("build"):
command = "cmake ../ && make && mv {} ../".format(self._graph.name)
process = subprocess.Popen(command, shell=True)
process.wait()
assert process.returncode == 0, "Failed to build {} under {}".format(
self._graph.name, os.getcwd()
)
obj = self._graph.name
elif self._code_format == "python":
builder = msc_utils.load_callable(self._graph.name + ".py:" + self._graph.name)
obj = builder(*inputs)
else:
raise NotImplementedError(
"Code format {} is not supported".format(self._code_format)
)
# post processing
if post_load:
obj = post_load(obj, folder)
else:
obj = None
return obj


Expand Down Expand Up @@ -125,8 +157,7 @@ def relay_to_relax(
opt_config=opt_config,
)
source_getter = tvm.get_global_func("msc.framework.tvm.GetRelaxSources")
codegen_config = {"from_relay": True}
codegen = CodeGen(graph, source_getter, codegen_config)
codegen = CodeGen(graph, source_getter, codegen_config={"from_relay": True})
inputs = [
tvm.relax.Var(i.alias, tvm.relax.TensorStructInfo(i.get_shape(), i.dtype_name))
for i in graph.get_inputs()
Expand All @@ -136,4 +167,4 @@ def relay_to_relax(
def _bind_weights(mod: tvm.IRModule, folder: msc_utils.MSCDirectory) -> tvm.IRModule:
return BindParams("main", weights)(mod)

return codegen.load(inputs, _bind_weights)
return codegen.load(inputs, post_load=_bind_weights)
184 changes: 184 additions & 0 deletions python/tvm/contrib/msc/core/codegen/sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""tvm.contrib.msc.core.codegen.sources"""

from typing import Dict


def get_base_h_code() -> str:
"""Create base header file codes

Returns
-------
source: str
The base header source.
"""

return """#ifndef TVM_CONTRIB_MSC_UTILS_BASE_H_
#define TVM_CONTRIB_MSC_UTILS_BASE_H_

#include <string>
#include <utility>
#include <vector>

namespace tvm {
namespace contrib {
namespace msc {

class FileUtils {
public:
static inline bool FileExist(const std::string& file);

template <typename T>
static bool ReadToBuffer(const std::string& file, T* buffer, size_t size);
};

class DatasetReader {
public:
DatasetReader(const std::string& folder, int max_size = -1);

void Reset();

bool ReadNext(void* buffers[], int num_datas = -1);

private:
std::string folder_;
size_t max_size_;
size_t cur_cnt_;
std::vector<std::pair<std::string, size_t>> tensor_info_;
};

} // namespace msc
} // namespace contrib
} // namespace tvm

#endif // TVM_CONTRIB_MSC_UTILS_BASE_H_
"""


def get_base_cc_code() -> str:
"""Create base cc file codes

Returns
-------
source: str
The base cc source.
"""

return """#include "base.h"

#include <algorithm>
#include <fstream>

namespace tvm {
namespace contrib {
namespace msc {

bool FileUtils::FileExist(const std::string& file) {
std::ifstream in_file(file, std::ifstream::binary);
if (in_file.is_open()) {
in_file.close();
return true;
}
return false;
}

template <typename T>
bool FileUtils::ReadToBuffer(const std::string& file, T* buffer, size_t size) {
std::ifstream in_file(file, std::ifstream::binary);
if (!in_file.is_open()) {
return false;
}
try {
in_file.read((char*)(&buffer[0]), size * sizeof(T));
} catch (std::exception const& e) {
in_file.close();
return false;
}
in_file.close();
return true;
}

DatasetReader::DatasetReader(const std::string& folder, int max_size) {
folder_ = folder;
const std::string info_file = folder_ + "/tensor_info";
std::ifstream input(info_file, std::ios::binary);
assert(input.is_open() && ("Failed to open file " + info_file).c_str());
std::string line;
while (getline(input, line)) {
int pos = line.find(":");
assert(pos > 0 && ("Can not find : in line " + line).c_str());
const auto& name = line.substr(0, pos);
const auto& byte_size = line.substr(pos + 1, line.size());
tensor_info_.push_back(std::make_pair(name, static_cast<size_t>(std::stoi(byte_size))));
}
size_t file_cnt = 0;
while (true) {
bool all_exists = true;
for (const auto& pair : tensor_info_) {
const auto& d_file =
folder_ + "/" + pair.first + "/batch_" + std::to_string(file_cnt) + ".bin";
if (!FileUtils::FileExist(d_file)) {
all_exists = false;
break;
}
}
if (!all_exists) {
break;
}
file_cnt++;
}
max_size_ = max_size > 0 ? static_cast<size_t>(max_size) : file_cnt;
max_size_ = std::min(max_size_, file_cnt);
Reset();
}

void DatasetReader::Reset() { cur_cnt_ = 0; }

bool DatasetReader::ReadNext(void* buffers[], int num_datas) {
if (cur_cnt_ >= max_size_) {
return false;
}
size_t max_num = num_datas > 0 ? static_cast<size_t>(num_datas) : tensor_info_.size();
max_num = std::min(max_num, tensor_info_.size());
for (size_t i = 0; i < max_num; i++) {
const auto& pair = tensor_info_[i];
const auto& d_file = folder_ + "/" + pair.first + "/batch_" + std::to_string(cur_cnt_) + ".bin";
if (!FileUtils::ReadToBuffer(d_file, (char*)buffers[i], pair.second)) {
return false;
}
}
cur_cnt_++;
return true;
}

} // namespace msc
} // namespace contrib
} // namespace tvm
"""


def get_base_sources() -> Dict[str, str]:
"""Create base sources for cpp codegen

Returns
-------
sources: dict<str,str>
The base utils sources.
"""

return {"base.h": get_base_h_code(), "base.cc": get_base_cc_code()}
50 changes: 49 additions & 1 deletion python/tvm/contrib/msc/core/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,22 @@ def output_at(self, idx: int) -> MSCTensor:

return _ffi_api.MSCJointOutputAt(self, idx)

def weight_at(self, wtype: str) -> MSCTensor:
"""Get weight from reference.

Parameters
----------
wtype: str
The type of weight.

Returns
-------
weight: MSCTensor
The weight Tensor.
"""

return _ffi_api.MSCJointWeightAt(self, wtype)

def get_inputs(self) -> List[MSCTensor]:
"""Get all the inputs.

Expand Down Expand Up @@ -242,7 +258,7 @@ def get_weights(self) -> Dict[str, MSCTensor]:
"""

src_weights = _ffi_api.MSCJointGetWeights(self)
return {ref: src_weights[ref] for ref in src_weights}
return {wtype: src_weights[wtype] for wtype in src_weights}

def get_attrs(self) -> Dict[str, str]:
"""Get all the attributes from node
Expand Down Expand Up @@ -406,6 +422,22 @@ def __init__(
output_names,
)

def has_node(self, name: str) -> bool:
"""Check if node in the graph.

Parameters
----------
name: string
The name of the node.

Returns
-------
has_node: bool
Whether the node is in the graph
"""

return bool(_ffi_api.MSCGraphHasNode(self, name))

def find_node(self, name: str) -> MSCJoint:
"""Find node by name.

Expand All @@ -422,6 +454,22 @@ def find_node(self, name: str) -> MSCJoint:

return _ffi_api.MSCGraphFindNode(self, name)

def has_tensor(self, name: str) -> bool:
"""Check if tensor in the graph.

Parameters
----------
name: string
The name of the tensor.

Returns
-------
has_tensor: bool
Whether the tensor is in the graph
"""

return bool(_ffi_api.MSCGraphHasTensor(self, name))

def find_tensor(self, name: str) -> MSCTensor:
"""Find tensor by name.

Expand Down
Loading