From b2ed06aa9b6298bdde59bcecbfe8e8e9a4c32aef Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 23 Jan 2023 22:54:12 -0800 Subject: [PATCH 1/2] [TVMScript] Consolidate folder structure This PR consolidates the parser folder into Relay, as it is used specifically for the Relay IR. This is the last step for the TVMScript refactoring, where it established the default text format is the roundtrippable TVMScript e-DSL. --- CMakeLists.txt | 2 +- include/tvm/ir/diagnostic.h | 2 - include/tvm/ir/expr.h | 2 +- include/tvm/ir/module.h | 6 +- include/tvm/ir/{span.h => source_map.h} | 96 +++++++++++++- include/tvm/ir/type.h | 2 +- include/tvm/parser/source_map.h | 119 ------------------ include/tvm/relay/base.h | 2 +- include/tvm/relay/error.h | 3 +- include/tvm/{parser => relay}/parser.h | 16 +-- include/tvm/runtime/metadata_base.h | 5 +- python/tvm/ir/base.py | 42 ++++--- python/tvm/parser.py | 47 +++++++ python/tvm/relay/__init__.py | 3 + .../_ffi_api.py => relay/_ffi_api_parser.py} | 5 +- .../{parser/__init__.py => relay/parser.py} | 22 ++-- rust/tvm/src/ir/module.rs | 4 +- src/ir/diagnostic.cc | 4 +- src/ir/module.cc | 15 +-- src/ir/{span.cc => source_map.cc} | 75 ++++++++++- src/ir/transform.cc | 9 +- src/parser/source_map.cc | 97 -------------- src/relay/backend/utils.cc | 2 +- src/relay/backend/vm/compiler.cc | 2 +- src/relay/ir/base.cc | 17 --- src/relay/ir/function.cc | 12 ++ src/{ => relay}/parser/meta_ref.cc | 4 +- src/{ => relay}/parser/meta_ref.h | 14 +-- src/{ => relay}/parser/op_table.h | 15 ++- src/{ => relay}/parser/parser.cc | 51 ++++---- src/{ => relay}/parser/span_check.cc | 6 +- src/{ => relay}/parser/span_check.h | 11 +- src/{ => relay}/parser/token.h | 29 +++-- src/{ => relay}/parser/tokenizer.h | 33 ++--- src/relay/printer/relay_text_printer.cc | 2 +- src/runtime/profiling.cc | 1 - .../relay/backend/aot/aot_lower_main_test.cc | 4 +- .../relay/collage/candidate_partition_test.cc | 4 +- .../cpp/relay/collage/partition_rule_test.cc | 4 +- tests/cpp/relay/df_pattern_rewrite_test.cc | 4 +- tests/cpp/relay/ir/indexed_graph_test.cc | 6 +- .../relay/transforms/device_domains_test.cc | 4 +- tests/cpp/relay/with_fields_test.cc | 6 +- 43 files changed, 395 insertions(+), 414 deletions(-) rename include/tvm/ir/{span.h => source_map.h} (59%) delete mode 100644 include/tvm/parser/source_map.h rename include/tvm/{parser => relay}/parser.h (86%) create mode 100644 python/tvm/parser.py rename python/tvm/{parser/_ffi_api.py => relay/_ffi_api_parser.py} (91%) rename python/tvm/{parser/__init__.py => relay/parser.py} (71%) rename src/ir/{span.cc => source_map.cc} (61%) delete mode 100644 src/parser/source_map.cc rename src/{ => relay}/parser/meta_ref.cc (98%) rename src/{ => relay}/parser/meta_ref.h (92%) rename src/{ => relay}/parser/op_table.h (93%) rename src/{ => relay}/parser/parser.cc (99%) rename src/{ => relay}/parser/span_check.cc (96%) rename src/{ => relay}/parser/span_check.h (93%) rename src/{ => relay}/parser/token.h (93%) rename src/{ => relay}/parser/tokenizer.h (96%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f18d673e4a2..032e0bc2af00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -287,7 +287,6 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/tir/*.cc src/topi/*.cc src/driver/*.cc - src/parser/*.cc src/support/*.cc src/script/*.cc ) @@ -317,6 +316,7 @@ tvm_file_glob(GLOB RELAY_BACKEND_SRCS tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS src/relay/ir/*.cc src/relay/printer/*.cc + src/relay/parser/*.cc ) tvm_file_glob(GLOB_RECURSE RELAY_QNN_SRCS src/relay/qnn/*.cc diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 41130a5be0aa..3b2407491f26 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -27,14 +27,12 @@ #define TVM_IR_DIAGNOSTIC_H_ #include -#include #include #include namespace tvm { -using tvm::parser::SourceMap; using tvm::runtime::TypedPackedFunc; /*! \brief The diagnostic level, controls the printing of the message. */ diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 78c09e81b16f..c8531c88465a 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,7 +24,7 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ -#include +#include #include #include #include diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 0a5bac182fd9..fdb44b11887c 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -27,8 +27,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -60,7 +60,7 @@ class IRModuleNode : public Object { /*! \brief A map from global type vars to ADT type data. */ Map type_definitions; /*! \brief The source map for the module. */ - parser::SourceMap source_map; + SourceMap source_map; /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; /*! @@ -357,7 +357,7 @@ class IRModule : public ObjectRef { */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, - std::unordered_set import_set = {}, parser::SourceMap map = {}, + std::unordered_set import_set = {}, SourceMap map = {}, DictAttrs attrs = {}); /*! \brief default constructor */ diff --git a/include/tvm/ir/span.h b/include/tvm/ir/source_map.h similarity index 59% rename from include/tvm/ir/span.h rename to include/tvm/ir/source_map.h index b53ca2921fe7..536099f3114b 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/source_map.h @@ -16,20 +16,25 @@ * specific language governing permissions and limitations * under the License. */ - /*! - * \file tvm/ir/span.h - * \brief Span information for debugging purposes. + * \file source_map.h + * \brief A map from source names to source code. */ -#ifndef TVM_IR_SPAN_H_ -#define TVM_IR_SPAN_H_ +#ifndef TVM_IR_SOURCE_MAP_H_ +#define TVM_IR_SOURCE_MAP_H_ #include #include +#include +#include +#include #include +#include +#include namespace tvm { + /*! * \brief The source name in the Span * \sa SourceNameNode, Span @@ -122,5 +127,84 @@ class Span : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; +/*! \brief A program source in any language. + * + * Could represent the source from an ML framework or a source + * representing a tvm::IRModule. + */ +class Source; + +class SourceNode : public Object { + public: + /*! \brief The source name. */ + SourceName source_name; + + /*! \brief The raw source. */ + String source; + + /*! \brief A mapping of line breaks into the raw source. */ + std::vector> line_map; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) { + v->Visit("source_name", &source_name); + v->Visit("source", &source); + } + + static constexpr const char* _type_key = "Source"; + TVM_DECLARE_FINAL_OBJECT_INFO(SourceNode, Object); +}; + +class Source : public ObjectRef { + public: + TVM_DLL Source(SourceName src_name, std::string source); + TVM_DLL tvm::String GetLine(int line); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); +}; + +/*! + * \brief A mapping from a unique source name to source fragment. + */ +class SourceMap; +/*! + * \brief Stores locations in frontend source that generated a node. + */ +class SourceMapNode : public Object { + public: + /*! \brief The source mapping. */ + Map source_map; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } + + bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { + return equal(source_map, other->source_map); + } + + static constexpr const char* _type_key = "SourceMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object); +}; + +class SourceMap : public ObjectRef { + public: + explicit SourceMap(Map source_map); + + explicit SourceMap(std::initializer_list> source_map) + : SourceMap(Map(source_map)) {} + + SourceMap() : SourceMap(Map()) {} + + void Add(const Source& source); + + SourceMapNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); +}; + } // namespace tvm -#endif // TVM_IR_SPAN_H_ + +#endif // TVM_IR_SOURCE_MAP_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 62328f6a074a..c6baf5e08be3 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -49,7 +49,7 @@ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ -#include +#include #include #include #include diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h deleted file mode 100644 index a160c22a2a2f..000000000000 --- a/include/tvm/parser/source_map.h +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file source_map.h - * \brief A map from source names to source code. - */ -#ifndef TVM_PARSER_SOURCE_MAP_H_ -#define TVM_PARSER_SOURCE_MAP_H_ - -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace parser { - -/*! \brief A program source in any language. - * - * Could represent the source from an ML framework or a source - * representing a tvm::IRModule. - */ -class Source; - -class SourceNode : public Object { - public: - /*! \brief The source name. */ - SourceName source_name; - - /*! \brief The raw source. */ - String source; - - /*! \brief A mapping of line breaks into the raw source. */ - std::vector> line_map; - - // override attr visitor - void VisitAttrs(AttrVisitor* v) { - v->Visit("source_name", &source_name); - v->Visit("source", &source); - } - - static constexpr const char* _type_key = "Source"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceNode, Object); -}; - -class Source : public ObjectRef { - public: - TVM_DLL Source(SourceName src_name, std::string source); - TVM_DLL tvm::String GetLine(int line); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); -}; - -/*! - * \brief A mapping from a unique source name to source fragment. - */ -class SourceMap; -/*! - * \brief Stores locations in frontend source that generated a node. - */ -class SourceMapNode : public Object { - public: - /*! \brief The source mapping. */ - Map source_map; - - // override attr visitor - void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } - - bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { - return equal(source_map, other->source_map); - } - - static constexpr const char* _type_key = "SourceMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object); -}; - -class SourceMap : public ObjectRef { - public: - TVM_DLL SourceMap(Map source_map); - - TVM_DLL SourceMap(std::initializer_list> source_map) - : SourceMap(Map(source_map)) {} - - TVM_DLL SourceMap() : SourceMap(Map()) {} - - void Add(const Source& source); - - SourceMapNode* operator->() { - ICHECK(get() != nullptr); - return static_cast(get_mutable()); - } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); -}; - -} // namespace parser -} // namespace tvm - -#endif // TVM_PARSER_SOURCE_MAP_H_ diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 2825bcfc659a..a66b8044998b 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_BASE_H_ #define TVM_RELAY_BASE_H_ -#include +#include #include #include diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index be34e2b8ae1a..abe8278f2f5d 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -20,7 +20,6 @@ #define TVM_RELAY_ERROR_H_ #include -#include #include #include @@ -31,7 +30,7 @@ namespace tvm { namespace relay { /*! * \brief A wrapper around std::stringstream to build error. - * + *include/tvm/ir/type.h * Can be consumed by CompileError to construct an error. * * \code diff --git a/include/tvm/parser/parser.h b/include/tvm/relay/parser.h similarity index 86% rename from include/tvm/parser/parser.h rename to include/tvm/relay/parser.h index 0a73e1a2a532..6e33e7873f60 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/relay/parser.h @@ -16,13 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#ifndef TVM_RELAY_PARSER_H_ +#define TVM_RELAY_PARSER_H_ -#ifndef TVM_PARSER_PARSER_H_ -#define TVM_PARSER_PARSER_H_ -/*! - * \file include/tvm/parser/parser.h - * \brief A parser for TVM IR. - */ #include #include #include @@ -32,7 +28,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using MetaTable = Map>; @@ -45,9 +41,9 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for * modules constructed programaticaly rather than textually. */ -transform::Pass AnnotateSpans(); +tvm::transform::Pass AnnotateSpans(); -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_PARSER_H_ +#endif // TVM_RELAY_PARSER_H_ diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index 698f56d46d28..ca412a3b615c 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -24,7 +24,10 @@ #ifndef TVM_RUNTIME_METADATA_BASE_H_ #define TVM_RUNTIME_METADATA_BASE_H_ -#include +#include +#include +#include +#include #include #include diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index b84a83d55843..5df529b0532f 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -17,17 +17,23 @@ """Common base structures.""" import tvm._ffi import tvm.error -import tvm.runtime._ffi_node_api -from tvm.runtime import Object +from tvm._ffi import get_global_func, register_object +from tvm.runtime import Object, _ffi_node_api from . import _ffi_api, json_compact class Node(Object): - """Base class of all IR Nodes, implements astext function.""" + """Base class of all IR Nodes.""" -@tvm._ffi.register_object("SourceName") +@register_object("SourceMap") +class SourceMap(Object): + def add(self, name, content): + return get_global_func("SourceMapAdd")(self, name, content) + + +@register_object("SourceName") class SourceName(Object): """A identifier for a source location. @@ -38,10 +44,10 @@ class SourceName(Object): """ def __init__(self, name): - self.__init_handle_by_constructor__(_ffi_api.SourceName, name) + self.__init_handle_by_constructor__(_ffi_api.SourceName, name) # type: ignore # pylint: disable=no-member -@tvm._ffi.register_object("Span") +@register_object("Span") class Span(Object): """Specifies a location in a source program. @@ -59,11 +65,11 @@ class Span(Object): def __init__(self, source_name, line, end_line, column, end_column): self.__init_handle_by_constructor__( - _ffi_api.Span, source_name, line, end_line, column, end_column + _ffi_api.Span, source_name, line, end_line, column, end_column # type: ignore # pylint: disable=no-member ) -@tvm._ffi.register_object +@register_object class EnvFunc(Object): """Environment function. @@ -71,11 +77,11 @@ class EnvFunc(Object): """ def __call__(self, *args): - return _ffi_api.EnvFuncCall(self, *args) + return _ffi_api.EnvFuncCall(self, *args) # type: ignore # pylint: disable=no-member @property def func(self): - return _ffi_api.EnvFuncGetPackedFunc(self) + return _ffi_api.EnvFuncGetPackedFunc(self) # type: ignore # pylint: disable=no-member @staticmethod def get(name): @@ -86,7 +92,7 @@ def get(name): name : str The name of the function. """ - return _ffi_api.EnvFuncGet(name) + return _ffi_api.EnvFuncGet(name) # type: ignore # pylint: disable=no-member def load_json(json_str) -> Object: @@ -104,10 +110,10 @@ def load_json(json_str) -> Object: """ try: - return tvm.runtime._ffi_node_api.LoadJSON(json_str) + return _ffi_node_api.LoadJSON(json_str) except tvm.error.TVMError: json_str = json_compact.upgrade_json(json_str) - return tvm.runtime._ffi_node_api.LoadJSON(json_str) + return _ffi_node_api.LoadJSON(json_str) def save_json(node) -> str: @@ -123,7 +129,7 @@ def save_json(node) -> str: json_str : str Saved json string. """ - return tvm.runtime._ffi_node_api.SaveJSON(node) + return _ffi_node_api.SaveJSON(node) def structural_equal(lhs, rhs, map_free_vars=False): @@ -175,7 +181,7 @@ def structural_equal(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) + return bool(_ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) # type: ignore # pylint: disable=no-member def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): @@ -201,7 +207,7 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars) + mismatch = _ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars) # type: ignore # pylint: disable=no-member if mismatch is None: return None else: @@ -233,7 +239,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) + _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) # type: ignore # pylint: disable=no-member def structural_hash(node, map_free_vars=False): @@ -275,4 +281,4 @@ def structural_hash(node, map_free_vars=False): -------- structrual_equal """ - return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars) + return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/parser.py b/python/tvm/parser.py new file mode 100644 index 000000000000..63c40deb2069 --- /dev/null +++ b/python/tvm/parser.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""The legacy TVM parser """ +# pylint: disable=import-outside-toplevel + + +def parse(*args, **kwargs): + """Deprecated, use `tvm.relay.parse` instead""" + from tvm.relay import parse as _impl + + return _impl(*args, **kwargs) + + +def parse_expr(*args, **kwargs): + """Deprecated, use `tvm.relay.parse_expr` instead""" + from tvm.relay import parse_expr as _impl + + return _impl(*args, **kwargs) + + +def fromtext(*args, **kwargs): + """Deprecated, use `tvm.relay.fromtext` instead""" + from tvm.relay import fromtext as _impl + + return _impl(*args, **kwargs) + + +def SpanCheck(*args, **kwargs): + """Deprecated, use `tvm.relay.SpanCheck` instead""" + from tvm.relay import SpanCheck as _impl + + return _impl(*args, **kwargs) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 5e5d1d5f18d8..02eec18d3013 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -65,6 +65,9 @@ # Load Memory Passes from .transform import memory_plan +# Parser +from .parser import parse, parse_expr, fromtext, SpanCheck + # Required to traverse large programs setrecursionlimit(10000) diff --git a/python/tvm/parser/_ffi_api.py b/python/tvm/relay/_ffi_api_parser.py similarity index 91% rename from python/tvm/parser/_ffi_api.py rename to python/tvm/relay/_ffi_api_parser.py index 7fa3b78b72bb..731b926b5655 100644 --- a/python/tvm/parser/_ffi_api.py +++ b/python/tvm/relay/_ffi_api_parser.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.ir""" +"""FFI APIs for Relay parser.""" import tvm._ffi - -tvm._ffi._init_api("parser", __name__) +tvm._ffi._init_api("relay.parser", __name__) diff --git a/python/tvm/parser/__init__.py b/python/tvm/relay/parser.py similarity index 71% rename from python/tvm/parser/__init__.py rename to python/tvm/relay/parser.py index d75ad16ebab2..5e5f00a90eea 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/relay/parser.py @@ -15,25 +15,23 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -"""The under development unified IR parsing infrastructure.""" -from .. import _ffi, Object -from . import _ffi_api - - -@_ffi.register_object("SourceMap") -class SourceMap(Object): - def add(self, name, content): - return _ffi.get_global_func("SourceMapAdd")(self, name, content) +"""The relay parser.""" +from . import _ffi_api_parser def parse(source, source_name="from_string", init_module=None, init_meta_table=None): if init_meta_table is None: init_meta_table = {} - return _ffi_api.ParseModuleInContext(source_name, source, init_module, init_meta_table) + return _ffi_api_parser.ParseModuleInContext( # type: ignore # pylint: disable=no-member + source_name, + source, + init_module, + init_meta_table, + ) def parse_expr(source): - return _ffi_api.ParseExpr("string", source) + return _ffi_api_parser.ParseExpr("string", source) # type: ignore # pylint: disable=no-member def fromtext(source, source_name="from_string"): @@ -42,4 +40,4 @@ def fromtext(source, source_name="from_string"): def SpanCheck(): """A debugging utility for reporting missing span information.""" - return _ffi_api.SpanCheck() + return _ffi_api_parser.SpanCheck() # type: ignore # pylint: disable=no-member diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index ea257af1ebc0..8f71a8be2c7c 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -57,9 +57,9 @@ pub struct IRModuleNode { external! { // Parser functions - #[name("parser.ParseModule")] + #[name("relay.parser.ParseModule")] fn parse_module(file_name: TVMString, source: TVMString) -> IRModule; - #[name("parser.ParseExpr")] + #[name("relay.parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; #[name("ir.IRModule")] fn module_new(funcs: Map, types: Map) -> IRModule; diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 336575a93e97..6687a28d8c84 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -22,14 +22,12 @@ * \brief Implementation of DiagnosticContext and friends. */ #include -#include +#include #include namespace tvm { -using tvm::parser::Source; - // failed to check to argument arg0.dims[0] != 0 /* Diagnostic */ diff --git a/src/ir/module.cc b/src/ir/module.cc index b6923cd1e60d..22c6faf3d69d 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -16,16 +16,14 @@ * specific language governing permissions and limitations * under the License. */ - /*! * \file module.cc - * \brief The global module in Relay. + * \brief The global module in TVM. */ #include #include #include #include -#include #include #include @@ -36,8 +34,7 @@ namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set, parser::SourceMap source_map, - DictAttrs attrs) { + std::unordered_set import_set, SourceMap source_map, DictAttrs attrs) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -322,12 +319,14 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, const Mapimport_set_.count(path) == 0) { this->import_set_.insert(path); std::fstream src_file(path, std::fstream::in); std::string file_contents{std::istreambuf_iterator(src_file), std::istreambuf_iterator()}; - auto mod_to_import = parser::ParseModule(path, file_contents, GetRef(this)); + auto mod_to_import = (*f)(path, file_contents, GetRef(this)); Update(mod_to_import); } } @@ -342,7 +341,9 @@ void IRModuleNode::ImportFromStd(const String& path) { std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } IRModule IRModule::FromText(const String& text, const String& source_path) { - return tvm::parser::ParseModule(source_path, text); + static const auto* f = runtime::Registry::Get("relay.parser.ParseModule"); + ICHECK(f != nullptr) << "ValueError: Relay parser is not available"; + return (*f)(source_path, text, Optional()); } TVM_REGISTER_NODE_TYPE(IRModuleNode); diff --git a/src/ir/span.cc b/src/ir/source_map.cc similarity index 61% rename from src/ir/span.cc rename to src/ir/source_map.cc index 39f0044d16d3..8b913906ea42 100644 --- a/src/ir/span.cc +++ b/src/ir/source_map.cc @@ -17,11 +17,10 @@ * under the License. */ /*! - * \file span.cc - * \brief The span data structure. + * \file source_map.cc + * \brief The implementation of the source map data structure. */ -#include -#include +#include #include #include @@ -100,4 +99,72 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Span(" << node->source_name << ", " << node->line << ", " << node->end_line << ", " << node->column << ", " << node->end_column << ")"; }); + +TVM_REGISTER_NODE_TYPE(SourceNode); + +/*! \brief Construct a source from a string. */ +Source::Source(SourceName src_name, std::string source) { + auto n = make_object(); + n->source_name = std::move(src_name); + n->source = std::move(source); + + int index = 0; + int length = 0; + n->line_map.push_back({index, length}); + // NB(@jroesch): + std::string source_str = n->source; + for (auto c : source_str) { + if (c == '\n') { + // Record the length of the line. + n->line_map.back().second = length; + // Bump past the newline. + index += 1; + // Record the start of the next line, and put placeholder for length. + n->line_map.push_back({index, 0}); + // Reset length to zero. + length = 0; + } else { + length += 1; + index += 1; + } + } + n->line_map.back().second = length; + + data_ = n; +} + +tvm::String Source::GetLine(int line) { + VLOG(1) << "Source::GetLine: line=" << line; + ICHECK(line - 1 < static_cast((*this)->line_map.size())) + << "requested line: " << line << "at index: " << (line - 1) + << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; + + // Adjust for zero indexing, now have (line_start, line_length); + auto range = (*this)->line_map.at(line - 1); + int line_start = range.first; + int line_length = range.second; + VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; + // TODO(@jroesch): expose substring on tvm::String. + auto line_text = std::string((*this)->source).substr(line_start, line_length); + VLOG(1) << "Source::GetLine: line_text=" << line_text; + return line_text; +} + +TVM_REGISTER_NODE_TYPE(SourceMapNode); + +SourceMap::SourceMap(Map source_map) { + auto n = make_object(); + n->source_map = std::move(source_map); + data_ = std::move(n); +} + +void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } + +TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) { + auto src_name = SourceName::Get(name); + Source source(src_name, content); + map.Add(source); + return src_name; +}); + } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 9a669493ccb7..66b06e6b505d 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -587,11 +587,12 @@ TVM_REGISTER_GLOBAL("transform.OverrideInstruments") Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { - if (const auto* f = runtime::Registry::Get("relay.PrintIR")) { - (*f)(mod, header, show_meta_data); - } else { - LOG(INFO) << "PrintIR(" << header << "):\n" << mod; + if (const auto* f = runtime::Registry::Get("relay.ir.PrintIR")) { + if ((*f)(mod, header, show_meta_data)) { + return mod; + } } + LOG(INFO) << "PrintIR(" << header << "):\n" << mod; return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc deleted file mode 100644 index 3c1329670c40..000000000000 --- a/src/parser/source_map.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file source_map.cc - * \brief The implementation of the source map data structure. - */ -#include -#include - -namespace tvm { -namespace parser { - -TVM_REGISTER_NODE_TYPE(SourceNode); - -/*! \brief Construct a source from a string. */ -Source::Source(SourceName src_name, std::string source) { - auto n = make_object(); - n->source_name = std::move(src_name); - n->source = std::move(source); - - int index = 0; - int length = 0; - n->line_map.push_back({index, length}); - // NB(@jroesch): - std::string source_str = n->source; - for (auto c : source_str) { - if (c == '\n') { - // Record the length of the line. - n->line_map.back().second = length; - // Bump past the newline. - index += 1; - // Record the start of the next line, and put placeholder for length. - n->line_map.push_back({index, 0}); - // Reset length to zero. - length = 0; - } else { - length += 1; - index += 1; - } - } - n->line_map.back().second = length; - - data_ = n; -} - -tvm::String Source::GetLine(int line) { - VLOG(1) << "Source::GetLine: line=" << line; - ICHECK(line - 1 < static_cast((*this)->line_map.size())) - << "requested line: " << line << "at index: " << (line - 1) - << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; - - // Adjust for zero indexing, now have (line_start, line_length); - auto range = (*this)->line_map.at(line - 1); - int line_start = range.first; - int line_length = range.second; - VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; - // TODO(@jroesch): expose substring on tvm::String. - auto line_text = std::string((*this)->source).substr(line_start, line_length); - VLOG(1) << "Source::GetLine: line_text=" << line_text; - return line_text; -} - -TVM_REGISTER_NODE_TYPE(SourceMapNode); - -SourceMap::SourceMap(Map source_map) { - auto n = make_object(); - n->source_map = std::move(source_map); - data_ = std::move(n); -} - -void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } - -TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) { - auto src_name = SourceName::Get(name); - Source source(src_name, content); - map.Add(source); - return src_name; -}); - -} // namespace parser -} // namespace tvm diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 183a3094e473..4ff8a59b349e 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -25,7 +25,7 @@ #include "utils.h" -#include +#include #include #include #include diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index fb23c4cc082a..c29b3195a3fd 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -25,13 +25,13 @@ #include "compiler.h" #include -#include #include #include #include #include #include #include +#include #include #include #include diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 5f913026080d..deedd283c2ff 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -39,22 +39,5 @@ Id::Id(String name_hint) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span sp) { - if (auto* rn = node_ref.as()) { - rn->span = sp; - } else if (auto* rn = node_ref.as()) { - rn->span = sp; - } else if (auto* rn = node_ref.as()) { - rn->span = sp; - } else { - LOG(FATAL) << "Expect Type or RelayNode "; - } -}); - -TVM_REGISTER_GLOBAL("relay.PrintIR") - .set_body_typed([](ObjectRef mod, String header, bool show_metadata) { - LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata); - }); - } // namespace relay } // namespace tvm diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 3ff5eaa059c1..5d743d521777 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -123,6 +123,7 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) { } return nullptr; } + TVM_REGISTER_GLOBAL("relay.ir.PrintRelayModule") .set_body_typed([](IRModule mod) -> Optional { for (const auto& it : mod->functions) { @@ -133,6 +134,17 @@ TVM_REGISTER_GLOBAL("relay.ir.PrintRelayModule") return NullOpt; }); +TVM_REGISTER_GLOBAL("relay.ir.PrintIR") + .set_body_typed([](IRModule mod, String header, bool show_metadata) -> bool { + for (const auto& it : mod->functions) { + if (it.second->IsInstance()) { + LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata); + return true; + } + } + return false; + }); + TVM_REGISTER_GLOBAL("relay.ir.WarnIfMalformed") .set_body_typed([](const IRModule& mod, const BaseFunc& base_func) -> void { if (const auto* relay_func = base_func.as()) { diff --git a/src/parser/meta_ref.cc b/src/relay/parser/meta_ref.cc similarity index 98% rename from src/parser/meta_ref.cc rename to src/relay/parser/meta_ref.cc index 6b0e8d0c5966..cdc6929622dd 100644 --- a/src/parser/meta_ref.cc +++ b/src/relay/parser/meta_ref.cc @@ -30,7 +30,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using tvm::relay::transform::CreateFunctionPass; using tvm::transform::PassContext; @@ -95,5 +95,5 @@ IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) { return pass(mod, PassContext::Create()); } -} // namespace parser +} // namespace relay } // namespace tvm diff --git a/src/parser/meta_ref.h b/src/relay/parser/meta_ref.h similarity index 92% rename from src/parser/meta_ref.h rename to src/relay/parser/meta_ref.h index 483b7f726e07..bed67bea05a4 100644 --- a/src/parser/meta_ref.h +++ b/src/relay/parser/meta_ref.h @@ -22,20 +22,18 @@ * \brief A reference into the metadata section of the Relay text format. */ -#ifndef TVM_PARSER_META_REF_H_ -#define TVM_PARSER_META_REF_H_ +#ifndef TVM_RELAY_PARSER_META_REF_H_ +#define TVM_RELAY_PARSER_META_REF_H_ #include -#include #include #include +#include #include namespace tvm { -namespace parser { - -using namespace relay; +namespace relay { /*! * \brief Options for allocating storage. @@ -78,7 +76,7 @@ Expr MetaRef(std::string type_key, uint64_t node_index); relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func); IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod); -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_META_REF_H_ +#endif // TVM_RELAY_PARSER_META_REF_H_ diff --git a/src/parser/op_table.h b/src/relay/parser/op_table.h similarity index 93% rename from src/parser/op_table.h rename to src/relay/parser/op_table.h index 28c9cd7fc05f..6ff2c05476f4 100644 --- a/src/parser/op_table.h +++ b/src/relay/parser/op_table.h @@ -18,14 +18,13 @@ */ /*! - * \file token.h + * \file op_table.h * \brief A operator table for parsing. - * * Provides symbolic token sequences to map to TVM operators, with a given associativity and arity. */ -#ifndef TVM_PARSER_OP_TABLE_H_ -#define TVM_PARSER_OP_TABLE_H_ +#ifndef TVM_RELAY_PARSER_OP_TABLE_H_ +#define TVM_RELAY_PARSER_OP_TABLE_H_ #include #include @@ -38,7 +37,7 @@ #include "./tokenizer.h" namespace tvm { -namespace parser { +namespace relay { struct Rule { std::vector tokens; @@ -77,7 +76,7 @@ struct OperatorTable { } }; -OperatorTable DefaultOpTable() { +inline OperatorTable DefaultOpTable() { return OperatorTable( {Rule({TokenType::kStar}, Op::Get("multiply"), 12, 2, true), Rule({TokenType::kDivision}, Op::Get("divide"), 12, 2, true), @@ -91,6 +90,6 @@ OperatorTable DefaultOpTable() { Rule({TokenType::kBang, TokenType::kEqual}, Op::Get("not_equal"), 7, 2, true)}); } -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_OP_TABLE_H_ +#endif // TVM_RELAY_PARSER_OP_TABLE_H_ diff --git a/src/parser/parser.cc b/src/relay/parser/parser.cc similarity index 99% rename from src/parser/parser.cc rename to src/relay/parser/parser.cc index fe89857f2709..ae7fc52cbead 100644 --- a/src/parser/parser.cc +++ b/src/relay/parser/parser.cc @@ -23,11 +23,12 @@ */ #include #include -#include #include #include #include +#include #include +#include #include #include #include @@ -35,18 +36,14 @@ #include -#include "../support/scalars.h" +#include "../../support/scalars.h" #include "./meta_ref.h" #include "./op_table.h" #include "./span_check.h" #include "./tokenizer.h" -#include "tvm/runtime/builtin_fp16.h" namespace tvm { -namespace parser { - -using namespace relay; -using Expr = relay::Expr; +namespace relay { /*! \brief The meta table maps from type key to a sequence of objects. */ using MetaTable = Map>; @@ -1948,22 +1945,6 @@ Expr ParseExpr(const std::string& file_name, const std::string& file_content) { return expr; } -TVM_REGISTER_GLOBAL("parser.ParseModuleInContext") - .set_body_typed([](const std::string& file_name, const std::string& file_content, - const Optional& init_module, const MetaTable& init_meta_table) { - return ParseModule(file_name, file_content, init_module, init_meta_table); - }); - -TVM_REGISTER_GLOBAL("parser.ParseModule") - .set_body_typed([](const std::string& file_name, const std::string& file_content) { - return ParseModule(file_name, file_content); - }); - -TVM_REGISTER_GLOBAL("parser.ParseExpr") - .set_body_typed([](tvm::String file_name, tvm::String file_content) { - return ParseExpr(file_name, file_content); - }); - /*! * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for @@ -1978,7 +1959,29 @@ Pass AnnotateSpans() { return CreateModulePass(pass_func, 0, "AnnotateSpans", {}); } +TVM_REGISTER_GLOBAL("relay.parser.ParseModuleInContext") + .set_body_typed([](const std::string& file_name, const std::string& file_content, + const Optional& init_module, const MetaTable& init_meta_table) { + return ParseModule(file_name, file_content, init_module, init_meta_table); + }); + +TVM_REGISTER_GLOBAL("relay.parser.ParseModule").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK(args.size() >= 2 && args.size() <= 4) << "Expected 2-4 arguments, but got " << args.size(); + if (args.size() == 2) { + *ret = ParseModule(args[0], args[1]); + } else if (args.size() == 3) { + *ret = ParseModule(args[0], args[1], args[2]); + } else { + *ret = ParseModule(args[0], args[1], args[2], args[3]); + } +}); + +TVM_REGISTER_GLOBAL("relay.parser.ParseExpr") + .set_body_typed([](tvm::String file_name, tvm::String file_content) { + return ParseExpr(file_name, file_content); + }); + TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans); -} // namespace parser +} // namespace relay } // namespace tvm diff --git a/src/parser/span_check.cc b/src/relay/parser/span_check.cc similarity index 96% rename from src/parser/span_check.cc rename to src/relay/parser/span_check.cc index 7fed3730d926..6bbf6317ad9f 100644 --- a/src/parser/span_check.cc +++ b/src/relay/parser/span_check.cc @@ -25,7 +25,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using tvm::relay::transform::CreateFunctionPass; using tvm::transform::PassContext; @@ -101,7 +101,7 @@ Pass SpanCheck() { 0, "SpanCheck", {}); } -TVM_REGISTER_GLOBAL("parser.SpanCheck").set_body_typed([]() { return SpanCheck(); }); +TVM_REGISTER_GLOBAL("relay.parser.SpanCheck").set_body_typed([]() { return SpanCheck(); }); -} // namespace parser +} // namespace relay } // namespace tvm diff --git a/src/parser/span_check.h b/src/relay/parser/span_check.h similarity index 93% rename from src/parser/span_check.h rename to src/relay/parser/span_check.h index 0074c66d61f4..b85b4a497965 100644 --- a/src/parser/span_check.h +++ b/src/relay/parser/span_check.h @@ -21,9 +21,8 @@ * \file span_check.h * \brief Check that the Relay IR has correctly attached span information. */ - -#ifndef TVM_PARSER_SPAN_CHECK_H_ -#define TVM_PARSER_SPAN_CHECK_H_ +#ifndef TVM_RELAY_PARSER_SPAN_CHECK_H_ +#define TVM_RELAY_PARSER_SPAN_CHECK_H_ #include #include @@ -38,7 +37,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using namespace tvm::relay; using tvm::transform::Pass; @@ -74,6 +73,6 @@ struct SpanChecker : ExprVisitor { Pass SpanCheck(); -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_SPAN_CHECK_H_ +#endif // TVM_RELAY_PARSER_SPAN_CHECK_H_ diff --git a/src/parser/token.h b/src/relay/parser/token.h similarity index 93% rename from src/parser/token.h rename to src/relay/parser/token.h index 48a1bf70a250..7b11e701cf6e 100644 --- a/src/parser/token.h +++ b/src/relay/parser/token.h @@ -22,10 +22,11 @@ * \brief The definition of tokens for the TVM parser. */ -#ifndef TVM_PARSER_TOKEN_H_ -#define TVM_PARSER_TOKEN_H_ +#ifndef TVM_RELAY_PARSER_TOKEN_H_ +#define TVM_RELAY_PARSER_TOKEN_H_ -#include +#include +#include #include #include @@ -33,7 +34,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using namespace runtime; @@ -97,7 +98,7 @@ enum class TokenType { kNull, }; -std::string ToString(const TokenType& token_type) { +inline std::string ToString(const TokenType& token_type) { switch (token_type) { case TokenType::kCommentStart: return "CommentStart"; @@ -219,7 +220,7 @@ std::string ToString(const TokenType& token_type) { } } -std::string Pretty(const TokenType& token_type) { +inline std::string Pretty(const TokenType& token_type) { switch (token_type) { case TokenType::kCommentStart: return "`/*`"; @@ -375,7 +376,7 @@ class Token : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Token, ObjectRef, TokenNode); }; -Token::Token(Span span, TokenType token_type, ObjectRef data) { +inline Token::Token(Span span, TokenType token_type, ObjectRef data) { ObjectPtr n = make_object(); n->span = span; n->token_type = token_type; @@ -383,15 +384,17 @@ Token::Token(Span span, TokenType token_type, ObjectRef data) { data_ = std::move(n); } -Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); } +inline Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); } -int64_t Token::ToNumber() const { +inline int64_t Token::ToNumber() const { return Downcast(this->operator->()->data).IntValue(); } -std::string Token::ToString() const { return Downcast(this->operator->()->data); } +inline std::string Token::ToString() const { + return Downcast(this->operator->()->data); +} -Map> Token::ToMetadata() const { +inline Map> Token::ToMetadata() const { ObjectRef data = this->operator->()->data; if (data.defined()) { return Downcast>>(data); @@ -400,6 +403,6 @@ Map> Token::ToMetadata() const { } } -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_TOKEN_H_ +#endif // TVM_RELAY_PARSER_TOKEN_H_ diff --git a/src/parser/tokenizer.h b/src/relay/parser/tokenizer.h similarity index 96% rename from src/parser/tokenizer.h rename to src/relay/parser/tokenizer.h index 505784e4bf70..04dcd3263e99 100644 --- a/src/parser/tokenizer.h +++ b/src/relay/parser/tokenizer.h @@ -18,11 +18,11 @@ */ /*! - * \file parser.h + * \file tokenizer.h * \brief A parser for TVM IR. */ -#ifndef TVM_PARSER_TOKENIZER_H_ -#define TVM_PARSER_TOKENIZER_H_ +#ifndef TVM_RELAY_PARSER_TOKENIZER_H_ +#define TVM_RELAY_PARSER_TOKENIZER_H_ #include #include @@ -34,12 +34,12 @@ #include #include -#include "../support/scalars.h" +#include "../../support/scalars.h" #include "./meta_ref.h" #include "./token.h" namespace tvm { -namespace parser { +namespace relay { using namespace runtime; @@ -54,20 +54,20 @@ static inline void rtrim(std::string& s) { // NOLINT(*) s.end()); } -bool IsDigit(char c) { return '0' <= c && c <= '9'; } +inline bool IsDigit(char c) { return '0' <= c && c <= '9'; } -bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; } +inline bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; } -bool IsNumeric(char c) { +inline bool IsNumeric(char c) { return (IsDigit(c) || c == '.' || c == 'e' || c == '-' || c == '+' || c == 'E') && !IsWhitespace(c); } -bool IsIdentLetter(char c) { +inline bool IsIdentLetter(char c) { return '_' == c || c == '/' || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z'); } -bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } +inline bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } static std::unordered_map KEYWORD_TABLE = { {"let", TokenType::kLet}, {"fn", TokenType::kFn}, @@ -371,7 +371,7 @@ struct Tokenizer { int line = this->line; int col = this->col; auto next = Peek(); - VLOG(9) << "tvm::parser::TokenizeOnce: next=" << next; + VLOG(9) << "tvm::relay::TokenizeOnce: next=" << next; if (next == '\n') { auto token = NewToken(TokenType::kNewline); Next(); @@ -582,7 +582,7 @@ struct Tokenizer { } void Tokenize() { - VLOG(9) << "tvm::parser::Tokenize"; + VLOG(9) << "tvm::relay::Tokenize"; while (this->More()) { auto token = TokenizeOnce(); ICHECK(token.defined()); @@ -601,7 +601,7 @@ struct Tokenizer { tokens() {} }; -std::vector Condense(const std::vector& tokens, Token* table) { +inline std::vector Condense(const std::vector& tokens, Token* table) { std::vector out; bool found_metadata = false; @@ -680,7 +680,8 @@ std::vector Condense(const std::vector& tokens, Token* table) { return out; } -std::pair, Token> Tokenize(const DiagnosticContext& ctx, const Source& source) { +inline std::pair, Token> Tokenize(const DiagnosticContext& ctx, + const Source& source) { auto tokenizer = Tokenizer(ctx, source); tokenizer.Tokenize(); Token meta_table(Span(), TokenType::kUnknown, ObjectRef()); @@ -691,7 +692,7 @@ std::pair, Token> Tokenize(const DiagnosticContext& ctx, cons return {tokens, meta_table}; } -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_TOKENIZER_H_ +#endif // TVM_RELAY_PARSER_TOKENIZER_H_ diff --git a/src/relay/printer/relay_text_printer.cc b/src/relay/printer/relay_text_printer.cc index cc86f9b56435..5b47c262fd48 100644 --- a/src/relay/printer/relay_text_printer.cc +++ b/src/relay/printer/relay_text_printer.cc @@ -41,9 +41,9 @@ #include #include "../../ir/attr_functor.h" -#include "../../parser/meta_ref.h" #include "../../support/scalars.h" #include "../analysis/dependency_graph.h" +#include "../parser/meta_ref.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 168441d1708d..8b6600fbdfa9 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -23,7 +23,6 @@ */ #include -#include #include #include #include diff --git a/tests/cpp/relay/backend/aot/aot_lower_main_test.cc b/tests/cpp/relay/backend/aot/aot_lower_main_test.cc index 31166f1e6bb8..0157f031c214 100644 --- a/tests/cpp/relay/backend/aot/aot_lower_main_test.cc +++ b/tests/cpp/relay/backend/aot/aot_lower_main_test.cc @@ -20,7 +20,7 @@ #include "../../../../../src/relay/backend/aot/aot_lower_main.h" #include -#include +#include namespace tvm { namespace relay { @@ -37,7 +37,7 @@ TEST(AOTLowerMain, ExprAllocatorSkipNestedFunc) { %0(%x) } )"; - IRModule mod = parser::ParseModule("string", mod_text, {}, {}); + IRModule mod = ParseModule("string", mod_text, {}, {}); auto host_target = tvm::Target("llvm"); auto prim_target = tvm::Target(host_target, host_target); auto ctxt = tvm::transform::PassContext::Current(); diff --git a/tests/cpp/relay/collage/candidate_partition_test.cc b/tests/cpp/relay/collage/candidate_partition_test.cc index bc5d2d880a3b..d298a493c11f 100644 --- a/tests/cpp/relay/collage/candidate_partition_test.cc +++ b/tests/cpp/relay/collage/candidate_partition_test.cc @@ -20,9 +20,9 @@ #include "../../../src/relay/collage/candidate_partition.h" #include -#include #include #include +#include #include #include "../../../../src/relay/collage/mock_cost_estimator.h" @@ -37,7 +37,7 @@ namespace { // so not re-tested here. The only other non-trivial code is CandidatePartition::EstimateCost Function MakeTestFunction(const std::string& mod_text) { - IRModule mod = parser::ParseModule("string", mod_text, {}, {}); + IRModule mod = ParseModule("string", mod_text, {}, {}); mod = transform::CapturePostDfsIndexInSpans()(mod); auto func = Downcast(mod->Lookup("main")); LOG(INFO) << "------- input function -------"; diff --git a/tests/cpp/relay/collage/partition_rule_test.cc b/tests/cpp/relay/collage/partition_rule_test.cc index 51a4970c7ec0..effe0b1fa030 100644 --- a/tests/cpp/relay/collage/partition_rule_test.cc +++ b/tests/cpp/relay/collage/partition_rule_test.cc @@ -20,9 +20,9 @@ #include "../../../src/relay/collage/partition_rule.h" #include -#include #include #include +#include #include #include "../../../src/relay/collage/partition_spec.h" @@ -46,7 +46,7 @@ Function MakeTestFunction( } Map> metatable; metatable.Set("relay.Constant", constants); - IRModule mod = parser::ParseModule("string", mod_text, {}, metatable); + IRModule mod = ParseModule("string", mod_text, {}, metatable); mod = transform::CapturePostDfsIndexInSpans()(mod); auto func = Downcast(mod->Lookup("main")); LOG(INFO) << "------- input function -------"; diff --git a/tests/cpp/relay/df_pattern_rewrite_test.cc b/tests/cpp/relay/df_pattern_rewrite_test.cc index af09ae48aafd..374887c12a22 100644 --- a/tests/cpp/relay/df_pattern_rewrite_test.cc +++ b/tests/cpp/relay/df_pattern_rewrite_test.cc @@ -18,11 +18,11 @@ */ #include -#include #include #include #include #include +#include #include "../../../src/relay/transforms/simplify_expr.h" @@ -82,7 +82,7 @@ TEST(DFPatternRewrite, DeeplyNestedWithCallAttributes) { } )"; - IRModule module = parser::ParseModule("string", kModel); + IRModule module = ParseModule("string", kModel); DFPatternRewriteComposer composer; composer.AddRewrite(); Function in_function = Downcast(module->Lookup("main")); diff --git a/tests/cpp/relay/ir/indexed_graph_test.cc b/tests/cpp/relay/ir/indexed_graph_test.cc index 17ec68261684..486d027fbc21 100644 --- a/tests/cpp/relay/ir/indexed_graph_test.cc +++ b/tests/cpp/relay/ir/indexed_graph_test.cc @@ -20,9 +20,9 @@ #include "../../../src/relay/ir/indexed_graph.h" #include -#include #include #include +#include namespace tvm { namespace relay { @@ -81,7 +81,7 @@ IRModule TestRecursiveIRModule() { (%19, %20) // 51 } // 52 )"; - return parser::ParseModule("string", kModel, /*init_module=*/{}, metadata); + return ParseModule("string", kModel, /*init_module=*/{}, metadata); } TEST(IndexedGraph, RecursiveExprRegression) { @@ -179,7 +179,7 @@ IRModule TestUnusedLetBoundIRModule() { } } )"; - return parser::ParseModule("string", kModel); + return ParseModule("string", kModel); } TEST(IndexedGraph, UnusedLetVars) { diff --git a/tests/cpp/relay/transforms/device_domains_test.cc b/tests/cpp/relay/transforms/device_domains_test.cc index c5b2f26315b2..47e303996b3b 100644 --- a/tests/cpp/relay/transforms/device_domains_test.cc +++ b/tests/cpp/relay/transforms/device_domains_test.cc @@ -27,7 +27,7 @@ #include "../../../../src/relay/transforms/device_domains.h" #include -#include +#include #include namespace tvm { @@ -36,7 +36,7 @@ namespace transform { namespace { IRModule TestModule() { - return InferType()(tvm::parser::ParseModule("test", R"( + return InferType()(ParseModule("test", R"( #[version = "0.0.5"] def @f(%x : Tensor[(3, 7), float32], %y : Tensor[(3, 7), float32]) { add(%x, %y) diff --git a/tests/cpp/relay/with_fields_test.cc b/tests/cpp/relay/with_fields_test.cc index 48e04c259bb5..6114fa97a9fd 100644 --- a/tests/cpp/relay/with_fields_test.cc +++ b/tests/cpp/relay/with_fields_test.cc @@ -23,18 +23,18 @@ */ #include -#include #include #include #include +#include namespace tvm { namespace relay { namespace { IRModule TestIRModule() { - return parser::ParseModule("string", - R"( + return ParseModule("string", + R"( #[version = "0.0.5"] def @main(%data : Tensor[(1, 304, 128, 128), float32], %weight1 : Tensor[(304, 1, 3, 3), float32], From eb0c748f8a826613a13242436a126db806729f24 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 24 Jan 2023 20:41:51 -0800 Subject: [PATCH 2/2] retrigger ci