From a5cd7d2dbd450b1847ecd61dca4eaf19931afef9 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 3 Jul 2019 12:29:04 -0700 Subject: [PATCH 1/2] add fatal lint lint lint do make completeness check an error lint drop changes address review comment lint address review add asf header fix --- include/tvm/relay/expr.h | 26 +++++++++++++++ include/tvm/relay/expr_functor.h | 4 +++ include/tvm/relay/feature.h | 5 +-- pytest.ini | 18 ++++++++++ python/tvm/relay/__init__.py | 4 +++ python/tvm/relay/expr.py | 21 ++++++++++++ python/tvm/relay/expr_functor.py | 13 +++++++- python/tvm/relay/feature.py | 5 +-- python/tvm/relay/testing/py_converter.py | 7 ++++ src/relay/backend/interpreter.cc | 5 +++ src/relay/ir/alpha_equal.cc | 8 +++++ src/relay/ir/expr.cc | 25 ++++++++++++++ src/relay/ir/expr_functor.cc | 6 ++++ src/relay/ir/hash.cc | 6 ++++ src/relay/ir/pretty_printer.cc | 6 ++++ src/relay/pass/dependency_graph.cc | 2 ++ src/relay/pass/feature.cc | 1 + src/relay/pass/let_list.h | 12 ++++++- src/relay/pass/partial_eval.cc | 7 ++-- src/relay/pass/to_a_normal_form.cc | 5 +++ src/relay/pass/type_infer.cc | 10 ++++++ tests/lint/check_file_type.py | 3 +- tests/python/relay/test_adt.py | 15 +++++++++ .../python/relay/test_backend_interpreter.py | 27 ++++++++------- tests/python/relay/test_expr_functor.py | 20 ++++------- tests/python/relay/test_pass_alpha_equal.py | 33 +++++++------------ tests/python/relay/test_pass_partial_eval.py | 33 +++++++------------ tests/python/relay/test_py_converter.py | 6 ++++ 28 files changed, 254 insertions(+), 79 deletions(-) create mode 100644 pytest.ini diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index b1b8d6a7154e..5eb6e8055ebe 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -513,6 +513,32 @@ class RefWriteNode : public ExprNode { RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); +/*! \brief A fatal error has occurred. Stop all execution and report with a message. */ +class Fatal; +class FatalNode : public ExprNode { + public: + /*! \brief The Message. */ + std::string msg; + Type type_annotation; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("msg", &msg); + v->Visit("type_annotation", &type_annotation); + v->Visit("span", &span); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Fatal make(std::string msg, Type type_annotation); + + static constexpr const char* _type_key = "relay.Fatal"; + TVM_DECLARE_NODE_TYPE_INFO(FatalNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Fatal, FatalNode, Expr); + +/*! \brief the fatal message for case unhandled in match. */ +TVM_DLL std::string NoMatchMsg(); + /*! * \brief Base class of the temporary expression. * diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e0d940c5d1a5..d85dd0f2e7fa 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -116,6 +116,7 @@ class ExprFunctor { virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const FatalNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->type_key(); throw; @@ -141,6 +142,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); + RELAY_EXPR_FUNCTOR_DISPATCH(FatalNode); return vtable; } }; @@ -171,6 +173,7 @@ class ExprVisitor void VisitExpr_(const RefWriteNode* op) override; void VisitExpr_(const ConstructorNode* op) override; void VisitExpr_(const MatchNode* op) override; + void VisitExpr_(const FatalNode* op) override; virtual void VisitType(const Type& t); virtual void VisitClause(const Clause& c); virtual void VisitPattern(const Pattern& c); @@ -213,6 +216,7 @@ class ExprMutator Expr VisitExpr_(const RefWriteNode* op) override; Expr VisitExpr_(const ConstructorNode* op) override; Expr VisitExpr_(const MatchNode* op) override; + Expr VisitExpr_(const FatalNode* op) override; /*! * \brief Used to visit the types inside of expressions. diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index d7b3b394c5cd..992c2ef1f61b 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -48,10 +48,11 @@ enum Feature : int { fRefWrite = 12, fConstructor = 13, fMatch = 14, + fFatal = 15, /*! \brief Whether any non-atom fragment of the program is shared, making the program a graph. */ - fGraph = 15, + fGraph = 16, /*! \brief Whether there is local fixpoint in the program. */ - fLetRec = 16 + fLetRec = 17 }; constexpr size_t feature_count = 17; diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000000..f0e690284f1c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +[pytest] +xfail_strict=true diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index fff9c99e5007..fa2a973217e6 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -99,6 +99,7 @@ RefCreate = expr.RefCreate RefRead = expr.RefRead RefWrite = expr.RefWrite +Fatal = expr.Fatal # ADT PatternWildcard = adt.PatternWildcard @@ -140,3 +141,6 @@ # Feature Feature = feature.Feature + +# Fatal Messages +NO_MATCH_MSG = expr.NO_MATCH_MSG diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 88779dfd76e0..358f9035c7b1 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -395,6 +395,7 @@ def __init__(self, tuple_value, index): @register_relay_node class RefCreate(Expr): """Create a new reference from initial value. + Parameters ---------- value: tvm.relay.Expr @@ -407,6 +408,7 @@ def __init__(self, value): @register_relay_node class RefRead(Expr): """Get the value inside the reference. + Parameters ---------- ref: tvm.relay.Expr @@ -421,6 +423,7 @@ class RefWrite(Expr): """ Update the value inside the reference. The whole expression will evaluate to an empty tuple. + Parameters ---------- ref: tvm.relay.Expr @@ -432,6 +435,24 @@ def __init__(self, ref, value): self.__init_handle_by_constructor__(_make.RefWrite, ref, value) +@register_relay_node +class Fatal(Expr): + """ + Abort the execution with a fatal error message. + + Parameters + ---------- + msg: String + The message + + type_annotation: Optional[tvm.relay.Type] + The type of Fatal. Leave none to be inferred. + """ + def __init__(self, msg, ty=None): + self.__init_handle_by_constructor__(_make.Fatal, msg, ty) + +NO_MATCH_MSG = _expr.NoMatchMsg() + class TempExpr(Expr): """Baseclass of all TempExpr. diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index f492c743173c..cfc2f95b9afd 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -18,7 +18,7 @@ """The expression functor of Relay.""" from .expr import Function, Call, Let, Var, GlobalVar -from .expr import If, Tuple, TupleGetItem, Constant +from .expr import If, Tuple, TupleGetItem, Constant, Fatal from .expr import RefCreate, RefRead, RefWrite from .adt import Constructor, Match, Clause from .op import Op @@ -69,6 +69,8 @@ def visit(self, expr): res = self.visit_constructor(expr) elif isinstance(expr, Match): res = self.visit_match(expr) + elif isinstance(expr, Fatal): + res = self.visit_fatal(expr) else: raise Exception("warning unhandled case: {0}".format(type(expr))) @@ -124,6 +126,9 @@ def visit_constructor(self, _): def visit_match(self, _): raise NotImplementedError() + def visit_fatal(self, _): + raise NotImplementedError() + class ExprVisitor(ExprFunctor): """ @@ -186,6 +191,9 @@ def visit_match(self, m): for c in m.clauses: self.visit(c.rhs) + def visit_fatal(self, r): + pass + class ExprMutator(ExprFunctor): """ @@ -262,3 +270,6 @@ def visit_ref_write(self, r): def visit_ref_read(self, r): return RefRead(self.visit(r.ref)) + + def visit_fatal(self, r): + return Fatal(r.msg, r.type_annotation) diff --git a/python/tvm/relay/feature.py b/python/tvm/relay/feature.py index 68502672682d..38a14855e65d 100644 --- a/python/tvm/relay/feature.py +++ b/python/tvm/relay/feature.py @@ -35,7 +35,8 @@ class Feature(IntEnum): fRefWrite = 12 fConstructor = 13 fMatch = 14 + fFatal = 15 """ Whether any non-atom fragment of the program is shared, making the program a graph. """ - fGraph = 15 + fGraph = 16 """ Whether there is local fixpoint in the program. """ - fLetRec = 16 + fLetRec = 17 diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index d7b59922b89d..a640239d944b 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -411,6 +411,13 @@ def visit_global_var(self, gvar: Expr): return (Name(gvar.name_hint, Load()), []) + def visit_fatal(self, fatal: Expr): + thunk_name = self.generate_function_name('_fatal_thunk') + thunk = self.create_def(thunk_name, [], [ + ast.Raise(ast.Call(Name("Exception", Load()), [ast.Str(fatal.msg)], []), None)]) + return (self.create_call(thunk_name, []), [thunk]) + + def visit_let(self, letexp: Expr): # To properly account for scoping and ensure that the entire node produces an expression, # we translate the let binding as a function that we call with the value we intend to bind. diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 86a4ebb4ebd2..251b93699936 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -692,6 +692,11 @@ class Interpreter : return Value(); } + Value VisitExpr_(const FatalNode* op) final { + LOG(FATAL) << "fatal message recieved: " << op->msg; + return Value(); + } + bool VisitPattern_(const PatternConstructorNode* op, const Value& v) final { const ConstructorValueNode* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 878795d0b9f2..562f88398b29 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -514,6 +514,14 @@ class AlphaEqualHandler: return false; } + bool VisitExpr_(const FatalNode* lhs, const Expr& other) final { + if (const FatalNode* rhs = other.as()) { + return lhs->msg == rhs->msg; + } else { + return false; + } + } + bool ClauseEqual(const Clause& lhs, const Clause& rhs) { return PatternEqual(lhs->lhs, rhs->lhs) && ExprEqual(lhs->rhs, rhs->rhs); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 35e4f2b4ab13..a8804ea488fa 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -330,6 +330,31 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefWriteNode(" << node->ref << ", " << node->value << ")"; }); + +Fatal FatalNode::make(std::string msg, Type type_annotation) { + NodePtr n = make_node(); + n->msg = std::move(msg); + n->type_annotation = std::move(type_annotation); + return Fatal(n); +} + +TVM_REGISTER_NODE_TYPE(FatalNode); + +TVM_REGISTER_API("relay._make.Fatal") +.set_body_typed(FatalNode::make); + +std::string NoMatchMsg() { + return "No case Match"; +} + +TVM_REGISTER_API("relay._expr.NoMatchMsg") +.set_body_typed(NoMatchMsg); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const FatalNode* node, tvm::IRPrinter* p) { + p->stream << "FatalNode(" << node->msg << ")"; +}); + TVM_REGISTER_API("relay._expr.TempExprRealize") .set_body_typed([](TempExpr temp) { return temp->Realize(); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 6a2db6b46d64..1b67eeb69ab0 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -215,6 +215,10 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) { return MatchNode::make(VisitExpr(m->data), clauses, m->complete); } +Expr ExprMutator::VisitExpr_(const FatalNode* f) { + return GetRef(f); +} + Clause ExprMutator::VisitClause(const Clause& c) { Pattern p = VisitPattern(c->lhs); return ClauseNode::make(p, VisitExpr(c->rhs)); @@ -318,6 +322,8 @@ void ExprVisitor::VisitExpr_(const MatchNode* op) { } } +void ExprVisitor::VisitExpr_(const FatalNode* op) { } + void ExprVisitor::VisitClause(const Clause& op) { this->VisitPattern(op->lhs); this->VisitExpr(op->rhs); diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index d39253372830..a198f1967645 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -351,6 +351,12 @@ class RelayHashHandler: return hash; } + size_t VisitExpr_(const FatalNode* fn) final { + size_t hash = std::hash()(FatalNode::_type_key); + hash = Combine(hash, std::hash()(fn->msg)); + return hash; + } + size_t VisitType_(const TypeCallNode* tcn) final { size_t hash = std::hash()(TypeCallNode::_type_key); hash = Combine(hash, TypeHash(tcn->func)); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 31218be4a6d4..0fb9e11c6d3f 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -633,6 +633,12 @@ class PrettyPrinter : return printed_pattern; } + Doc VisitExpr_(const FatalNode* op) final { + Doc doc; + doc << "Fatal(" << PrintString(op->msg) << ")"; + return doc; + } + Doc VisitPattern_(const PatternConstructorNode* p) final { Doc doc; doc << p->constructor->name_hint; diff --git a/src/relay/pass/dependency_graph.cc b/src/relay/pass/dependency_graph.cc index a9018266589a..81de1995a323 100644 --- a/src/relay/pass/dependency_graph.cc +++ b/src/relay/pass/dependency_graph.cc @@ -174,6 +174,8 @@ class DependencyGraph::Creator : private ExprFunctor { void VisitExpr_(const OpNode* o) final { } void VisitExpr_(const ConstructorNode* c) final { } + + void VisitExpr_(const FatalNode* c) final { } }; DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) { diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc index 2c5e7ab3b984..353bd14d52c2 100644 --- a/src/relay/pass/feature.cc +++ b/src/relay/pass/feature.cc @@ -84,6 +84,7 @@ FeatureSet DetectFeature(const Expr& expr) { DETECT_DEFAULT_CONSTRUCT(RefWrite) DETECT_DEFAULT_CONSTRUCT(Constructor) DETECT_DEFAULT_CONSTRUCT(Match) + DETECT_DEFAULT_CONSTRUCT(Fatal) #undef DETECT_DEFAULT_CONSTRUCT } fd; fd(expr); diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 94b5ea3ad42a..dc6a53e061ee 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -41,6 +41,10 @@ namespace tvm { namespace relay { +struct EmitFatal : dmlc::Error { + explicit EmitFatal(const std::string& msg) : dmlc::Error(msg) { } +}; + /*! * \brief LetList allow you to transform expression into variables, so you can copy them around. * one can insert into the LetList by calling Push, and wrap an expression with bindings with Get. @@ -134,7 +138,13 @@ class LetList { template static Expr With(F&& f) { LetList ll; - return ll.Get(f(&ll)); + Expr ret; + try { + ret = f(&ll); + } catch (const EmitFatal& ef) { + ret = FatalNode::make(ef.what(), Type()); + } + return ll.Get(ret); } static Expr Let(const Expr& e, const std::function& f) { diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index 906d245e4601..ee6fa743c05b 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -881,6 +881,10 @@ class PartialEvaluator : public ExprFunctor ReflectError() : dmlc::Error("static value not found") { } }; + PStatic VisitExpr_(const FatalNode* op, LetList* ll) final { + throw EmitFatal(op->msg); + } + Expr Reflect(const PStatic& st) { if (!st->pstatic.defined()) { throw ReflectError(); @@ -1010,8 +1014,7 @@ class PartialEvaluator : public ExprFunctor throw; } } - LOG(FATAL) << "No case Match"; - throw; + throw EmitFatal(NoMatchMsg()); }); } diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index d5869a5806a4..9b84e1f80f4e 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -264,6 +264,11 @@ class Fill : ExprFunctor { } return Compound(e, MatchNode::make(data, clauses, m->complete), v); } + + Expr VisitExpr_(const FatalNode* f, const Var& v) final { + Expr e = GetRef(f); + return Compound(e, e, v); + } }; Expr ToANormalFormAux(const Expr& e) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 5b9b25bd61f9..86a29b752bca 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -323,6 +323,12 @@ class TypeInferencer : private ExprFunctor, return op->op_type; } + Type VisitExpr_(const FatalNode* op) final { + return op->type_annotation.defined() ? + op->type_annotation : + IncompleteTypeNode::make(Kind::kType); + } + Type VisitExpr_(const LetNode* let) final { // if the definition is a function literal, permit recursion bool is_functional_literal = let->value.as() != nullptr; @@ -653,6 +659,10 @@ class TypeInferencer::Resolver : public ExprMutator, PatternMutator { return AttachCheckedType(op); } + Expr VisitExpr_(const FatalNode* op) final { + return AttachCheckedType(op); + } + Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py index e5f2dc7e2aa6..345122e4e0f1 100644 --- a/tests/lint/check_file_type.py +++ b/tests/lint/check_file_type.py @@ -110,6 +110,7 @@ "docs/_static/css/tvm_theme.css", "docs/_static/img/tvm-logo-small.png", "docs/_static/img/tvm-logo-square.png", + "pytest.ini", } @@ -163,7 +164,7 @@ def main(): report += "\n".join(error_list) report += "\nFound %d files that are now allowed\n" % len(error_list) report += ("We do not check in binary files into the repo.\n" - "If necessary, please discuss with committers and" + "If necessary, please discuss with committers and " "modify tests/lint/check_file_type.py to enable the file you need.\n") sys.stderr.write(report) sys.stderr.flush() diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 390d3cd9f3c4..25802a7a439b 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -20,6 +20,8 @@ from tvm.relay import create_executor from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr +from tvm.relay import TypeVar, Var, Fatal, Function, Call +from tvm.relay import PatternConstructor, Clause, PatternVar, Match import numpy as np @@ -826,6 +828,19 @@ def run(dtype): run('float32') run('int32') +def test_head_fatal(): + a = TypeVar("a") + x = Var("x", l(a)) + y = Var("y") + z = Var("z") + nil_case = Clause(PatternConstructor(nil, []), Fatal("cannot pass nil into head")) + cons_case = Clause(PatternConstructor(cons, [PatternVar(y), PatternVar(z)]), y) + hd_fatal = Function([x], Match(x, [cons_case, nil_case]), a, [a]) + expr = Call(hd_fatal, [cons(make_nat_expr(12), nil())]) + res = intrp.evaluate(expr) + assert count(res) == 12 + + if __name__ == "__main__": test_nat_constructor() test_double() diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index c1a19c4d9bb1..74bdfa455c8b 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm import relay @@ -242,18 +243,16 @@ def test_tuple_passing(): out = f(value_tuple) tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) + +@pytest.mark.xfail() +def test_fatal(): + mod = relay.Module() + ctx = tvm.cpu() + target = tvm.target.create('llvm') + exec = relay.create_executor(mod=mod, ctx=ctx, target=target) + f = exec.evaluate(relay.Fatal("msg", relay.TupleType([]))) + f() + + if __name__ == "__main__": - test_id() - test_add_const() - test_equal() - test_subtract() - test_simple_loop() - test_loop() - test_binds() - test_kwargs_params() - test_ref() - test_tensor_value() - test_tuple_value() - test_tuple_getitem() - test_function_taking_adt_ref_tuple() - test_tuple_passing() + pytest.main([__file__]) diff --git a/tests/python/relay/test_expr_functor.py b/tests/python/relay/test_expr_functor.py index 5c923655d7b7..c2c4eacc824f 100644 --- a/tests/python/relay/test_expr_functor.py +++ b/tests/python/relay/test_expr_functor.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import pytest import tvm from tvm import relay from tvm.relay import ExprFunctor, ExprMutator, ExprVisitor @@ -135,18 +136,9 @@ def test_match_completeness(): assert result_expr.complete == completeness +def test_fatal(): + check_visit(relay.Fatal("meow")) + + if __name__ == "__main__": - test_constant() - test_tuple() - test_var() - test_global() - test_function() - test_call() - test_let() - test_ite() - test_ref_create() - test_ref_read() - test_ref_write() - test_memo() - test_match() - test_match_completeness() + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index b240daf962d5..5f3a7593f34c 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm import relay from tvm.relay import analysis @@ -626,25 +627,15 @@ def test_tuple_match(): assert analysis.structural_hash(x) == analysis.structural_hash(y) +def test_fatal(): + x = relay.Fatal("msg") + y = relay.Fatal("msg") + z = relay.Fatal("dio") + assert analysis.alpha_equal(x, y) + assert analysis.structural_hash(x) == analysis.structural_hash(y) + assert not analysis.alpha_equal(x, z) + assert not analysis.structural_hash(x) == analysis.structural_hash(z) + + if __name__ == "__main__": - test_tensor_type_alpha_equal() - test_incomplete_type_alpha_equal() - test_constant_alpha_equal() - test_func_type_alpha_equal() - test_tuple_type_alpha_equal() - test_type_relation_alpha_equal() - test_type_call_alpha_equal() - test_constant_alpha_equal() - test_global_var_alpha_equal() - test_tuple_alpha_equal() - test_tuple_get_item_alpha_equal() - test_function_alpha_equal() - test_call_alpha_equal() - test_let_alpha_equal() - test_if_alpha_equal() - test_constructor_alpha_equal() - test_match_alpha_equal() - test_op_alpha_equal() - test_var_alpha_equal() - test_graph_equal() - test_hash_unequal() + pytest.main([__file__]) diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index cf4f8f6cee74..d5b261a90449 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -16,6 +16,7 @@ # under the License. import numpy as np +import pytest import tvm from tvm import relay from tvm.relay.analysis import alpha_equal, assert_alpha_equal @@ -23,7 +24,7 @@ from tvm.relay import op, create_executor, transform from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate from tvm.relay import TensorType, Tuple, If, Module, Clause, PatternConstructor, PatternVar, Match -from tvm.relay import GlobalVar, Call +from tvm.relay import GlobalVar, Call, Fatal, TupleType from tvm.relay.transform import gradient from tvm.relay.testing import add_nat_definitions, make_nat_expr, run_infer_type @@ -339,24 +340,14 @@ def test_tuple_match(): assert_alpha_equal(dcpe(x), const(2)) +def test_fatal(): + msg = "user-defined fatal message" + assert alpha_equal(dcpe(Fatal(msg)), Fatal(msg)) + mod = Module() + p = Prelude(mod) + orig = Function([], p.hd(p.nil()), TupleType([])) + assert alpha_equal(dcpe(orig, mod=mod).body, Fatal(relay.NO_MATCH_MSG)) + + if __name__ == '__main__': - test_nat_update() - test_ref() - test_tuple() - test_empty_ad() - test_const_inline() - test_ad() - test_if_ref() - test_function_invalidate() - test_head_cons() - test_map() - test_loop() - test_swap_loop() - test_abs_diff() - test_double() - test_nat_id() - test_global_match_nat_id() - test_match_nat_id() - test_concat() - test_triangle_number() - test_tuple_match() + pytest.main() diff --git a/tests/python/relay/test_py_converter.py b/tests/python/relay/test_py_converter.py index 49a2219dcd04..0e32f0f9e98a 100644 --- a/tests/python/relay/test_py_converter.py +++ b/tests/python/relay/test_py_converter.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm import relay from tvm.relay.testing import to_python, run_as_python @@ -553,3 +554,8 @@ def reference(x, gamma, beta, moving_mean, moving_var): verify_batch_norm([(20, 10), (10,), (10,), (10,), (10,)]) verify_batch_norm([(10, 50), (50,), (50,), (50,), (50,)]) verify_batch_norm([(30, 40), (40,), (40,), (40,), (40,)]) + + +def test_fatal(): + with pytest.raises(Exception, match="msg"): + run_as_python(relay.Fatal("msg", relay.TupleType([]))) From 66c94725d52edb98a67bd8cc8497cff90fa87e68 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Sat, 19 Oct 2019 04:08:29 +0000 Subject: [PATCH 2/2] save --- include/tvm/relay/feature.h | 2 +- tests/python/relay/test_pass_partial_eval.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h index 992c2ef1f61b..1cb2e94e0bfe 100644 --- a/include/tvm/relay/feature.h +++ b/include/tvm/relay/feature.h @@ -55,7 +55,7 @@ enum Feature : int { fLetRec = 17 }; -constexpr size_t feature_count = 17; +constexpr size_t feature_count = 18; /*! * \brief A finite set of Feature. diff --git a/tests/python/relay/test_pass_partial_eval.py b/tests/python/relay/test_pass_partial_eval.py index d5b261a90449..afbb368f135a 100644 --- a/tests/python/relay/test_pass_partial_eval.py +++ b/tests/python/relay/test_pass_partial_eval.py @@ -342,7 +342,8 @@ def test_tuple_match(): def test_fatal(): msg = "user-defined fatal message" - assert alpha_equal(dcpe(Fatal(msg)), Fatal(msg)) + UnitType = TupleType([]) + assert alpha_equal(dcpe(Fatal(msg, UnitType)), Fatal(msg, UnitType)) mod = Module() p = Prelude(mod) orig = Function([], p.hd(p.nil()), TupleType([])) @@ -350,4 +351,4 @@ def test_fatal(): if __name__ == '__main__': - pytest.main() + pytest.main([__file__])