From d4a51ffd69bd29684a279da5339d0163b03554bd Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 7 Jun 2020 23:31:21 -0700 Subject: [PATCH 01/19] resolved cherry pick --- src/relay/op/tensor/transform.cc | 3 + src/relay/op/tensor/transform.h | 3 + src/relay/transforms/lazy_gradient_init.cc | 847 +++++++++++++++--- .../relay/test_pass_lazy_gradient_init.py | 2 +- 4 files changed, 750 insertions(+), 105 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f1d5b7ae5e27..457dc2bf5a34 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -184,6 +184,9 @@ bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } const auto* param = attrs.as(); + if (param == nullptr) { + return false; + } const int ndim = static_cast(data->shape.size()); const int axis = param->axis; const int num_newaxis = param->num_newaxis; diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 4e5677a1af6d..99dc5904be56 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -59,6 +59,9 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs } const auto* param = attrs.as(); + if (param == nullptr) { + return false; + } if (tensor_tuple->fields[0].as()) { return false; } diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index f06246667a8b..f15f383568f1 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -35,10 +35,10 @@ * * It also overloads + and * operation which can increase performance when doing * operations involving tensors with values of only 0 or 1. - * - * Note: this pass can only be used with functions where the input/output types are - * a combination of TupleTypes and TensorTypes - * + * + * Note: this pass can only be used with functions where the input/output types are a + * combination of TupleTypes, TensorTypes, non mutually-recursive ADTs, and non-nested FuncTypes + * * This pass optimizes 6 ops: * - add * - multiply @@ -46,143 +46,667 @@ * - ones_like * - zeros * - zeros_like - * - * This pass makes use of three visitor. The most important one visits the entire function, - * one is used for wrap inputs and one to unwrap outputs. - * - * For example: - * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32] - * - * After this pass - * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]] - * - * Thus, it is necessary to wrap this outer function so that the input/output types remain the same + * + * This module level pass adds a new "GradCell" datatype for each existing datatype. + * This is the case to propogate the new GradCell datatype through ADTs such as Lists. + * For each function, a new function is created that accepts the "GradCell-version" of the arguments + * of the original function. That is, inputs to the function are converted to their GradCell-version, + * passed to the newly created "GradCell_Function". + * The output is then necessarily converted from the GradCell version to the original return type. + * + * To support ADTs, we use functions that convert between an instance of an ADT to its + * respective GradCell version + * by matching constructors to the constructor of the "GradCell" datatype. + * + * A transformation function is required for different type arguments. + * For example the ADT List may be List[int32] or List[List[int32]], which should be handled separately. + * + * This pass uses 4 primary mutators: + * - LazyGradientInitializer to create the "GradCell_Function" of a given function. + * - GradCellWrapper mutates expr into its respective GradCell expr + * - GradCellWrapper mutates expr into its respective non-GradCell expr + * - ADTTransform creates a ADT for each unique ADT */ #include #include #include #include +#include +#include #include - +#include #include "let_list.h" namespace tvm { namespace relay { +const std::string GradCell_Header = "_GradCell_"; +const std::string GradCell_TransFunc = "_GradCell_TransFunc_"; +const std::string GradCell_ReverseTransFunc = "_GradCell_ReverseTransFunc_"; +const std::string GradCell_Func = "_GradCell_Func_"; + +struct TypeCallHash { + size_t operator()(const TypeCall& typecall) const { + return ObjectHash()(typecall->func); + } +}; + /*! - * \brief Visitor appropriately wraps tensors with Raw constructor - * - * Recursively looks at the type of the expression (TensorType or TupleType are only supported for - * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if - * TupleType + * \brief Check if two ADT instances are equal, + * check for dataflow equivalence allow for mapping between TypeVars + * i.e GradCell[TypeVar(A)] = GradCell[TypeVar(B)] */ -class InputVisitor : public ExprFunctor { +struct TypeCallEqual { + bool operator()(const TypeCall& l, const TypeCall& r) const { + if (!(l->func.same_as(r->func))) { + return false; + } + + if (l->args.size() != r->args.size()) { + return false; + } + + for (size_t i = 0; i < l->args.size(); i++) { + if (!GraphEqual(l->args[i], r->args[i])) { + return false; + } + } + + return true; + } +}; + +/*! + * \brief ADTTransform creates a new ADT named + * GradCell_Header + name_hint for each unique ADT. + * + * This is necessary to handle ADTs. + */ +class ADTTransform: public TypeMutator, public PatternMutator { public: - explicit InputVisitor(IRModule module) : module_(module) {} + explicit ADTTransform(IRModule module): + module_(module) { } + + Type VisitType(const Type& t) final { + return TypeMutator::VisitType(t); + } + + Type VisitType_(const TensorTypeNode* op) final { + GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); + tvm::Array args; + args.push_back(GetRef(op)); + return TypeCall(gradCell, args); + } + + Type VisitType_(const GlobalTypeVarNode* op) final { + GlobalTypeVar t = GetRef(op); + if (op->kind == kAdtHandle) { + + if (adt_mapping_.count(t) != 0) { + return adt_mapping_.at(t); + } + + TypeData adt = module_->LookupTypeDef(t); + this->VisitType(adt); + + return adt_mapping_.at(t); + } + + return GetRef(op); + } + + Type VisitType_(const TypeDataNode* op) final { + auto type_data = GetRef(op); + std::string transformed_adt_name = GradCell_Header + op->header->name_hint; + + // add new ADT to map to handle recursive definitions + GlobalTypeVar new_adt = GlobalTypeVar(transformed_adt_name, + op->header->kind); + adt_mapping_[type_data->header] = new_adt; + reverse_adt_mapping_[new_adt] = type_data->header; + + // define transformed ADT + Array constructors; + for (Constructor con : op->constructors) { + Array inputs; + for (Type t : con->inputs) { + inputs.push_back(this->VisitType(t)); + } + Constructor transformed_cons = Constructor(GradCell_Header + con->name_hint, + inputs, new_adt); + constructors.push_back(transformed_cons); + } + + TypeData new_datatype = TypeData(new_adt, op->type_vars, constructors); + module_->AddTypeDef(new_adt, new_datatype); + return new_datatype; + } + + Pattern VisitPattern(const Pattern& c) final { + return PatternMutator::VisitPattern(c); + } - Expr VisitExpr_(const VarNode* op, const Type& t) final { - std::cout << op->type_annotation << std::endl; - return WrapExpr(GetRef(op), op->type_annotation); + Constructor VisitConstructor(const Constructor& c) final { + this->VisitType(c->belong_to); + return module_->GetConstructor(GradCell_Header + c->belong_to->name_hint, + GradCell_Header + c->name_hint); } - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return WrapExpr(GetRef(op), t); + /*! + * \brief Given a transformed ADT, returned the original ADT. + * Useful for GradCellUnWrapper which needs to map transformed ADT constructors + * to the original ADT constructors. + * + * \param transformed_adt_handle GlobalTypeVar of "GradCell-version" of ADT + * \return ADT + */ + GlobalTypeVar GetReverseADT(GlobalTypeVar transformed_adt_handle) { + auto it = reverse_adt_mapping_.find(transformed_adt_handle); + + // reverse mapping should always be found + CHECK(it != reverse_adt_mapping_.end()) << "Reverse mapping of ADT transformation not found"; + return it->second; } private: + // Module IRModule module_; + // ADT -> transformed ADT + std::unordered_map adt_mapping_; + // transformed ADT -> ADT + std::unordered_map reverse_adt_mapping_; +}; - Expr WrapExpr(const Expr expr, const Type& type) { - if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); - } else if (auto* type_anno = type.as()) { - tvm::Array fields; - for (size_t i = 0; i < type_anno->fields.size(); i++) { - const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); - } - Expr tuple = Tuple(fields); - return tuple; +/*! + * \brief Helper for TypeCallMutator. + * Replace TypeVar with type arguments + */ +class TypeVarSolver: public TypeMutator { + public: + explicit TypeVarSolver(std::unordered_map &type_var_map, + std::unordered_map &type_call_map): + type_var_map_(type_var_map), type_call_map_(type_call_map) {} + Type VisitType_(const TypeVarNode* op) final { + TypeVar type = GetRef(op); + + if (type_call_map_.count(type) != 0) { + // recursively visit Type argument to replace possible nested TypeVar + return VisitType(type_call_map_.at(type)); + } + + if (type_var_map_.count(type) != 0) { + return type_var_map_.at(type); } - return expr; + return type; } + private: + // + std::unordered_map &type_var_map_; + std::unordered_map &type_call_map_; }; /*! - * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors - * - * Recursively looks at the type of the expression - * and either use the FromGradCell function if TypeCall to GradCell - * or unfold and recursively visit if TupleType + * \brief Find and replace TypeVars within arguments of a TypeCall + * This is used in GradCellWrapper and GradCellUnwrapper to extract the type_args + * to be passed in to the function that converts the ADT and the type_params of the + * new function. */ -class OutputVisitor : public ExprFunctor { +class TypeCallMutator: public TypeVisitor { public: - explicit OutputVisitor(IRModule module) : module_(module) {} + // TypeVars that have to be passed to function as type_args + Array args; + // TypeVar params of new function + Array params; + explicit TypeCallMutator(IRModule module, const TypeCallNode* op): module_(module) { + for (Type t : op->args) { + // visit each type argument + VisitType(t); + } + for (auto const& x : type_var_map) { + args.push_back(x.first); + params.push_back(x.second); + } + } + + /*! + * \brief Solve for input type to function. Replace TypeVars of ADT with argumuents from TypeCall + * and replace TypeVars with newly created TypeVars of function. + * + * \param t TypeCall + * \param map TypeVar of ADT -> type argument + * + * \return type after replacing ADT TypeVar with arguments and TypeVar of arguments with TypeVar + * of functions + */ - Expr VisitExpr_(const CallNode* op, const Type& t) final { - return UnwrapExpr(GetRef(op), t); + Type InputType(Type t, std::unordered_map& map) { + return TypeVarSolver(type_var_map, map).VisitType(t); } - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { - return UnwrapExpr(GetRef(op), t); + void VisitType_(const TypeVarNode* op) final { + TypeVar tv = GetRef(op); + if (type_var_map.count(tv) == 0) { + TypeVar replacement = TypeVar(tv->name_hint + "_", tv->kind); + type_var_map.insert({tv, replacement}); + } } private: IRModule module_; + // TypeVar in argument -> TypeVar of polymorphic function + std::unordered_map type_var_map; +}; - Expr UnwrapExpr(const Expr expr, const Type& type) { - if (auto* type_call = type.as()) { - if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - return Call(module_->GetGlobalVar("FromGradCell"), {expr}); - } - return expr; - } else if (auto* type_anno = type.as()) { - tvm::Array fields; - for (size_t i = 0; i < type_anno->fields.size(); i++) { - const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); - } - Expr tuple = Tuple(fields); - return tuple; - } +typedef class GradCellUnWrapper GradCellUnWrapper; - return expr; - } +/*! + * \brief Mutate a given expression into its "GradCell-version". + * TensorTypes are wrapped with the Raw constructor of GradCell. + * TupleTypes are recursively visited. + * ADTTypes are converted to its appropriate transformed ADT + * FuncTypes are wrapped with a function that appropriates wraps/unwraps input and output + */ +class GradCellWrapper: public ExprFunctor, + public TypeMutator { + public: + explicit GradCellWrapper(IRModule module, ADTTransform* adt_transformer): + module_(module), adt_transformer_(adt_transformer), unique(0) {} + Expr VisitExpr_(const VarNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; + Expr VisitExpr_(const CallNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; + private: + // Module + IRModule module_; + // ADTTransform + ADTTransform* adt_transformer_; + // TypeCall -> Function to transform an ADT Instance into GradCell version + std::unordered_map adt_wrapper_map_; + // TypeVar of ADT call -> Type argument + std::unordered_map type_var_map; + // create unique strings for ADT wrapper functions + unsigned long unique; + + Expr WrapExpr(const Expr expr, const Type& type, GradCellUnWrapper* unwrapper); + // Return function to wrap ADT + Expr GetADTFunction(const TypeCallNode* op, TypeCallMutator& type_args, + GradCellUnWrapper* unwrapper); + Type VisitType_(const GlobalTypeVarNode* op) final; + Type VisitType_(const TensorTypeNode* op) final; }; -class LazyGradientInitializer : public ExprMutator, public TypeMutator { +/*! + * \brief Mutate a given "GradCell-version" expression into its nonGradCell-version. + * TypeCalls to GradCell are wrapped with FromGradCell function + * TupleTypes are recursively visited. + * Transformed ADTs are converted to its appropriate normal ADT + */ +class GradCellUnWrapper: public ExprFunctor, + public TypeMutator { public: - explicit LazyGradientInitializer(IRModule module) : module_(module) { - module_->ImportFromStd("gradient.rly"); + explicit GradCellUnWrapper(IRModule module, ADTTransform* adt_transformer): + module_(module), adt_transformer_(adt_transformer), unique(0) {} + Expr VisitExpr_(const VarNode* op, const Type& t) final; + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final; + Expr VisitExpr_(const CallNode* op, const Type& t) final; + Expr VisitExpr_(const TupleNode* op, const Type& t) final; + Expr VisitExpr_(const ConstantNode* op, const Type& t) final; + private: + // Module + IRModule module_; + // ADTTransform + ADTTransform* adt_transformer_; + // TypeCall -> Function an GradCell_ADT into ADT + std::unordered_map adt_unwrapper_map_; + // TypeVar of GradCell_ADT call -> Type argument + std::unordered_map type_var_map; + // create unique strings for ADT unwrapper functions + unsigned long unique; + + Expr UnwrapExpr(const Expr expr, const Type& type); + // Return function to unwrap ADT + Expr GetReverseADTFunction(const TypeCallNode* op, TypeCallMutator& type_args); + Type VisitType_(const TypeCallNode* op) final; + Type VisitType_(const GlobalTypeVarNode* op) final; +}; + +/* GradCellWrapper */ +Expr GradCellWrapper::VisitExpr_(const VarNode* op, const Type& t, + GradCellUnWrapper* unwrapper) { + return WrapExpr(GetRef(op), op->type_annotation, unwrapper); +} + +Expr GradCellWrapper::VisitExpr_(const TupleGetItemNode* op, const Type& t, + GradCellUnWrapper* unwrapper) { + return WrapExpr(GetRef(op), t, unwrapper); +} + +Expr GradCellWrapper::VisitExpr_(const CallNode* op, const Type& t, + GradCellUnWrapper* unwrapper) { + return WrapExpr(GetRef(op), t, unwrapper); +} + +Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, + GradCellUnWrapper* unwrapper) { + if (type.as()) { + return Call(module_->GetConstructor("GradCell", "Raw"), + {expr}, Attrs(), {type}); + } + + if (auto* type_anno = type.as()) { + tvm::Array fields; + for (size_t i = 0; i < type_anno->fields.size(); i++) { + const Type& t = type_anno->fields[i]; + // recursively visit each item of tuple + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t, unwrapper)); + } + Expr tuple = Tuple(fields); + return tuple; + } + + if (auto* type_anno = type.as()) { + // create GradCell_ADT if not already created + adt_transformer_->VisitType(type_anno->func); + auto tvs = TypeCallMutator(module_, type_anno); + + return Call(GetADTFunction(type_anno, tvs, unwrapper), {expr}, Attrs(), tvs.args); + } + + if (auto* type_anno = type.as()) { + Array funcVars; + Array args; + for (Type t : type_anno->arg_types) { + Type visited = this->VisitType(t); + Var v = Var("v", visited); + funcVars.push_back(v); + // unwrap arguments + args.push_back(unwrapper->VisitExpr(v, visited)); + } + // call original expr with unwrapped arguments + Call call = Call(expr, args); + // wrap results of the call + Expr result = this->WrapExpr(call, type_anno->ret_type, unwrapper); + // return new function with GradCell-version types, wrapping original function + return Function(funcVars, result, + this->VisitType(type_anno->ret_type), type_anno->type_params); } - /*! - * \brief apply LazyGradientInit transformation and wrap function - * so that function type stays the same - * - * input/output types should only be a combination of TupleTypes and TensorTypes + return expr; +} + +Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator &type_args, + GradCellUnWrapper* unwrapper) { + auto type = GetRef(op); + GlobalTypeVar adt_handle = Downcast(op->func); + if (adt_wrapper_map_.count(type) != 0) { + // ADT already wrapped previously + return adt_wrapper_map_.at(type); + } + + // handle recursive ADT which require recursive calls to transform + GlobalVar func_var = GlobalVar(GradCell_Header + GradCell_TransFunc + adt_handle->name_hint + + std::to_string(unique++)); + adt_wrapper_map_[type] = func_var; + + TypeData adt_data = module_->LookupTypeDef(adt_handle); + TypeData new_adt_data = module_->LookupTypeDef(GradCell_Header + adt_handle->name_hint); + + // solve for input type wrap ADT function + for (size_t i = 0; i < adt_data->type_vars.size(); i++) { + type_var_map[adt_data->type_vars[i]] = op->args[i]; + } + auto input_type = type_args.InputType(type, type_var_map); + + CHECK(adt_data->constructors.size() == new_adt_data->constructors.size()) << + "ADT and transformed ADT have different number of constructors"; + + /* + * Pattern match each constructor of the ADT to the respective constructor + * in the transformed ADT. PatternVars then need to be recursively wrapped, + * and passed as argument to the constructor of the transformed ADT */ - Expr Transform(const Expr& e) { - auto* f = (e).as(); - auto* transformed = this->Mutate(e).as(); + Array clauses; + for (size_t i = 0; i < adt_data->constructors.size(); i++) { + // get Constructor to pattern match against + Array patternVars; + Array c_args; + Constructor c = adt_data->constructors[i]; + for (Type t : c->inputs) { + // solve for type of PatternVar + Type pattern_var_type = type_args.InputType(t, type_var_map); + Var v = Var("var", pattern_var_type); + patternVars.push_back(PatternVar(v)); + // recursively wrap + c_args.push_back(this->VisitExpr(v, pattern_var_type, unwrapper)); + } + Pattern p = PatternConstructor(c, patternVars); + // return Constructor of new ADT with wrapped arguments + Expr e = Call(new_adt_data->constructors[i], c_args); + + clauses.push_back(Clause(p, e)); + } + + Var v = Var("v", input_type); + Expr match = Match(v, clauses); + + Function func = Function({v}, match, this->VisitType(input_type), type_args.params); + module_->Add(func_var, func); + return func; +} + +Type GradCellWrapper::VisitType_(const GlobalTypeVarNode* op) { + GlobalTypeVar t = GetRef(op); + if (op->kind == kAdtHandle) { + return adt_transformer_->VisitType(t); + } + + return GetRef(op); +} + +Type GradCellWrapper::VisitType_(const TensorTypeNode* op) { + GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); + tvm::Array args; + args.push_back(GetRef(op)); + return TypeCall(gradCell, args); +} + +/* GradCellUnWrapper */ +Expr GradCellUnWrapper::VisitExpr_(const CallNode* op, const Type& t) { + return UnwrapExpr(GetRef(op), t); +} + +Expr GradCellUnWrapper::VisitExpr_(const TupleGetItemNode* op, const Type& t) { + return UnwrapExpr(GetRef(op), t); +} + +Expr GradCellUnWrapper::VisitExpr_(const VarNode* op, const Type& t) { + return UnwrapExpr(GetRef(op), op->type_annotation); +} + +Expr GradCellUnWrapper::VisitExpr_(const TupleNode* op, const Type& t) { + return UnwrapExpr(GetRef(op), t); +} + +Expr GradCellUnWrapper::VisitExpr_(const ConstantNode* op, const Type& t) { + return UnwrapExpr(GetRef(op), t); +} - if (e.same_as(GetRef(transformed))) { - return GetRef(transformed); +Expr GradCellUnWrapper::UnwrapExpr(const Expr expr, const Type& type) { + if (auto* type_call = type.as()) { + if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { + // if TypeCall to GradCell, simply wrap with FromGradCell function + return Call(module_->GetGlobalVar("FromGradCell"), {expr}, Attrs(), type_call->args); } - // wrap inputs of Tensor type using InputVisitor class - tvm::Array args; - for (Var var : f->params) { - Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); - args.push_back(wrappedInput); + // convert transformed ADT to ADT + auto tvs = TypeCallMutator(module_, type_call); + return Call(GetReverseADTFunction(type_call, tvs), {expr}, Attrs(), tvs.args); + } + + if (auto* type_anno = type.as()) { + tvm::Array fields; + for (size_t i = 0; i < type_anno->fields.size(); i++) { + // recursively unwrap items of tuple + const Type& t = type_anno->fields[i]; + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + } + Expr tuple = Tuple(fields); + return tuple; + } + return expr; +} + +Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, + TypeCallMutator& type_args) { + TypeCall type = GetRef(op); + GlobalTypeVar transformed_adt_handle = Downcast(op->func); + GlobalTypeVar adt_handle = adt_transformer_->GetReverseADT(transformed_adt_handle); + + // sanity check + CHECK(transformed_adt_handle->name_hint.rfind(GradCell_Header, 0) == 0) + << "Output ADT is not a transformed ADT"; + + if (adt_unwrapper_map_.count(type)) { + // transformed ADT unwrapped previously + return adt_unwrapper_map_.at(type); + } + + // handle recursive ADTs + GlobalVar func_var = GlobalVar(GradCell_Header + GradCell_ReverseTransFunc + + adt_handle->name_hint + std::to_string(unique++)); + adt_unwrapper_map_[type] = func_var; + + TypeData adt_data = module_->LookupTypeDef(adt_handle); + TypeData transformed_adt_data = module_->LookupTypeDef(transformed_adt_handle); + + CHECK(adt_data->type_vars.size() == transformed_adt_data->type_vars.size()) + << "ADT and transformed ADT have different # of type args"; + + // solve for TypeVars of ADT to solve for input type of function + for (size_t i = 0; i < transformed_adt_data->type_vars.size(); i++) { + type_var_map[adt_data->type_vars[i]] = op->args[i]; + } + auto input_type = type_args.InputType(type, type_var_map); + + CHECK(adt_data->constructors.size() == transformed_adt_data->constructors.size()) << + "ADT and transformed ADT have different number of constructors"; + + // use same logic as wrapping expression + // Pattern match with each Constructor of the transformed ADT, + // return respective Constructor with arguments of unwrapped PatternVars + Array clauses; + for (size_t i = 0; i < transformed_adt_data->constructors.size(); i++) { + // Get Constructor of transformed ADT + Array patternVars; + Array c_args; + Constructor c = transformed_adt_data->constructors[i]; + for (Type t : c->inputs) { + // solve for type of pattern var + Type pattern_var_type = type_args.InputType(t, type_var_map); + Var v = Var("var", pattern_var_type); + // bind PatternVar to Var passed to constructor + patternVars.push_back(PatternVar(v)); + // recursively unwrap + c_args.push_back(this->VisitExpr(v, pattern_var_type)); } - Expr transformedExpr = Call(GetRef(transformed), args); + Pattern p = PatternConstructor(c, patternVars); + // Call appropriate Constructor + Expr e = Call(adt_data->constructors[i], c_args); + + clauses.push_back(Clause(p, e)); + } + + Var v = Var("v", input_type); + Expr match = Match(v, clauses); + + Function func = Function({v}, match, this->VisitType(input_type), type_args.params); + module_->Add(func_var, func); + return func; +} - // unwrap outputs of GradCell type into Tensor type using OutputVisitor class - Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); - return Function(f->params, tensorOutput, f->ret_type, Array()); +Type GradCellUnWrapper::VisitType_(const TypeCallNode* op) { + if (op->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { + return op->args[0]; + } + return TypeMutator::VisitType_(op); +} + +Type GradCellUnWrapper::VisitType_(const GlobalTypeVarNode* op) { + GlobalTypeVar t = GetRef(op); + if (op->kind == kAdtHandle) { + return adt_transformer_->GetReverseADT(t); + } + + return GetRef(op); +} + + +class LazyGradientInitializer: public ExprMutator, + public TypeMutator, + public PatternMutator { + public: + explicit LazyGradientInitializer(IRModule module): + module_(module) { + // setup + adt_transformer_ = new ADTTransform(module_); + grad_cell_wrapper_ = new GradCellWrapper(module_, adt_transformer_); + grad_cell_unwrapper_ = new GradCellUnWrapper(module_, adt_transformer_); + + // import GradCell and GradCell functions + module_->ImportFromStd("gradient.rly"); + + // ignore these functions when transforming + GlobalVar from_grad_cell = module_->GetGlobalVar("FromGradCell"); + GlobalVar mul_grad_cell = module_->GetGlobalVar("MultiplyGradCell"); + GlobalVar add_grad_cell = module_->GetGlobalVar("AddGradCell"); + + func_map_[from_grad_cell] = from_grad_cell; + func_map_[mul_grad_cell] = mul_grad_cell; + func_map_[add_grad_cell] = add_grad_cell; + } + + /*! + * \brief Given a global function, create new global function + * that mirrors the functionality however using GradCell type. + * Original function will wrap inputs, call the mirrored function, unwrap the ouput, + * and return. + */ + BaseFunc VisitGlobalVar(const GlobalVar& gv) { + auto base_func = module_->Lookup(gv); + if (auto* e = base_func.as()) { + auto f = GetRef(e); + if (func_map_.count(gv) == 0) { + // create GlobalVar handle for function + func_map_[gv] = GlobalVar(GradCell_Func + gv->name_hint); + } + GlobalVar func_var = func_map_.at(gv); + if (module_->ContainGlobalVar(func_var->name_hint)) { + // transformed function already contained in IRModule, return + return module_->Lookup(func_var); + } + // create transformed function and add definition to IRModule + auto* transformed = ExprMutator::Mutate(f).as(); + module_->Add(func_var, GetRef(transformed)); + + // sanity check + CHECK(f->params.size() == transformed->params.size()) + << "Transformed function doesn't have same number of args"; + + // wrap inputs of Tensor type using GradCellWrapper class + tvm::Array args; + for (Var var : f->params) { + Expr wrappedInput = grad_cell_wrapper_->VisitExpr(var, var->checked_type(), + grad_cell_unwrapper_); + args.push_back(wrappedInput); + } + Expr transformedExpr = Call(func_var, args); + + // unwrap outputs of GradCell type into Tensor type using OutputVisitor class + Expr tensorOutput = grad_cell_unwrapper_->VisitExpr(transformedExpr, transformed->ret_type); + return Function(f->params, tensorOutput, f->ret_type, f->type_params); + } + return module_->Lookup(gv); } Expr VisitExpr_(const ConstantNode* op) final { @@ -226,26 +750,131 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { // handle all other ops Expr result = CallPrimitiveOp(call_node); // wrap result with Raw constructor - return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), - {call_node->checked_type()}); + return grad_cell_wrapper_->VisitExpr(result, call_node->checked_type(), grad_cell_unwrapper_); } - // not an op + if (auto* op = (call_node->op).as()) { + // create "GradCell-version" of ADT if not already created + adt_transformer_->VisitType(op->belong_to); + // call Constructor of transformed ADT + Constructor c = module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, + GradCell_Header + op->name_hint); + Array args; + for (Expr e : call_node->args) { + args.push_back(this->VisitExpr(e)); + } + + Array type_args; + for (Type t : call_node->type_args) { + type_args.push_back(this->VisitType(t)); + } + return Call(c, args, Attrs(), type_args); + } + return ExprMutator::VisitExpr_(call_node); } - Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } + Expr VisitExpr_(const ConstructorNode* op) final { + Constructor c = module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, + GradCell_Header + op->name_hint); + return c; + } + + Expr VisitExpr_(const IfNode* op) final { + auto true_b = VisitExpr(op->true_branch); + auto false_b = VisitExpr(op->false_branch); + + // guard is bool type which will become GradCell[bool], so necessary to unwrap + auto guard = grad_cell_unwrapper_->VisitExpr(VisitExpr(op->cond), + VisitType(op->cond->checked_type())); + return If(guard, true_b, false_b); + } + + Expr VisitExpr_(const VarNode* op) final { + auto var = GetRef(op); + if (var_map_.count(var) != 0) { + return var_map_.at(var); + } - Type VisitType_(const TensorTypeNode* op) { + return ExprMutator::VisitExpr_(op); + } + + Expr VisitExpr_(const GlobalVarNode* op) final { + // GlobalVar is a handle to a global function + GlobalVar gv = GetRef(op); + if (func_map_.count(gv) == 0) { + // create handle to transformed function + func_map_[gv] = GlobalVar(GradCell_Func + op->name_hint); + // define transformed function + this->VisitGlobalVar(gv); + } + return func_map_.at(gv); + } + + Type VisitType(const Type& t) final { + return TypeMutator::VisitType(t); + } + + Type VisitType_(const GlobalTypeVarNode* op) final { + GlobalTypeVar t = GetRef(op); + if (module_->GetGlobalTypeVar("GradCell").same_as(t)) { + return t; + } + if (op->kind == kAdtHandle) { + // handle to ADT, define GradCell version of ADT is not already created + return adt_transformer_->VisitType(t); + } + + return t; + } + + Var VisitVar(const Var& v) final { + // used for PatternMutator + if (var_map_.count(v) == 0) { + var_map_.insert(std::pair(v, + Var(v->name_hint(), + VisitType(v->type_annotation)))); + } + return var_map_.at(v); + } + + Type VisitType_(const TensorTypeNode* op) final { GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); tvm::Array args; args.push_back(GetRef(op)); return TypeCall(gradCell, args); } + Pattern VisitPattern(const Pattern& c) final { + return PatternMutator::VisitPattern(c); + } + + Constructor VisitConstructor(const Constructor& c) final { + adt_transformer_->VisitType(c->belong_to); + return module_->GetConstructor(GradCell_Header + c->belong_to->name_hint, + GradCell_Header + c->name_hint); + } + + ~LazyGradientInitializer() { + // destructors + delete grad_cell_wrapper_; + delete grad_cell_unwrapper_; + delete adt_transformer_; + } + private: // Module IRModule module_; + // pass single instance of ADTTransform to save state of ADTs transformed + ADTTransform* adt_transformer_; + // pass single instance of ADTTransform to save state of ADTs wrapped + GradCellWrapper* grad_cell_wrapper_; + // pass single instance of ADTTransform to save state of ADTs unwrapped + GradCellUnWrapper* grad_cell_unwrapper_; + // var map used for transforming a Clause + std::unordered_map var_map_; + // handle of function -> handle of transformed function + std::unordered_map func_map_; /*! * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type */ @@ -283,26 +912,36 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator { Expr CallPrimitiveOp(const CallNode* call_node) { const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; - // use FromGradCell to convert args to Tensor + + // unwrap arguments for (Expr expr : call_node->args) { - args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); + args.push_back(grad_cell_unwrapper_->VisitExpr(VisitExpr(expr), + VisitType(expr->checked_type()))); } // result of operation return Call(call_node->op, args, call_node->attrs); } }; -Expr LazyGradientInit(const Expr& e, IRModule mod) { - return LazyGradientInitializer(mod).Transform(e); +IRModule LazyGradientInit(const IRModule& m) { + LazyGradientInitializer lgi = LazyGradientInitializer(m); + std::vector gvs; + for (const auto& p : m->functions) { + gvs.push_back(p.first); + } + for (const auto& gv : gvs) { + m->Add(gv, lgi.VisitGlobalVar(gv)); + } + return m; } namespace transform { Pass LazyGradientInit() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(LazyGradientInit(f, m)); - }; - return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { + return relay::LazyGradientInit(m); + }; + return CreateModulePass(pass_func, 1, "LazyGradientInit", {}); } TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit); diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index 414926802870..998b9c06ca7c 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -21,6 +21,7 @@ from tvm.relay import create_executor, transform from tvm.relay.testing import rand, run_infer_type from tvm.testing import assert_allclose +from tvm.relay.prelude import Prelude import pytest def test_tc(): @@ -80,7 +81,6 @@ def test_add_tuple(): mod["main"] = y mod = transform.LazyGradientInit()(mod) - mod = tvm.transform.PrintIR(show_meta_data=True)(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], tensor_type) From 25660c8d8732201e1c52b77d97dacd32d9f6eee2 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 14 Jun 2020 17:49:05 -0700 Subject: [PATCH 02/19] support ADT --- include/tvm/ir/module.h | 5 ++ include/tvm/node/structural_equal.h | 2 + src/ir/module.cc | 39 ++++++++++ src/node/structural_equal.cc | 6 ++ src/relay/transforms/lazy_gradient_init.cc | 73 ++++++++++--------- .../relay/test_pass_lazy_gradient_init.py | 36 +++++++++ 6 files changed, 126 insertions(+), 35 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 7af84b687f5f..63640c3d65df 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -85,6 +85,11 @@ class IRModuleNode : public Object { */ TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func); + /*! + * \brief Infer the type of all global functions + */ + TVM_DLL void Check(); + /*! * \brief Add a type-level definition to the global environment. * \param var The var of the global type definition. diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 9424f6dc30f2..8c8aee7f0a3c 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -89,6 +89,8 @@ class StructuralEqual : public BaseValueEqual { * \return The comparison result. */ TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; + + TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) const; }; /*! diff --git a/src/ir/module.cc b/src/ir/module.cc index bcab39aabf32..99b0c13ec287 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -231,6 +231,45 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { global_var_map_.Set(var->name_hint, var); } +void IRModuleNode::Check() { + const auto& mod = GetRef(this); + const auto& globalvars = this->GetGlobalVars(); + + // first pass: fill in type for all functions + for (const auto& var : globalvars) { + relay::Function f = Downcast(this->Lookup(var)); + auto func = Downcast(relay::DeDup(std::move(f))); + + auto fv = relay::FreeVars(func); + auto ftv = relay::FreeTypeVars(func, mod); + if (fv.size() != 0) { + LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false) + << std::endl; + } + if (ftv.size() != 0) { + LOG(WARNING) << "There are free type variables: " << ftv + << " in function: " << AsText(func, false) << std::endl; + } + auto func_copy = relay::Function(concat(func->params, fv), func->body, func->ret_type, + concat(func->type_params, ftv), func->attrs); + + func_copy->checked_type_ = func_copy->func_type_annotation(); + mod->AddUnchecked(var, func_copy); + } + + // second pass: type inference on every function + for (const auto& var : globalvars) { + auto func = Downcast(this->Lookup(var)); + relay::Function checked_func = InferType(func, mod, var); + + Type type = checked_func->checked_type(); + CHECK(type.as() == nullptr) << "NULL"; + + var->checked_type_ = type; + mod->AddUnchecked(var, func); + } +} + void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { // We hash the global type var name to use as a globally unique prefix for tags. // The hash will be used as the most significant byte of the tag, with the index of diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index e05cbbb60d1f..8cb2019a993e 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -231,4 +231,10 @@ bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) con return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); } +bool StructuralEqual::operator()(const ObjectRef& lhs, + const ObjectRef& rhs, + bool map_free_vars) const { + return RemapVarSEqualHandler(false).Equal(lhs, rhs, map_free_vars); +} + } // namespace tvm diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index f15f383568f1..a2ea04306c8c 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -37,7 +37,7 @@ * operations involving tensors with values of only 0 or 1. * * Note: this pass can only be used with functions where the input/output types are a - * combination of TupleTypes, TensorTypes, non mutually-recursive ADTs, and non-nested FuncTypes + * combination of TupleTypes, TensorTypes, ADTs, and non-nested FuncTypes * * This pass optimizes 6 ops: * - add @@ -47,9 +47,9 @@ * - zeros * - zeros_like * - * This module level pass adds a new "GradCell" datatype for each existing datatype. + * This module level pass adds a new "GradCell" version datatype for each existing datatype. * This is the case to propogate the new GradCell datatype through ADTs such as Lists. - * For each function, a new function is created that accepts the "GradCell-version" of the arguments + * For each function, a new function is created that accepts the "GradCell" type of the arguments * of the original function. That is, inputs to the function are converted to their GradCell-version, * passed to the newly created "GradCell_Function". * The output is then necessarily converted from the GradCell version to the original return type. @@ -78,12 +78,18 @@ #include #include "let_list.h" +#include + namespace tvm { namespace relay { +// prefix of name of GradCell version ADT const std::string GradCell_Header = "_GradCell_"; +// prefix of transformation function for converting ADT to GradCell version const std::string GradCell_TransFunc = "_GradCell_TransFunc_"; +// prefix of transformation function for converting GradCell version ADT to normal const std::string GradCell_ReverseTransFunc = "_GradCell_ReverseTransFunc_"; +// prefix of copy of function that operates on GradCell types const std::string GradCell_Func = "_GradCell_Func_"; struct TypeCallHash { @@ -108,7 +114,7 @@ struct TypeCallEqual { } for (size_t i = 0; i < l->args.size(); i++) { - if (!GraphEqual(l->args[i], r->args[i])) { + if (!tvm::StructuralEqual()(l->args[i], r->args[i], true)) { return false; } } @@ -120,8 +126,6 @@ struct TypeCallEqual { /*! * \brief ADTTransform creates a new ADT named * GradCell_Header + name_hint for each unique ADT. - * - * This is necessary to handle ADTs. */ class ADTTransform: public TypeMutator, public PatternMutator { public: @@ -142,7 +146,6 @@ class ADTTransform: public TypeMutator, public PatternMutator { Type VisitType_(const GlobalTypeVarNode* op) final { GlobalTypeVar t = GetRef(op); if (op->kind == kAdtHandle) { - if (adt_mapping_.count(t) != 0) { return adt_mapping_.at(t); } @@ -242,22 +245,21 @@ class TypeVarSolver: public TypeMutator { return type; } private: - // - std::unordered_map &type_var_map_; - std::unordered_map &type_call_map_; + // type vars to unique type vars + std::unordered_map type_var_map_; + // TypeCall arguments to ADT + std::unordered_map type_call_map_; }; /*! - * \brief Find and replace TypeVars within arguments of a TypeCall - * This is used in GradCellWrapper and GradCellUnwrapper to extract the type_args - * to be passed in to the function that converts the ADT and the type_params of the - * new function. + * \brief Find all TypeVars within the arguments of a TypeCallNode and create a mapping + * of the TypeVars to new TypeVars */ class TypeCallMutator: public TypeVisitor { public: - // TypeVars that have to be passed to function as type_args + // TypeVars within TypeCallNode Array args; - // TypeVar params of new function + // unique TypeVars Array params; explicit TypeCallMutator(IRModule module, const TypeCallNode* op): module_(module) { for (Type t : op->args) { @@ -271,14 +273,14 @@ class TypeCallMutator: public TypeVisitor { } /*! - * \brief Solve for input type to function. Replace TypeVars of ADT with argumuents from TypeCall - * and replace TypeVars with newly created TypeVars of function. + * \brief Replace ADT type vars with TypeCall arguments + * and replace type vars with unique typevars * * \param t TypeCall * \param map TypeVar of ADT -> type argument * - * \return type after replacing ADT TypeVar with arguments and TypeVar of arguments with TypeVar - * of functions + * \return type after replacing ADT TypeVar with arguments and replacing any + * free type vars with uniquely generated typevars */ Type InputType(Type t, std::unordered_map& map) { @@ -306,7 +308,7 @@ typedef class GradCellUnWrapper GradCellUnWrapper; * TensorTypes are wrapped with the Raw constructor of GradCell. * TupleTypes are recursively visited. * ADTTypes are converted to its appropriate transformed ADT - * FuncTypes are wrapped with a function that appropriates wraps/unwraps input and output + * FuncTypes are wrapped with a function that appropriately wraps/unwraps input and output */ class GradCellWrapper: public ExprFunctor, public TypeMutator { @@ -325,7 +327,7 @@ class GradCellWrapper: public ExprFunctor adt_wrapper_map_; // TypeVar of ADT call -> Type argument std::unordered_map type_var_map; - // create unique strings for ADT wrapper functions + // append to prefix to create unique function names for ADT wrapper functions unsigned long unique; Expr WrapExpr(const Expr expr, const Type& type, GradCellUnWrapper* unwrapper); @@ -408,12 +410,16 @@ Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, if (auto* type_anno = type.as()) { // create GradCell_ADT if not already created adt_transformer_->VisitType(type_anno->func); + // find all type vars within type_anno + // to handle polymorphic functions auto tvs = TypeCallMutator(module_, type_anno); return Call(GetADTFunction(type_anno, tvs, unwrapper), {expr}, Attrs(), tvs.args); } if (auto* type_anno = type.as()) { + // to handle functions, we need to create a new function + // that handles GradCell version input and outputs GradCell version types Array funcVars; Array args; for (Type t : type_anno->arg_types) { @@ -452,7 +458,7 @@ Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator &ty TypeData adt_data = module_->LookupTypeDef(adt_handle); TypeData new_adt_data = module_->LookupTypeDef(GradCell_Header + adt_handle->name_hint); - // solve for input type wrap ADT function + // solve for input type to wrap ADT function for (size_t i = 0; i < adt_data->type_vars.size(); i++) { type_var_map[adt_data->type_vars[i]] = op->args[i]; } @@ -491,7 +497,7 @@ Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator &ty Expr match = Match(v, clauses); Function func = Function({v}, match, this->VisitType(input_type), type_args.params); - module_->Add(func_var, func); + module_->AddUnchecked(func_var, func); return func; } @@ -564,7 +570,7 @@ Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, GlobalTypeVar adt_handle = adt_transformer_->GetReverseADT(transformed_adt_handle); // sanity check - CHECK(transformed_adt_handle->name_hint.rfind(GradCell_Header, 0) == 0) + CHECK(std::string(transformed_adt_handle->name_hint).rfind(GradCell_Header, 0) == 0) << "Output ADT is not a transformed ADT"; if (adt_unwrapper_map_.count(type)) { @@ -621,7 +627,7 @@ Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, Expr match = Match(v, clauses); Function func = Function({v}, match, this->VisitType(input_type), type_args.params); - module_->Add(func_var, func); + module_->AddUnchecked(func_var, func); return func; } @@ -687,11 +693,7 @@ class LazyGradientInitializer: public ExprMutator, } // create transformed function and add definition to IRModule auto* transformed = ExprMutator::Mutate(f).as(); - module_->Add(func_var, GetRef(transformed)); - - // sanity check - CHECK(f->params.size() == transformed->params.size()) - << "Transformed function doesn't have same number of args"; + module_->AddUnchecked(func_var, GetRef(transformed)); // wrap inputs of Tensor type using GradCellWrapper class tvm::Array args; @@ -706,7 +708,7 @@ class LazyGradientInitializer: public ExprMutator, Expr tensorOutput = grad_cell_unwrapper_->VisitExpr(transformedExpr, transformed->ret_type); return Function(f->params, tensorOutput, f->ret_type, f->type_params); } - return module_->Lookup(gv); + LOG(FATAL) << "GlobalVar does not map to a function"; } Expr VisitExpr_(const ConstantNode* op) final { @@ -752,6 +754,7 @@ class LazyGradientInitializer: public ExprMutator, // wrap result with Raw constructor return grad_cell_wrapper_->VisitExpr(result, call_node->checked_type(), grad_cell_unwrapper_); } + if (auto* op = (call_node->op).as()) { // create "GradCell-version" of ADT if not already created adt_transformer_->VisitType(op->belong_to); @@ -804,8 +807,6 @@ class LazyGradientInitializer: public ExprMutator, if (func_map_.count(gv) == 0) { // create handle to transformed function func_map_[gv] = GlobalVar(GradCell_Func + op->name_hint); - // define transformed function - this->VisitGlobalVar(gv); } return func_map_.at(gv); } @@ -817,6 +818,7 @@ class LazyGradientInitializer: public ExprMutator, Type VisitType_(const GlobalTypeVarNode* op) final { GlobalTypeVar t = GetRef(op); if (module_->GetGlobalTypeVar("GradCell").same_as(t)) { + // if GradCell type, do nothing return t; } if (op->kind == kAdtHandle) { @@ -930,8 +932,9 @@ IRModule LazyGradientInit(const IRModule& m) { gvs.push_back(p.first); } for (const auto& gv : gvs) { - m->Add(gv, lgi.VisitGlobalVar(gv)); + m->AddUnchecked(gv, lgi.VisitGlobalVar(gv)); } + m->Check(); return m; } diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index 998b9c06ca7c..7f1ca3ca6719 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -391,5 +391,41 @@ def test_ones_like(): y = ex.evaluate(y)(x) assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) +def test_list_adt(): + """test prelude functions on list ADT. which is a recursive ADT""" + mod = tvm.IRModule() + p = Prelude(mod) + + cons = p.cons + nil = p.nil + + mod = transform.LazyGradientInit()(mod) + + ex = create_executor(mod=mod) + + def to_list_adt(list): + l = nil() + for x in list: + l = cons(relay.const(x), l) + return ex.evaluate(l) + + def from_list_adt(list): + l = [] + def rec(x): + if x.constructor.tag == cons.tag: + l.insert(0, x.fields[0].asnumpy().tolist()) + rec(x.fields[1]) + rec(list) + return l + + # test sum + x = np.random.randint(1,101,10) + assert sum(x) == ex.evaluate(mod['sum'])(to_list_adt(x)).asnumpy() + + # test reverse + x = np.random.rand(10) + actual = from_list_adt(ex.evaluate(mod['rev'])(to_list_adt(x))) + assert_allclose(x[::-1], actual) + if __name__ == "__main__": pytest.main([__file__]) From 7b5c9a7d82d166f0a7ad9f78227d57c8f79c5517 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 14 Jun 2020 18:43:18 -0700 Subject: [PATCH 03/19] lint --- src/relay/transforms/lazy_gradient_init.cc | 284 ++++++++++----------- 1 file changed, 130 insertions(+), 154 deletions(-) diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index a2ea04306c8c..8af32261aece 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -35,10 +35,10 @@ * * It also overloads + and * operation which can increase performance when doing * operations involving tensors with values of only 0 or 1. - * + * * Note: this pass can only be used with functions where the input/output types are a * combination of TupleTypes, TensorTypes, ADTs, and non-nested FuncTypes - * + * * This pass optimizes 6 ops: * - add * - multiply @@ -46,21 +46,22 @@ * - ones_like * - zeros * - zeros_like - * + * * This module level pass adds a new "GradCell" version datatype for each existing datatype. * This is the case to propogate the new GradCell datatype through ADTs such as Lists. * For each function, a new function is created that accepts the "GradCell" type of the arguments - * of the original function. That is, inputs to the function are converted to their GradCell-version, - * passed to the newly created "GradCell_Function". - * The output is then necessarily converted from the GradCell version to the original return type. - * - * To support ADTs, we use functions that convert between an instance of an ADT to its + * of the original function. That is, inputs to the function are converted to their + * GradCell-version, passed to the newly created "GradCell_Function". The output is then necessarily + * converted from the GradCell version to the original return type. + * + * To support ADTs, we use functions that convert between an instance of an ADT to its * respective GradCell version * by matching constructors to the constructor of the "GradCell" datatype. - * + * * A transformation function is required for different type arguments. - * For example the ADT List may be List[int32] or List[List[int32]], which should be handled separately. - * + * For example the ADT List may be List[int32] or List[List[int32]], which should be handled + * separately. + * * This pass uses 4 primary mutators: * - LazyGradientInitializer to create the "GradCell_Function" of a given function. * - GradCellWrapper mutates expr into its respective GradCell expr @@ -72,30 +73,27 @@ #include #include #include -#include #include #include -#include -#include "let_list.h" #include +#include "let_list.h" + namespace tvm { namespace relay { // prefix of name of GradCell version ADT -const std::string GradCell_Header = "_GradCell_"; +const char GradCell_Header[] = "_GradCell_"; // prefix of transformation function for converting ADT to GradCell version -const std::string GradCell_TransFunc = "_GradCell_TransFunc_"; +const char GradCell_TransFunc[] = "_GradCell_TransFunc_"; // prefix of transformation function for converting GradCell version ADT to normal -const std::string GradCell_ReverseTransFunc = "_GradCell_ReverseTransFunc_"; +const char GradCell_ReverseTransFunc[] = "_GradCell_ReverseTransFunc_"; // prefix of copy of function that operates on GradCell types -const std::string GradCell_Func = "_GradCell_Func_"; +const char GradCell_Func[] = "_GradCell_Func_"; struct TypeCallHash { - size_t operator()(const TypeCall& typecall) const { - return ObjectHash()(typecall->func); - } + size_t operator()(const TypeCall& typecall) const { return ObjectHash()(typecall->func); } }; /*! @@ -124,17 +122,14 @@ struct TypeCallEqual { }; /*! - * \brief ADTTransform creates a new ADT named + * \brief ADTTransform creates a new ADT named * GradCell_Header + name_hint for each unique ADT. - */ -class ADTTransform: public TypeMutator, public PatternMutator { + */ +class ADTTransform : public TypeMutator, public PatternMutator { public: - explicit ADTTransform(IRModule module): - module_(module) { } - - Type VisitType(const Type& t) final { - return TypeMutator::VisitType(t); - } + explicit ADTTransform(IRModule module) : module_(module) {} + + Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } Type VisitType_(const TensorTypeNode* op) final { GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); @@ -142,14 +137,14 @@ class ADTTransform: public TypeMutator, public PatternMutator { args.push_back(GetRef(op)); return TypeCall(gradCell, args); } - + Type VisitType_(const GlobalTypeVarNode* op) final { GlobalTypeVar t = GetRef(op); if (op->kind == kAdtHandle) { if (adt_mapping_.count(t) != 0) { return adt_mapping_.at(t); } - + TypeData adt = module_->LookupTypeDef(t); this->VisitType(adt); @@ -163,9 +158,8 @@ class ADTTransform: public TypeMutator, public PatternMutator { auto type_data = GetRef(op); std::string transformed_adt_name = GradCell_Header + op->header->name_hint; - // add new ADT to map to handle recursive definitions - GlobalTypeVar new_adt = GlobalTypeVar(transformed_adt_name, - op->header->kind); + // add new ADT to map to handle recursive definitions + GlobalTypeVar new_adt = GlobalTypeVar(transformed_adt_name, op->header->kind); adt_mapping_[type_data->header] = new_adt; reverse_adt_mapping_[new_adt] = type_data->header; @@ -176,31 +170,28 @@ class ADTTransform: public TypeMutator, public PatternMutator { for (Type t : con->inputs) { inputs.push_back(this->VisitType(t)); } - Constructor transformed_cons = Constructor(GradCell_Header + con->name_hint, - inputs, new_adt); + Constructor transformed_cons = Constructor(GradCell_Header + con->name_hint, inputs, new_adt); constructors.push_back(transformed_cons); } - + TypeData new_datatype = TypeData(new_adt, op->type_vars, constructors); module_->AddTypeDef(new_adt, new_datatype); return new_datatype; } - Pattern VisitPattern(const Pattern& c) final { - return PatternMutator::VisitPattern(c); - } + Pattern VisitPattern(const Pattern& c) final { return PatternMutator::VisitPattern(c); } Constructor VisitConstructor(const Constructor& c) final { this->VisitType(c->belong_to); return module_->GetConstructor(GradCell_Header + c->belong_to->name_hint, - GradCell_Header + c->name_hint); + GradCell_Header + c->name_hint); } /*! * \brief Given a transformed ADT, returned the original ADT. * Useful for GradCellUnWrapper which needs to map transformed ADT constructors * to the original ADT constructors. - * + * * \param transformed_adt_handle GlobalTypeVar of "GradCell-version" of ADT * \return ADT */ @@ -225,14 +216,15 @@ class ADTTransform: public TypeMutator, public PatternMutator { * \brief Helper for TypeCallMutator. * Replace TypeVar with type arguments */ -class TypeVarSolver: public TypeMutator { +class TypeVarSolver : public TypeMutator { public: - explicit TypeVarSolver(std::unordered_map &type_var_map, - std::unordered_map &type_call_map): - type_var_map_(type_var_map), type_call_map_(type_call_map) {} + explicit TypeVarSolver( + const std::unordered_map& type_var_map, + const std::unordered_map& type_call_map) + : type_var_map_(type_var_map), type_call_map_(type_call_map) {} Type VisitType_(const TypeVarNode* op) final { TypeVar type = GetRef(op); - + if (type_call_map_.count(type) != 0) { // recursively visit Type argument to replace possible nested TypeVar return VisitType(type_call_map_.at(type)); @@ -244,6 +236,7 @@ class TypeVarSolver: public TypeMutator { return type; } + private: // type vars to unique type vars std::unordered_map type_var_map_; @@ -255,13 +248,13 @@ class TypeVarSolver: public TypeMutator { * \brief Find all TypeVars within the arguments of a TypeCallNode and create a mapping * of the TypeVars to new TypeVars */ -class TypeCallMutator: public TypeVisitor { +class TypeCallMutator : public TypeVisitor { public: // TypeVars within TypeCallNode Array args; // unique TypeVars Array params; - explicit TypeCallMutator(IRModule module, const TypeCallNode* op): module_(module) { + explicit TypeCallMutator(IRModule module, const TypeCallNode* op) : module_(module) { for (Type t : op->args) { // visit each type argument VisitType(t); @@ -275,15 +268,15 @@ class TypeCallMutator: public TypeVisitor { /*! * \brief Replace ADT type vars with TypeCall arguments * and replace type vars with unique typevars - * + * * \param t TypeCall * \param map TypeVar of ADT -> type argument - * + * * \return type after replacing ADT TypeVar with arguments and replacing any * free type vars with uniquely generated typevars */ - Type InputType(Type t, std::unordered_map& map) { + Type InputType(Type t, const std::unordered_map& map) { return TypeVarSolver(type_var_map, map).VisitType(t); } @@ -310,14 +303,15 @@ typedef class GradCellUnWrapper GradCellUnWrapper; * ADTTypes are converted to its appropriate transformed ADT * FuncTypes are wrapped with a function that appropriately wraps/unwraps input and output */ -class GradCellWrapper: public ExprFunctor, - public TypeMutator { +class GradCellWrapper : public ExprFunctor, + public TypeMutator { public: - explicit GradCellWrapper(IRModule module, ADTTransform* adt_transformer): - module_(module), adt_transformer_(adt_transformer), unique(0) {} + explicit GradCellWrapper(IRModule module, ADTTransform* adt_transformer) + : module_(module), adt_transformer_(adt_transformer), unique(0) {} Expr VisitExpr_(const VarNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; Expr VisitExpr_(const TupleGetItemNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; Expr VisitExpr_(const CallNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; + private: // Module IRModule module_; @@ -328,11 +322,11 @@ class GradCellWrapper: public ExprFunctor Type argument std::unordered_map type_var_map; // append to prefix to create unique function names for ADT wrapper functions - unsigned long unique; + int64_t unique; Expr WrapExpr(const Expr expr, const Type& type, GradCellUnWrapper* unwrapper); // Return function to wrap ADT - Expr GetADTFunction(const TypeCallNode* op, TypeCallMutator& type_args, + Expr GetADTFunction(const TypeCallNode* op, TypeCallMutator& type_args, GradCellUnWrapper* unwrapper); Type VisitType_(const GlobalTypeVarNode* op) final; Type VisitType_(const TensorTypeNode* op) final; @@ -344,16 +338,16 @@ class GradCellWrapper: public ExprFunctor, - public TypeMutator { +class GradCellUnWrapper : public ExprFunctor, public TypeMutator { public: - explicit GradCellUnWrapper(IRModule module, ADTTransform* adt_transformer): - module_(module), adt_transformer_(adt_transformer), unique(0) {} + explicit GradCellUnWrapper(IRModule module, ADTTransform* adt_transformer) + : module_(module), adt_transformer_(adt_transformer), unique(0) {} Expr VisitExpr_(const VarNode* op, const Type& t) final; Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final; Expr VisitExpr_(const CallNode* op, const Type& t) final; Expr VisitExpr_(const TupleNode* op, const Type& t) final; Expr VisitExpr_(const ConstantNode* op, const Type& t) final; + private: // Module IRModule module_; @@ -364,7 +358,7 @@ class GradCellUnWrapper: public ExprFunctor, // TypeVar of GradCell_ADT call -> Type argument std::unordered_map type_var_map; // create unique strings for ADT unwrapper functions - unsigned long unique; + int64_t unique; Expr UnwrapExpr(const Expr expr, const Type& type); // Return function to unwrap ADT @@ -374,28 +368,24 @@ class GradCellUnWrapper: public ExprFunctor, }; /* GradCellWrapper */ -Expr GradCellWrapper::VisitExpr_(const VarNode* op, const Type& t, - GradCellUnWrapper* unwrapper) { +Expr GradCellWrapper::VisitExpr_(const VarNode* op, const Type& t, GradCellUnWrapper* unwrapper) { return WrapExpr(GetRef(op), op->type_annotation, unwrapper); } -Expr GradCellWrapper::VisitExpr_(const TupleGetItemNode* op, const Type& t, - GradCellUnWrapper* unwrapper) { +Expr GradCellWrapper::VisitExpr_(const TupleGetItemNode* op, const Type& t, + GradCellUnWrapper* unwrapper) { return WrapExpr(GetRef(op), t, unwrapper); } -Expr GradCellWrapper::VisitExpr_(const CallNode* op, const Type& t, - GradCellUnWrapper* unwrapper) { +Expr GradCellWrapper::VisitExpr_(const CallNode* op, const Type& t, GradCellUnWrapper* unwrapper) { return WrapExpr(GetRef(op), t, unwrapper); } -Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, - GradCellUnWrapper* unwrapper) { +Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, GradCellUnWrapper* unwrapper) { if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), - {expr}, Attrs(), {type}); - } - + return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); + } + if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { @@ -405,8 +395,8 @@ Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, } Expr tuple = Tuple(fields); return tuple; - } - + } + if (auto* type_anno = type.as()) { // create GradCell_ADT if not already created adt_transformer_->VisitType(type_anno->func); @@ -415,8 +405,8 @@ Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, auto tvs = TypeCallMutator(module_, type_anno); return Call(GetADTFunction(type_anno, tvs, unwrapper), {expr}, Attrs(), tvs.args); - } - + } + if (auto* type_anno = type.as()) { // to handle functions, we need to create a new function // that handles GradCell version input and outputs GradCell version types @@ -434,15 +424,14 @@ Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, // wrap results of the call Expr result = this->WrapExpr(call, type_anno->ret_type, unwrapper); // return new function with GradCell-version types, wrapping original function - return Function(funcVars, result, - this->VisitType(type_anno->ret_type), type_anno->type_params); + return Function(funcVars, result, this->VisitType(type_anno->ret_type), type_anno->type_params); } return expr; } -Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator &type_args, - GradCellUnWrapper* unwrapper) { +Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator& type_args, + GradCellUnWrapper* unwrapper) { auto type = GetRef(op); GlobalTypeVar adt_handle = Downcast(op->func); if (adt_wrapper_map_.count(type) != 0) { @@ -451,8 +440,8 @@ Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator &ty } // handle recursive ADT which require recursive calls to transform - GlobalVar func_var = GlobalVar(GradCell_Header + GradCell_TransFunc + adt_handle->name_hint + - std::to_string(unique++)); + GlobalVar func_var = GlobalVar(GradCell_Header + std::string(GradCell_TransFunc) + + adt_handle->name_hint + std::to_string(unique++)); adt_wrapper_map_[type] = func_var; TypeData adt_data = module_->LookupTypeDef(adt_handle); @@ -464,8 +453,8 @@ Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator &ty } auto input_type = type_args.InputType(type, type_var_map); - CHECK(adt_data->constructors.size() == new_adt_data->constructors.size()) << - "ADT and transformed ADT have different number of constructors"; + CHECK(adt_data->constructors.size() == new_adt_data->constructors.size()) + << "ADT and transformed ADT have different number of constructors"; /* * Pattern match each constructor of the ADT to the respective constructor @@ -496,7 +485,7 @@ Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator &ty Var v = Var("v", input_type); Expr match = Match(v, clauses); - Function func = Function({v}, match, this->VisitType(input_type), type_args.params); + Function func = Function({v}, match, this->VisitType(input_type), type_args.params); module_->AddUnchecked(func_var, func); return func; } @@ -548,8 +537,8 @@ Expr GradCellUnWrapper::UnwrapExpr(const Expr expr, const Type& type) { // convert transformed ADT to ADT auto tvs = TypeCallMutator(module_, type_call); return Call(GetReverseADTFunction(type_call, tvs), {expr}, Attrs(), tvs.args); - } - + } + if (auto* type_anno = type.as()) { tvm::Array fields; for (size_t i = 0; i < type_anno->fields.size(); i++) { @@ -563,15 +552,14 @@ Expr GradCellUnWrapper::UnwrapExpr(const Expr expr, const Type& type) { return expr; } -Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, - TypeCallMutator& type_args) { +Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, TypeCallMutator& type_args) { TypeCall type = GetRef(op); GlobalTypeVar transformed_adt_handle = Downcast(op->func); GlobalTypeVar adt_handle = adt_transformer_->GetReverseADT(transformed_adt_handle); // sanity check - CHECK(std::string(transformed_adt_handle->name_hint).rfind(GradCell_Header, 0) == 0) - << "Output ADT is not a transformed ADT"; + CHECK(std::string(transformed_adt_handle->name_hint).rfind(GradCell_Header, 0) == 0) + << "Output ADT is not a transformed ADT"; if (adt_unwrapper_map_.count(type)) { // transformed ADT unwrapped previously @@ -579,24 +567,24 @@ Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, } // handle recursive ADTs - GlobalVar func_var = GlobalVar(GradCell_Header + GradCell_ReverseTransFunc + - adt_handle->name_hint + std::to_string(unique++)); + GlobalVar func_var = GlobalVar(GradCell_Header + std::string(GradCell_ReverseTransFunc) + + adt_handle->name_hint + std::to_string(unique++)); adt_unwrapper_map_[type] = func_var; TypeData adt_data = module_->LookupTypeDef(adt_handle); TypeData transformed_adt_data = module_->LookupTypeDef(transformed_adt_handle); - CHECK(adt_data->type_vars.size() == transformed_adt_data->type_vars.size()) - << "ADT and transformed ADT have different # of type args"; - + CHECK(adt_data->type_vars.size() == transformed_adt_data->type_vars.size()) + << "ADT and transformed ADT have different # of type args"; + // solve for TypeVars of ADT to solve for input type of function for (size_t i = 0; i < transformed_adt_data->type_vars.size(); i++) { type_var_map[adt_data->type_vars[i]] = op->args[i]; } auto input_type = type_args.InputType(type, type_var_map); - CHECK(adt_data->constructors.size() == transformed_adt_data->constructors.size()) << - "ADT and transformed ADT have different number of constructors"; + CHECK(adt_data->constructors.size() == transformed_adt_data->constructors.size()) + << "ADT and transformed ADT have different number of constructors"; // use same logic as wrapping expression // Pattern match with each Constructor of the transformed ADT, @@ -626,7 +614,7 @@ Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, Var v = Var("v", input_type); Expr match = Match(v, clauses); - Function func = Function({v}, match, this->VisitType(input_type), type_args.params); + Function func = Function({v}, match, this->VisitType(input_type), type_args.params); module_->AddUnchecked(func_var, func); return func; } @@ -647,37 +635,33 @@ Type GradCellUnWrapper::VisitType_(const GlobalTypeVarNode* op) { return GetRef(op); } - -class LazyGradientInitializer: public ExprMutator, - public TypeMutator, - public PatternMutator { +class LazyGradientInitializer : public ExprMutator, public TypeMutator, public PatternMutator { public: - explicit LazyGradientInitializer(IRModule module): - module_(module) { - // setup - adt_transformer_ = new ADTTransform(module_); - grad_cell_wrapper_ = new GradCellWrapper(module_, adt_transformer_); - grad_cell_unwrapper_ = new GradCellUnWrapper(module_, adt_transformer_); - - // import GradCell and GradCell functions - module_->ImportFromStd("gradient.rly"); - - // ignore these functions when transforming - GlobalVar from_grad_cell = module_->GetGlobalVar("FromGradCell"); - GlobalVar mul_grad_cell = module_->GetGlobalVar("MultiplyGradCell"); - GlobalVar add_grad_cell = module_->GetGlobalVar("AddGradCell"); - - func_map_[from_grad_cell] = from_grad_cell; - func_map_[mul_grad_cell] = mul_grad_cell; - func_map_[add_grad_cell] = add_grad_cell; - } - + explicit LazyGradientInitializer(IRModule module) : module_(module) { + // setup + adt_transformer_ = new ADTTransform(module_); + grad_cell_wrapper_ = new GradCellWrapper(module_, adt_transformer_); + grad_cell_unwrapper_ = new GradCellUnWrapper(module_, adt_transformer_); + + // import GradCell and GradCell functions + module_->ImportFromStd("gradient.rly"); + + // ignore these functions when transforming + GlobalVar from_grad_cell = module_->GetGlobalVar("FromGradCell"); + GlobalVar mul_grad_cell = module_->GetGlobalVar("MultiplyGradCell"); + GlobalVar add_grad_cell = module_->GetGlobalVar("AddGradCell"); + + func_map_[from_grad_cell] = from_grad_cell; + func_map_[mul_grad_cell] = mul_grad_cell; + func_map_[add_grad_cell] = add_grad_cell; + } + /*! - * \brief Given a global function, create new global function - * that mirrors the functionality however using GradCell type. - * Original function will wrap inputs, call the mirrored function, unwrap the ouput, - * and return. - */ + * \brief Given a global function, create new global function + * that mirrors the functionality however using GradCell type. + * Original function will wrap inputs, call the mirrored function, unwrap the ouput, + * and return. + */ BaseFunc VisitGlobalVar(const GlobalVar& gv) { auto base_func = module_->Lookup(gv); if (auto* e = base_func.as()) { @@ -698,8 +682,8 @@ class LazyGradientInitializer: public ExprMutator, // wrap inputs of Tensor type using GradCellWrapper class tvm::Array args; for (Var var : f->params) { - Expr wrappedInput = grad_cell_wrapper_->VisitExpr(var, var->checked_type(), - grad_cell_unwrapper_); + Expr wrappedInput = + grad_cell_wrapper_->VisitExpr(var, var->checked_type(), grad_cell_unwrapper_); args.push_back(wrappedInput); } Expr transformedExpr = Call(func_var, args); @@ -756,7 +740,7 @@ class LazyGradientInitializer: public ExprMutator, } if (auto* op = (call_node->op).as()) { - // create "GradCell-version" of ADT if not already created + // create "GradCell-version" of ADT if not already created adt_transformer_->VisitType(op->belong_to); // call Constructor of transformed ADT Constructor c = module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, @@ -778,7 +762,7 @@ class LazyGradientInitializer: public ExprMutator, Expr VisitExpr_(const ConstructorNode* op) final { Constructor c = module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, - GradCell_Header + op->name_hint); + GradCell_Header + op->name_hint); return c; } @@ -787,8 +771,8 @@ class LazyGradientInitializer: public ExprMutator, auto false_b = VisitExpr(op->false_branch); // guard is bool type which will become GradCell[bool], so necessary to unwrap - auto guard = grad_cell_unwrapper_->VisitExpr(VisitExpr(op->cond), - VisitType(op->cond->checked_type())); + auto guard = + grad_cell_unwrapper_->VisitExpr(VisitExpr(op->cond), VisitType(op->cond->checked_type())); return If(guard, true_b, false_b); } @@ -811,9 +795,7 @@ class LazyGradientInitializer: public ExprMutator, return func_map_.at(gv); } - Type VisitType(const Type& t) final { - return TypeMutator::VisitType(t); - } + Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } Type VisitType_(const GlobalTypeVarNode* op) final { GlobalTypeVar t = GetRef(op); @@ -832,9 +814,7 @@ class LazyGradientInitializer: public ExprMutator, Var VisitVar(const Var& v) final { // used for PatternMutator if (var_map_.count(v) == 0) { - var_map_.insert(std::pair(v, - Var(v->name_hint(), - VisitType(v->type_annotation)))); + var_map_.insert(std::pair(v, Var(v->name_hint(), VisitType(v->type_annotation)))); } return var_map_.at(v); } @@ -846,14 +826,12 @@ class LazyGradientInitializer: public ExprMutator, return TypeCall(gradCell, args); } - Pattern VisitPattern(const Pattern& c) final { - return PatternMutator::VisitPattern(c); - } + Pattern VisitPattern(const Pattern& c) final { return PatternMutator::VisitPattern(c); } Constructor VisitConstructor(const Constructor& c) final { adt_transformer_->VisitType(c->belong_to); return module_->GetConstructor(GradCell_Header + c->belong_to->name_hint, - GradCell_Header + c->name_hint); + GradCell_Header + c->name_hint); } ~LazyGradientInitializer() { @@ -917,8 +895,8 @@ class LazyGradientInitializer: public ExprMutator, // unwrap arguments for (Expr expr : call_node->args) { - args.push_back(grad_cell_unwrapper_->VisitExpr(VisitExpr(expr), - VisitType(expr->checked_type()))); + args.push_back( + grad_cell_unwrapper_->VisitExpr(VisitExpr(expr), VisitType(expr->checked_type()))); } // result of operation return Call(call_node->op, args, call_node->attrs); @@ -941,10 +919,8 @@ IRModule LazyGradientInit(const IRModule& m) { namespace transform { Pass LazyGradientInit() { runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { - return relay::LazyGradientInit(m); - }; - return CreateModulePass(pass_func, 1, "LazyGradientInit", {}); + [=](IRModule m, PassContext pc) { return relay::LazyGradientInit(m); }; + return CreateModulePass(pass_func, 1, "LazyGradientInit", {}); } TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit); From 42cc5fc6f62f93e0092650ba404d8d0d7aff6f99 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 15 Jun 2020 10:50:56 -0700 Subject: [PATCH 04/19] format --- src/ir/module.cc | 2 +- src/node/structural_equal.cc | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ir/module.cc b/src/ir/module.cc index 99b0c13ec287..25cba917143c 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -251,7 +251,7 @@ void IRModuleNode::Check() { << " in function: " << AsText(func, false) << std::endl; } auto func_copy = relay::Function(concat(func->params, fv), func->body, func->ret_type, - concat(func->type_params, ftv), func->attrs); + concat(func->type_params, ftv), func->attrs); func_copy->checked_type_ = func_copy->func_type_annotation(); mod->AddUnchecked(var, func_copy); diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 8cb2019a993e..280624b36b02 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -231,8 +231,7 @@ bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) con return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); } -bool StructuralEqual::operator()(const ObjectRef& lhs, - const ObjectRef& rhs, +bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) const { return RemapVarSEqualHandler(false).Equal(lhs, rhs, map_free_vars); } From 581b5ca0c05618a2f71537e37531d130fb604331 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 15 Jun 2020 11:09:28 -0700 Subject: [PATCH 05/19] fix warning --- src/relay/transforms/lazy_gradient_init.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index 8af32261aece..0e2d93eb597e 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -692,7 +692,7 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator, public P Expr tensorOutput = grad_cell_unwrapper_->VisitExpr(transformedExpr, transformed->ret_type); return Function(f->params, tensorOutput, f->ret_type, f->type_params); } - LOG(FATAL) << "GlobalVar does not map to a function"; + throw std::runtime_error("GlobalVar does not map to a function"); } Expr VisitExpr_(const ConstantNode* op) final { From addce052fcd916396c8c224122508e4d23e87c30 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 15 Jun 2020 15:41:32 -0700 Subject: [PATCH 06/19] add move semantics to avoid copying --- src/relay/transforms/lazy_gradient_init.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index 0e2d93eb597e..54752e387989 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -176,7 +176,7 @@ class ADTTransform : public TypeMutator, public PatternMutator { TypeData new_datatype = TypeData(new_adt, op->type_vars, constructors); module_->AddTypeDef(new_adt, new_datatype); - return new_datatype; + return std::move(new_datatype); } Pattern VisitPattern(const Pattern& c) final { return PatternMutator::VisitPattern(c); } @@ -234,7 +234,7 @@ class TypeVarSolver : public TypeMutator { return type_var_map_.at(type); } - return type; + return std::move(type); } private: @@ -487,7 +487,7 @@ Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator& ty Function func = Function({v}, match, this->VisitType(input_type), type_args.params); module_->AddUnchecked(func_var, func); - return func; + return std::move(func); } Type GradCellWrapper::VisitType_(const GlobalTypeVarNode* op) { @@ -616,7 +616,7 @@ Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, TypeCallMu Function func = Function({v}, match, this->VisitType(input_type), type_args.params); module_->AddUnchecked(func_var, func); - return func; + return std::move(func); } Type GradCellUnWrapper::VisitType_(const TypeCallNode* op) { @@ -761,9 +761,8 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator, public P } Expr VisitExpr_(const ConstructorNode* op) final { - Constructor c = module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, - GradCell_Header + op->name_hint); - return c; + return module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, + GradCell_Header + op->name_hint); } Expr VisitExpr_(const IfNode* op) final { @@ -801,14 +800,14 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator, public P GlobalTypeVar t = GetRef(op); if (module_->GetGlobalTypeVar("GradCell").same_as(t)) { // if GradCell type, do nothing - return t; + return std::move(t); } if (op->kind == kAdtHandle) { // handle to ADT, define GradCell version of ADT is not already created return adt_transformer_->VisitType(t); } - return t; + return std::move(t); } Var VisitVar(const Var& v) final { From 3f8381dc250d11eba2c3b75e7f36190b7b9b7252 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 21 Jun 2020 16:19:20 -0700 Subject: [PATCH 07/19] WIP --- include/tvm/relay/transform.h | 2 + python/tvm/ir/module.py | 4 ++ python/tvm/relay/transform/transform.py | 4 +- src/ir/module.cc | 2 + src/relay/analysis/type_solver.h | 5 ++ src/relay/transforms/type_infer.cc | 82 ++++++++++++++++++++++--- tests/python/relay/test_type_infer.py | 64 +++++++++++++++++++ 7 files changed, 155 insertions(+), 8 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d322710ec95a..c83beb133670 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -210,6 +210,8 @@ TVM_DLL Pass FastMath(); */ TVM_DLL Pass InferType(); +TVM_DLL Pass InferTypeAll(); + /*! * \brief Search and eliminate common subexpression. For example, if there are * two expressions evaluated to an identical value, a single variable is created diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 8d75d8e8ee21..9b13f6e34862 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -87,6 +87,10 @@ def _add(self, var, val, update=False): var = _ty.GlobalTypeVar(var) _ffi_api.Module_AddDef(self, var, val, update) + def add_unchecked(self, var, val): + assert isinstance(val, _expr.RelayExpr) + _ffi_api.Module_AddUnchecked(self, var, val) + def __getitem__(self, var): """Lookup a global definition by name or by variable. diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cc92141b73db..dd154de61d8d 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -98,7 +98,9 @@ def InferType(): """ return _ffi_api.InferType() - +def InferTypeAll(): + return _ffi_api.InferTypeAll() + def FoldScaleAxis(): """Fold the scaling of axis into weights of conv2d/dense. This pass will invoke both forward and backward scale folding. diff --git a/src/ir/module.cc b/src/ir/module.cc index 25cba917143c..faa5e6a5a71a 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -442,6 +442,8 @@ TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); +TVM_REGISTER_GLOBAL("ir.Module_AddUnchecked").set_body_method(&IRModuleNode::AddUnchecked); + TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") .set_body_method(&IRModuleNode::GetGlobalVar); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index dcd8de075854..ab1593415b04 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -65,6 +65,11 @@ class TypeSolver { public: TypeSolver(const GlobalVar& current_func, const IRModule& _mod, ErrorReporter* err_reporter); ~TypeSolver(); + + void SetCurrentFunc(GlobalVar current_func) { + this->current_func = current_func; + } + /*! * \brief Add a type constraint to the solver. * \param constraint The constraint to be added. diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 7182f0e96f0f..05fe4b458523 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -109,6 +109,15 @@ class TypeInferencer : private ExprFunctor, // inference the type of expr. Expr Infer(Expr expr); + void SetCurrentFunc(GlobalVar current_func) { + this->current_func_ = current_func; + this->solver_.SetCurrentFunc(current_func); + } + + void GenerateConstraints(Expr expr); + void Solve(); + Expr ResolveType(Expr expr); + private: // type resolver that maps back to type class Resolver; @@ -544,13 +553,6 @@ class TypeInferencer : private ExprFunctor, return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {}); } - void Solve() { - solver_.Solve(); - - if (err_reporter.AnyErrors()) { - err_reporter.RenderErrors(mod_); - } - } }; class TypeInferencer::Resolver : public ExprMutator, PatternMutator { @@ -698,6 +700,24 @@ Expr TypeInferencer::Infer(Expr expr) { return resolved_expr; } +void TypeInferencer::GenerateConstraints(Expr expr) { + GetType(expr); +} + +void TypeInferencer::Solve() { + solver_.Solve(); + + if (err_reporter.AnyErrors()) { + err_reporter.RenderErrors(mod_); + } +} + +Expr TypeInferencer::ResolveType(Expr expr) { + auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr); + CHECK(WellFormed(resolved_expr)); + return resolved_expr; +} + struct AllCheckTypePopulated : ExprVisitor { void VisitExpr(const Expr& e) { if (e.as()) { @@ -742,6 +762,45 @@ Function InferType(const Function& func, const IRModule& mod, const GlobalVar& v return Downcast(func_ret); } +IRModule InferTypeAll(const IRModule& mod) { + CHECK(mod.defined()) << "internal error: module must be set for type inference"; + const auto& globalvars = mod->GetGlobalVars(); + + // first pass: fill in type for all functions + for (const auto& var : globalvars) { + relay::Function func = Downcast(mod->Lookup(var)); + func->checked_type_ = func->func_type_annotation(); + mod->AddUnchecked(var, func); + } + + TypeInferencer ti = TypeInferencer(mod, GlobalVar("all")); + + // second pass, fill in constraints + for (const auto& var : globalvars) { + relay::Function func = Downcast(mod->Lookup(var)); + ti.SetCurrentFunc(var); + ti.GenerateConstraints(func); + // ti.Infer(func); + } + + std::cout << "done generating constraints" << std::endl; + + ti.Solve(); + + std::cout << "done solving" << std::endl; + + // third pass + for (const auto& var : globalvars) { + relay::Function func = Downcast(mod->Lookup(var)); + ti.SetCurrentFunc(var); + Expr func_ret = ti.ResolveType(func); + //Expr func_ret = ti.Infer(func); + mod->AddUnchecked(var, Downcast(func_ret)); + } + + return mod; +} + namespace transform { Pass InferType() { @@ -750,8 +809,17 @@ Pass InferType() { return CreateFunctionPass(pass_func, 0, "InferType", {}); } +Pass InferTypeAll() { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relay::InferTypeAll(m); }; + return CreateModulePass(pass_func, 0, "InferTypeAll", {}); +} + + TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); }); +TVM_REGISTER_GLOBAL("relay._transform.InferTypeAll").set_body_typed([]() { return InferTypeAll(); }); + } // namespace transform } // namespace relay diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index cc4748c92b00..a49646588c75 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -20,6 +20,8 @@ import pytest import tvm from tvm import te +import numpy as np +from tvm.relay.prelude import Prelude from tvm import relay from tvm.relay import op, transform, analysis from tvm.relay import Any @@ -362,6 +364,68 @@ def test_let_polymorphism(): int32 = relay.TensorType((), "int32") tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) +def test_mutual_recursion2(): + # f(x) = if x > 0 then g(x - 1) else 0 + # g(y) = if y > 0 then f(y - 1) else 0 + tensortype = relay.TensorType((), 'float32') + + x = relay.Var("x", tensortype) + y = relay.Var("y", tensortype) + + zero = relay.Constant(tvm.nd.array(np.array(0, dtype='float32'))) + one = relay.Constant(tvm.nd.array(np.array(1, dtype='float32'))) + + f_gv = relay.GlobalVar('f') + g_gv = relay.GlobalVar('g') + + def body(var, call_func): + subtract_one = relay.op.subtract(var, one) + cond = relay.If(relay.op.greater(var, zero), + relay.Call(call_func, [subtract_one]), + zero) + func = relay.Function([var], cond) + return func + + f = body(x, g_gv) + g = body(y, f_gv) + + mod = tvm.IRModule() + p = Prelude(mod) + mod.add_unchecked(f_gv, f) + mod.add_unchecked(g_gv, g) + mod = transform.InferTypeAll()(mod) + + expected = relay.FuncType([tensortype], tensortype) + tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected) + tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) + +def test_id_mutual(): + # f(x) = g(x) + # g(y) = f(y) + + tx = relay.TypeVar("x") + ty = relay.TypeVar("y") + + x = relay.Var("x", tx) + y = relay.Var("y", ty) + + f_gv = relay.GlobalVar('f') + g_gv = relay.GlobalVar('g') + + def body(var, call_func, type_param): + body = relay.Call(call_func, [var]) + func = relay.Function([var], body, type_params=type_param) + return func + + f = body(x, g_gv, [tx]) + g = body(y, f_gv, [ty]) + + mod = tvm.IRModule() + p = Prelude(mod) + mod.add_unchecked(f_gv, f) + mod.add_unchecked(g_gv, g) + mod = transform.InferTypeAll()(mod) + def test_if(): choice_t = relay.FuncType([], relay.scalar_type('bool')) From c788cca6957611e1bd574edd62a0fac91be12516 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 21 Jun 2020 16:24:39 -0700 Subject: [PATCH 08/19] undo lgi changes --- include/tvm/node/structural_equal.h | 2 - src/node/structural_equal.cc | 5 - src/relay/op/tensor/transform.cc | 3 - src/relay/op/tensor/transform.h | 3 - src/relay/transforms/lazy_gradient_init.cc | 807 +++--------------- .../relay/test_pass_lazy_gradient_init.py | 38 +- 6 files changed, 96 insertions(+), 762 deletions(-) diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 8c8aee7f0a3c..9424f6dc30f2 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -89,8 +89,6 @@ class StructuralEqual : public BaseValueEqual { * \return The comparison result. */ TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; - - TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) const; }; /*! diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 280624b36b02..e05cbbb60d1f 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -231,9 +231,4 @@ bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) con return RemapVarSEqualHandler(false).Equal(lhs, rhs, false); } -bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs, - bool map_free_vars) const { - return RemapVarSEqualHandler(false).Equal(lhs, rhs, map_free_vars); -} - } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 457dc2bf5a34..f1d5b7ae5e27 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -184,9 +184,6 @@ bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } const auto* param = attrs.as(); - if (param == nullptr) { - return false; - } const int ndim = static_cast(data->shape.size()); const int axis = param->axis; const int num_newaxis = param->num_newaxis; diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 99dc5904be56..4e5677a1af6d 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -59,9 +59,6 @@ bool ConcatenateRel(const Array& types, int num_inputs, const Attrs& attrs } const auto* param = attrs.as(); - if (param == nullptr) { - return false; - } if (tensor_tuple->fields[0].as()) { return false; } diff --git a/src/relay/transforms/lazy_gradient_init.cc b/src/relay/transforms/lazy_gradient_init.cc index 54752e387989..f06246667a8b 100644 --- a/src/relay/transforms/lazy_gradient_init.cc +++ b/src/relay/transforms/lazy_gradient_init.cc @@ -36,8 +36,8 @@ * It also overloads + and * operation which can increase performance when doing * operations involving tensors with values of only 0 or 1. * - * Note: this pass can only be used with functions where the input/output types are a - * combination of TupleTypes, TensorTypes, ADTs, and non-nested FuncTypes + * Note: this pass can only be used with functions where the input/output types are + * a combination of TupleTypes and TensorTypes * * This pass optimizes 6 ops: * - add @@ -47,652 +47,142 @@ * - zeros * - zeros_like * - * This module level pass adds a new "GradCell" version datatype for each existing datatype. - * This is the case to propogate the new GradCell datatype through ADTs such as Lists. - * For each function, a new function is created that accepts the "GradCell" type of the arguments - * of the original function. That is, inputs to the function are converted to their - * GradCell-version, passed to the newly created "GradCell_Function". The output is then necessarily - * converted from the GradCell version to the original return type. + * This pass makes use of three visitor. The most important one visits the entire function, + * one is used for wrap inputs and one to unwrap outputs. * - * To support ADTs, we use functions that convert between an instance of an ADT to its - * respective GradCell version - * by matching constructors to the constructor of the "GradCell" datatype. + * For example: + * fn: TensorType[(10,10), float32] -> TensorType[(10,10), float32] * - * A transformation function is required for different type arguments. - * For example the ADT List may be List[int32] or List[List[int32]], which should be handled - * separately. + * After this pass + * fn: GradCell[TensorType[(10,10), float32]] -> GradCell[TensorType[(10,10), float32]] * - * This pass uses 4 primary mutators: - * - LazyGradientInitializer to create the "GradCell_Function" of a given function. - * - GradCellWrapper mutates expr into its respective GradCell expr - * - GradCellWrapper mutates expr into its respective non-GradCell expr - * - ADTTransform creates a ADT for each unique ADT + * Thus, it is necessary to wrap this outer function so that the input/output types remain the same */ #include #include #include #include -#include #include -#include - #include "let_list.h" namespace tvm { namespace relay { -// prefix of name of GradCell version ADT -const char GradCell_Header[] = "_GradCell_"; -// prefix of transformation function for converting ADT to GradCell version -const char GradCell_TransFunc[] = "_GradCell_TransFunc_"; -// prefix of transformation function for converting GradCell version ADT to normal -const char GradCell_ReverseTransFunc[] = "_GradCell_ReverseTransFunc_"; -// prefix of copy of function that operates on GradCell types -const char GradCell_Func[] = "_GradCell_Func_"; - -struct TypeCallHash { - size_t operator()(const TypeCall& typecall) const { return ObjectHash()(typecall->func); } -}; - -/*! - * \brief Check if two ADT instances are equal, - * check for dataflow equivalence allow for mapping between TypeVars - * i.e GradCell[TypeVar(A)] = GradCell[TypeVar(B)] - */ -struct TypeCallEqual { - bool operator()(const TypeCall& l, const TypeCall& r) const { - if (!(l->func.same_as(r->func))) { - return false; - } - - if (l->args.size() != r->args.size()) { - return false; - } - - for (size_t i = 0; i < l->args.size(); i++) { - if (!tvm::StructuralEqual()(l->args[i], r->args[i], true)) { - return false; - } - } - - return true; - } -}; - /*! - * \brief ADTTransform creates a new ADT named - * GradCell_Header + name_hint for each unique ADT. + * \brief Visitor appropriately wraps tensors with Raw constructor + * + * Recursively looks at the type of the expression (TensorType or TupleType are only supported for + * now) and either call the GradCell constructor if TensorType or unfold and recursively visit if + * TupleType */ -class ADTTransform : public TypeMutator, public PatternMutator { +class InputVisitor : public ExprFunctor { public: - explicit ADTTransform(IRModule module) : module_(module) {} - - Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } - - Type VisitType_(const TensorTypeNode* op) final { - GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); - tvm::Array args; - args.push_back(GetRef(op)); - return TypeCall(gradCell, args); - } - - Type VisitType_(const GlobalTypeVarNode* op) final { - GlobalTypeVar t = GetRef(op); - if (op->kind == kAdtHandle) { - if (adt_mapping_.count(t) != 0) { - return adt_mapping_.at(t); - } + explicit InputVisitor(IRModule module) : module_(module) {} - TypeData adt = module_->LookupTypeDef(t); - this->VisitType(adt); - - return adt_mapping_.at(t); - } - - return GetRef(op); + Expr VisitExpr_(const VarNode* op, const Type& t) final { + std::cout << op->type_annotation << std::endl; + return WrapExpr(GetRef(op), op->type_annotation); } - Type VisitType_(const TypeDataNode* op) final { - auto type_data = GetRef(op); - std::string transformed_adt_name = GradCell_Header + op->header->name_hint; - - // add new ADT to map to handle recursive definitions - GlobalTypeVar new_adt = GlobalTypeVar(transformed_adt_name, op->header->kind); - adt_mapping_[type_data->header] = new_adt; - reverse_adt_mapping_[new_adt] = type_data->header; - - // define transformed ADT - Array constructors; - for (Constructor con : op->constructors) { - Array inputs; - for (Type t : con->inputs) { - inputs.push_back(this->VisitType(t)); - } - Constructor transformed_cons = Constructor(GradCell_Header + con->name_hint, inputs, new_adt); - constructors.push_back(transformed_cons); - } - - TypeData new_datatype = TypeData(new_adt, op->type_vars, constructors); - module_->AddTypeDef(new_adt, new_datatype); - return std::move(new_datatype); - } - - Pattern VisitPattern(const Pattern& c) final { return PatternMutator::VisitPattern(c); } - - Constructor VisitConstructor(const Constructor& c) final { - this->VisitType(c->belong_to); - return module_->GetConstructor(GradCell_Header + c->belong_to->name_hint, - GradCell_Header + c->name_hint); - } - - /*! - * \brief Given a transformed ADT, returned the original ADT. - * Useful for GradCellUnWrapper which needs to map transformed ADT constructors - * to the original ADT constructors. - * - * \param transformed_adt_handle GlobalTypeVar of "GradCell-version" of ADT - * \return ADT - */ - GlobalTypeVar GetReverseADT(GlobalTypeVar transformed_adt_handle) { - auto it = reverse_adt_mapping_.find(transformed_adt_handle); - - // reverse mapping should always be found - CHECK(it != reverse_adt_mapping_.end()) << "Reverse mapping of ADT transformation not found"; - return it->second; + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return WrapExpr(GetRef(op), t); } private: - // Module IRModule module_; - // ADT -> transformed ADT - std::unordered_map adt_mapping_; - // transformed ADT -> ADT - std::unordered_map reverse_adt_mapping_; -}; -/*! - * \brief Helper for TypeCallMutator. - * Replace TypeVar with type arguments - */ -class TypeVarSolver : public TypeMutator { - public: - explicit TypeVarSolver( - const std::unordered_map& type_var_map, - const std::unordered_map& type_call_map) - : type_var_map_(type_var_map), type_call_map_(type_call_map) {} - Type VisitType_(const TypeVarNode* op) final { - TypeVar type = GetRef(op); - - if (type_call_map_.count(type) != 0) { - // recursively visit Type argument to replace possible nested TypeVar - return VisitType(type_call_map_.at(type)); - } - - if (type_var_map_.count(type) != 0) { - return type_var_map_.at(type); + Expr WrapExpr(const Expr expr, const Type& type) { + if (type.as()) { + return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); + } else if (auto* type_anno = type.as()) { + tvm::Array fields; + for (size_t i = 0; i < type_anno->fields.size(); i++) { + const Type& t = type_anno->fields[i]; + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + } + Expr tuple = Tuple(fields); + return tuple; } - return std::move(type); + return expr; } - - private: - // type vars to unique type vars - std::unordered_map type_var_map_; - // TypeCall arguments to ADT - std::unordered_map type_call_map_; }; /*! - * \brief Find all TypeVars within the arguments of a TypeCallNode and create a mapping - * of the TypeVars to new TypeVars + * \brief Visitor appropriately unwraps expressions with GradCell type into Tensors + * + * Recursively looks at the type of the expression + * and either use the FromGradCell function if TypeCall to GradCell + * or unfold and recursively visit if TupleType */ -class TypeCallMutator : public TypeVisitor { +class OutputVisitor : public ExprFunctor { public: - // TypeVars within TypeCallNode - Array args; - // unique TypeVars - Array params; - explicit TypeCallMutator(IRModule module, const TypeCallNode* op) : module_(module) { - for (Type t : op->args) { - // visit each type argument - VisitType(t); - } - for (auto const& x : type_var_map) { - args.push_back(x.first); - params.push_back(x.second); - } - } + explicit OutputVisitor(IRModule module) : module_(module) {} - /*! - * \brief Replace ADT type vars with TypeCall arguments - * and replace type vars with unique typevars - * - * \param t TypeCall - * \param map TypeVar of ADT -> type argument - * - * \return type after replacing ADT TypeVar with arguments and replacing any - * free type vars with uniquely generated typevars - */ - - Type InputType(Type t, const std::unordered_map& map) { - return TypeVarSolver(type_var_map, map).VisitType(t); + Expr VisitExpr_(const CallNode* op, const Type& t) final { + return UnwrapExpr(GetRef(op), t); } - void VisitType_(const TypeVarNode* op) final { - TypeVar tv = GetRef(op); - if (type_var_map.count(tv) == 0) { - TypeVar replacement = TypeVar(tv->name_hint + "_", tv->kind); - type_var_map.insert({tv, replacement}); - } + Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final { + return UnwrapExpr(GetRef(op), t); } private: IRModule module_; - // TypeVar in argument -> TypeVar of polymorphic function - std::unordered_map type_var_map; -}; - -typedef class GradCellUnWrapper GradCellUnWrapper; - -/*! - * \brief Mutate a given expression into its "GradCell-version". - * TensorTypes are wrapped with the Raw constructor of GradCell. - * TupleTypes are recursively visited. - * ADTTypes are converted to its appropriate transformed ADT - * FuncTypes are wrapped with a function that appropriately wraps/unwraps input and output - */ -class GradCellWrapper : public ExprFunctor, - public TypeMutator { - public: - explicit GradCellWrapper(IRModule module, ADTTransform* adt_transformer) - : module_(module), adt_transformer_(adt_transformer), unique(0) {} - Expr VisitExpr_(const VarNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; - Expr VisitExpr_(const CallNode* op, const Type& t, GradCellUnWrapper* unwrapper) final; - - private: - // Module - IRModule module_; - // ADTTransform - ADTTransform* adt_transformer_; - // TypeCall -> Function to transform an ADT Instance into GradCell version - std::unordered_map adt_wrapper_map_; - // TypeVar of ADT call -> Type argument - std::unordered_map type_var_map; - // append to prefix to create unique function names for ADT wrapper functions - int64_t unique; - - Expr WrapExpr(const Expr expr, const Type& type, GradCellUnWrapper* unwrapper); - // Return function to wrap ADT - Expr GetADTFunction(const TypeCallNode* op, TypeCallMutator& type_args, - GradCellUnWrapper* unwrapper); - Type VisitType_(const GlobalTypeVarNode* op) final; - Type VisitType_(const TensorTypeNode* op) final; -}; - -/*! - * \brief Mutate a given "GradCell-version" expression into its nonGradCell-version. - * TypeCalls to GradCell are wrapped with FromGradCell function - * TupleTypes are recursively visited. - * Transformed ADTs are converted to its appropriate normal ADT - */ -class GradCellUnWrapper : public ExprFunctor, public TypeMutator { - public: - explicit GradCellUnWrapper(IRModule module, ADTTransform* adt_transformer) - : module_(module), adt_transformer_(adt_transformer), unique(0) {} - Expr VisitExpr_(const VarNode* op, const Type& t) final; - Expr VisitExpr_(const TupleGetItemNode* op, const Type& t) final; - Expr VisitExpr_(const CallNode* op, const Type& t) final; - Expr VisitExpr_(const TupleNode* op, const Type& t) final; - Expr VisitExpr_(const ConstantNode* op, const Type& t) final; - - private: - // Module - IRModule module_; - // ADTTransform - ADTTransform* adt_transformer_; - // TypeCall -> Function an GradCell_ADT into ADT - std::unordered_map adt_unwrapper_map_; - // TypeVar of GradCell_ADT call -> Type argument - std::unordered_map type_var_map; - // create unique strings for ADT unwrapper functions - int64_t unique; - - Expr UnwrapExpr(const Expr expr, const Type& type); - // Return function to unwrap ADT - Expr GetReverseADTFunction(const TypeCallNode* op, TypeCallMutator& type_args); - Type VisitType_(const TypeCallNode* op) final; - Type VisitType_(const GlobalTypeVarNode* op) final; -}; - -/* GradCellWrapper */ -Expr GradCellWrapper::VisitExpr_(const VarNode* op, const Type& t, GradCellUnWrapper* unwrapper) { - return WrapExpr(GetRef(op), op->type_annotation, unwrapper); -} - -Expr GradCellWrapper::VisitExpr_(const TupleGetItemNode* op, const Type& t, - GradCellUnWrapper* unwrapper) { - return WrapExpr(GetRef(op), t, unwrapper); -} - -Expr GradCellWrapper::VisitExpr_(const CallNode* op, const Type& t, GradCellUnWrapper* unwrapper) { - return WrapExpr(GetRef(op), t, unwrapper); -} - -Expr GradCellWrapper::WrapExpr(const Expr expr, const Type& type, GradCellUnWrapper* unwrapper) { - if (type.as()) { - return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type}); - } - - if (auto* type_anno = type.as()) { - tvm::Array fields; - for (size_t i = 0; i < type_anno->fields.size(); i++) { - const Type& t = type_anno->fields[i]; - // recursively visit each item of tuple - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t, unwrapper)); - } - Expr tuple = Tuple(fields); - return tuple; - } - - if (auto* type_anno = type.as()) { - // create GradCell_ADT if not already created - adt_transformer_->VisitType(type_anno->func); - // find all type vars within type_anno - // to handle polymorphic functions - auto tvs = TypeCallMutator(module_, type_anno); - - return Call(GetADTFunction(type_anno, tvs, unwrapper), {expr}, Attrs(), tvs.args); - } - - if (auto* type_anno = type.as()) { - // to handle functions, we need to create a new function - // that handles GradCell version input and outputs GradCell version types - Array funcVars; - Array args; - for (Type t : type_anno->arg_types) { - Type visited = this->VisitType(t); - Var v = Var("v", visited); - funcVars.push_back(v); - // unwrap arguments - args.push_back(unwrapper->VisitExpr(v, visited)); - } - // call original expr with unwrapped arguments - Call call = Call(expr, args); - // wrap results of the call - Expr result = this->WrapExpr(call, type_anno->ret_type, unwrapper); - // return new function with GradCell-version types, wrapping original function - return Function(funcVars, result, this->VisitType(type_anno->ret_type), type_anno->type_params); - } - - return expr; -} - -Expr GradCellWrapper::GetADTFunction(const TypeCallNode* op, TypeCallMutator& type_args, - GradCellUnWrapper* unwrapper) { - auto type = GetRef(op); - GlobalTypeVar adt_handle = Downcast(op->func); - if (adt_wrapper_map_.count(type) != 0) { - // ADT already wrapped previously - return adt_wrapper_map_.at(type); - } - - // handle recursive ADT which require recursive calls to transform - GlobalVar func_var = GlobalVar(GradCell_Header + std::string(GradCell_TransFunc) + - adt_handle->name_hint + std::to_string(unique++)); - adt_wrapper_map_[type] = func_var; - - TypeData adt_data = module_->LookupTypeDef(adt_handle); - TypeData new_adt_data = module_->LookupTypeDef(GradCell_Header + adt_handle->name_hint); - // solve for input type to wrap ADT function - for (size_t i = 0; i < adt_data->type_vars.size(); i++) { - type_var_map[adt_data->type_vars[i]] = op->args[i]; - } - auto input_type = type_args.InputType(type, type_var_map); - - CHECK(adt_data->constructors.size() == new_adt_data->constructors.size()) - << "ADT and transformed ADT have different number of constructors"; - - /* - * Pattern match each constructor of the ADT to the respective constructor - * in the transformed ADT. PatternVars then need to be recursively wrapped, - * and passed as argument to the constructor of the transformed ADT - */ - Array clauses; - for (size_t i = 0; i < adt_data->constructors.size(); i++) { - // get Constructor to pattern match against - Array patternVars; - Array c_args; - Constructor c = adt_data->constructors[i]; - for (Type t : c->inputs) { - // solve for type of PatternVar - Type pattern_var_type = type_args.InputType(t, type_var_map); - Var v = Var("var", pattern_var_type); - patternVars.push_back(PatternVar(v)); - // recursively wrap - c_args.push_back(this->VisitExpr(v, pattern_var_type, unwrapper)); - } - Pattern p = PatternConstructor(c, patternVars); - // return Constructor of new ADT with wrapped arguments - Expr e = Call(new_adt_data->constructors[i], c_args); - - clauses.push_back(Clause(p, e)); - } - - Var v = Var("v", input_type); - Expr match = Match(v, clauses); - - Function func = Function({v}, match, this->VisitType(input_type), type_args.params); - module_->AddUnchecked(func_var, func); - return std::move(func); -} - -Type GradCellWrapper::VisitType_(const GlobalTypeVarNode* op) { - GlobalTypeVar t = GetRef(op); - if (op->kind == kAdtHandle) { - return adt_transformer_->VisitType(t); - } - - return GetRef(op); -} - -Type GradCellWrapper::VisitType_(const TensorTypeNode* op) { - GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); - tvm::Array args; - args.push_back(GetRef(op)); - return TypeCall(gradCell, args); -} - -/* GradCellUnWrapper */ -Expr GradCellUnWrapper::VisitExpr_(const CallNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), t); -} - -Expr GradCellUnWrapper::VisitExpr_(const TupleGetItemNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), t); -} - -Expr GradCellUnWrapper::VisitExpr_(const VarNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), op->type_annotation); -} - -Expr GradCellUnWrapper::VisitExpr_(const TupleNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), t); -} - -Expr GradCellUnWrapper::VisitExpr_(const ConstantNode* op, const Type& t) { - return UnwrapExpr(GetRef(op), t); -} - -Expr GradCellUnWrapper::UnwrapExpr(const Expr expr, const Type& type) { - if (auto* type_call = type.as()) { - if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - // if TypeCall to GradCell, simply wrap with FromGradCell function - return Call(module_->GetGlobalVar("FromGradCell"), {expr}, Attrs(), type_call->args); - } - - // convert transformed ADT to ADT - auto tvs = TypeCallMutator(module_, type_call); - return Call(GetReverseADTFunction(type_call, tvs), {expr}, Attrs(), tvs.args); - } - - if (auto* type_anno = type.as()) { - tvm::Array fields; - for (size_t i = 0; i < type_anno->fields.size(); i++) { - // recursively unwrap items of tuple - const Type& t = type_anno->fields[i]; - fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); - } - Expr tuple = Tuple(fields); - return tuple; - } - return expr; -} - -Expr GradCellUnWrapper::GetReverseADTFunction(const TypeCallNode* op, TypeCallMutator& type_args) { - TypeCall type = GetRef(op); - GlobalTypeVar transformed_adt_handle = Downcast(op->func); - GlobalTypeVar adt_handle = adt_transformer_->GetReverseADT(transformed_adt_handle); - - // sanity check - CHECK(std::string(transformed_adt_handle->name_hint).rfind(GradCell_Header, 0) == 0) - << "Output ADT is not a transformed ADT"; - - if (adt_unwrapper_map_.count(type)) { - // transformed ADT unwrapped previously - return adt_unwrapper_map_.at(type); - } - - // handle recursive ADTs - GlobalVar func_var = GlobalVar(GradCell_Header + std::string(GradCell_ReverseTransFunc) + - adt_handle->name_hint + std::to_string(unique++)); - adt_unwrapper_map_[type] = func_var; - - TypeData adt_data = module_->LookupTypeDef(adt_handle); - TypeData transformed_adt_data = module_->LookupTypeDef(transformed_adt_handle); - - CHECK(adt_data->type_vars.size() == transformed_adt_data->type_vars.size()) - << "ADT and transformed ADT have different # of type args"; - - // solve for TypeVars of ADT to solve for input type of function - for (size_t i = 0; i < transformed_adt_data->type_vars.size(); i++) { - type_var_map[adt_data->type_vars[i]] = op->args[i]; - } - auto input_type = type_args.InputType(type, type_var_map); - - CHECK(adt_data->constructors.size() == transformed_adt_data->constructors.size()) - << "ADT and transformed ADT have different number of constructors"; - - // use same logic as wrapping expression - // Pattern match with each Constructor of the transformed ADT, - // return respective Constructor with arguments of unwrapped PatternVars - Array clauses; - for (size_t i = 0; i < transformed_adt_data->constructors.size(); i++) { - // Get Constructor of transformed ADT - Array patternVars; - Array c_args; - Constructor c = transformed_adt_data->constructors[i]; - for (Type t : c->inputs) { - // solve for type of pattern var - Type pattern_var_type = type_args.InputType(t, type_var_map); - Var v = Var("var", pattern_var_type); - // bind PatternVar to Var passed to constructor - patternVars.push_back(PatternVar(v)); - // recursively unwrap - c_args.push_back(this->VisitExpr(v, pattern_var_type)); + Expr UnwrapExpr(const Expr expr, const Type& type) { + if (auto* type_call = type.as()) { + if (type_call->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { + return Call(module_->GetGlobalVar("FromGradCell"), {expr}); + } + return expr; + } else if (auto* type_anno = type.as()) { + tvm::Array fields; + for (size_t i = 0; i < type_anno->fields.size(); i++) { + const Type& t = type_anno->fields[i]; + fields.push_back(this->VisitExpr(TupleGetItem(expr, i), t)); + } + Expr tuple = Tuple(fields); + return tuple; } - Pattern p = PatternConstructor(c, patternVars); - // Call appropriate Constructor - Expr e = Call(adt_data->constructors[i], c_args); - clauses.push_back(Clause(p, e)); + return expr; } +}; - Var v = Var("v", input_type); - Expr match = Match(v, clauses); - - Function func = Function({v}, match, this->VisitType(input_type), type_args.params); - module_->AddUnchecked(func_var, func); - return std::move(func); -} - -Type GradCellUnWrapper::VisitType_(const TypeCallNode* op) { - if (op->func.same_as(module_->GetGlobalTypeVar("GradCell"))) { - return op->args[0]; - } - return TypeMutator::VisitType_(op); -} - -Type GradCellUnWrapper::VisitType_(const GlobalTypeVarNode* op) { - GlobalTypeVar t = GetRef(op); - if (op->kind == kAdtHandle) { - return adt_transformer_->GetReverseADT(t); - } - - return GetRef(op); -} - -class LazyGradientInitializer : public ExprMutator, public TypeMutator, public PatternMutator { +class LazyGradientInitializer : public ExprMutator, public TypeMutator { public: explicit LazyGradientInitializer(IRModule module) : module_(module) { - // setup - adt_transformer_ = new ADTTransform(module_); - grad_cell_wrapper_ = new GradCellWrapper(module_, adt_transformer_); - grad_cell_unwrapper_ = new GradCellUnWrapper(module_, adt_transformer_); - - // import GradCell and GradCell functions module_->ImportFromStd("gradient.rly"); - - // ignore these functions when transforming - GlobalVar from_grad_cell = module_->GetGlobalVar("FromGradCell"); - GlobalVar mul_grad_cell = module_->GetGlobalVar("MultiplyGradCell"); - GlobalVar add_grad_cell = module_->GetGlobalVar("AddGradCell"); - - func_map_[from_grad_cell] = from_grad_cell; - func_map_[mul_grad_cell] = mul_grad_cell; - func_map_[add_grad_cell] = add_grad_cell; } /*! - * \brief Given a global function, create new global function - * that mirrors the functionality however using GradCell type. - * Original function will wrap inputs, call the mirrored function, unwrap the ouput, - * and return. + * \brief apply LazyGradientInit transformation and wrap function + * so that function type stays the same + * + * input/output types should only be a combination of TupleTypes and TensorTypes */ - BaseFunc VisitGlobalVar(const GlobalVar& gv) { - auto base_func = module_->Lookup(gv); - if (auto* e = base_func.as()) { - auto f = GetRef(e); - if (func_map_.count(gv) == 0) { - // create GlobalVar handle for function - func_map_[gv] = GlobalVar(GradCell_Func + gv->name_hint); - } - GlobalVar func_var = func_map_.at(gv); - if (module_->ContainGlobalVar(func_var->name_hint)) { - // transformed function already contained in IRModule, return - return module_->Lookup(func_var); - } - // create transformed function and add definition to IRModule - auto* transformed = ExprMutator::Mutate(f).as(); - module_->AddUnchecked(func_var, GetRef(transformed)); - - // wrap inputs of Tensor type using GradCellWrapper class - tvm::Array args; - for (Var var : f->params) { - Expr wrappedInput = - grad_cell_wrapper_->VisitExpr(var, var->checked_type(), grad_cell_unwrapper_); - args.push_back(wrappedInput); - } - Expr transformedExpr = Call(func_var, args); + Expr Transform(const Expr& e) { + auto* f = (e).as(); + auto* transformed = this->Mutate(e).as(); + + if (e.same_as(GetRef(transformed))) { + return GetRef(transformed); + } - // unwrap outputs of GradCell type into Tensor type using OutputVisitor class - Expr tensorOutput = grad_cell_unwrapper_->VisitExpr(transformedExpr, transformed->ret_type); - return Function(f->params, tensorOutput, f->ret_type, f->type_params); + // wrap inputs of Tensor type using InputVisitor class + tvm::Array args; + for (Var var : f->params) { + Expr wrappedInput = InputVisitor(module_).VisitExpr(var, var->checked_type()); + args.push_back(wrappedInput); } - throw std::runtime_error("GlobalVar does not map to a function"); + Expr transformedExpr = Call(GetRef(transformed), args); + + // unwrap outputs of GradCell type into Tensor type using OutputVisitor class + Expr tensorOutput = OutputVisitor(module_).VisitExpr(transformedExpr, transformed->ret_type); + return Function(f->params, tensorOutput, f->ret_type, Array()); } Expr VisitExpr_(const ConstantNode* op) final { @@ -736,124 +226,26 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator, public P // handle all other ops Expr result = CallPrimitiveOp(call_node); // wrap result with Raw constructor - return grad_cell_wrapper_->VisitExpr(result, call_node->checked_type(), grad_cell_unwrapper_); - } - - if (auto* op = (call_node->op).as()) { - // create "GradCell-version" of ADT if not already created - adt_transformer_->VisitType(op->belong_to); - // call Constructor of transformed ADT - Constructor c = module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, - GradCell_Header + op->name_hint); - Array args; - for (Expr e : call_node->args) { - args.push_back(this->VisitExpr(e)); - } - - Array type_args; - for (Type t : call_node->type_args) { - type_args.push_back(this->VisitType(t)); - } - return Call(c, args, Attrs(), type_args); + return Call(module_->GetConstructor("GradCell", "Raw"), {result}, Attrs(), + {call_node->checked_type()}); } - + // not an op return ExprMutator::VisitExpr_(call_node); } - Expr VisitExpr_(const ConstructorNode* op) final { - return module_->GetConstructor(GradCell_Header + op->belong_to->name_hint, - GradCell_Header + op->name_hint); - } - - Expr VisitExpr_(const IfNode* op) final { - auto true_b = VisitExpr(op->true_branch); - auto false_b = VisitExpr(op->false_branch); - - // guard is bool type which will become GradCell[bool], so necessary to unwrap - auto guard = - grad_cell_unwrapper_->VisitExpr(VisitExpr(op->cond), VisitType(op->cond->checked_type())); - return If(guard, true_b, false_b); - } - - Expr VisitExpr_(const VarNode* op) final { - auto var = GetRef(op); - if (var_map_.count(var) != 0) { - return var_map_.at(var); - } - - return ExprMutator::VisitExpr_(op); - } - - Expr VisitExpr_(const GlobalVarNode* op) final { - // GlobalVar is a handle to a global function - GlobalVar gv = GetRef(op); - if (func_map_.count(gv) == 0) { - // create handle to transformed function - func_map_[gv] = GlobalVar(GradCell_Func + op->name_hint); - } - return func_map_.at(gv); - } - Type VisitType(const Type& t) final { return TypeMutator::VisitType(t); } - Type VisitType_(const GlobalTypeVarNode* op) final { - GlobalTypeVar t = GetRef(op); - if (module_->GetGlobalTypeVar("GradCell").same_as(t)) { - // if GradCell type, do nothing - return std::move(t); - } - if (op->kind == kAdtHandle) { - // handle to ADT, define GradCell version of ADT is not already created - return adt_transformer_->VisitType(t); - } - - return std::move(t); - } - - Var VisitVar(const Var& v) final { - // used for PatternMutator - if (var_map_.count(v) == 0) { - var_map_.insert(std::pair(v, Var(v->name_hint(), VisitType(v->type_annotation)))); - } - return var_map_.at(v); - } - - Type VisitType_(const TensorTypeNode* op) final { + Type VisitType_(const TensorTypeNode* op) { GlobalTypeVar gradCell = module_->GetGlobalTypeVar("GradCell"); tvm::Array args; args.push_back(GetRef(op)); return TypeCall(gradCell, args); } - Pattern VisitPattern(const Pattern& c) final { return PatternMutator::VisitPattern(c); } - - Constructor VisitConstructor(const Constructor& c) final { - adt_transformer_->VisitType(c->belong_to); - return module_->GetConstructor(GradCell_Header + c->belong_to->name_hint, - GradCell_Header + c->name_hint); - } - - ~LazyGradientInitializer() { - // destructors - delete grad_cell_wrapper_; - delete grad_cell_unwrapper_; - delete adt_transformer_; - } - private: // Module IRModule module_; - // pass single instance of ADTTransform to save state of ADTs transformed - ADTTransform* adt_transformer_; - // pass single instance of ADTTransform to save state of ADTs wrapped - GradCellWrapper* grad_cell_wrapper_; - // pass single instance of ADTTransform to save state of ADTs unwrapped - GradCellUnWrapper* grad_cell_unwrapper_; - // var map used for transforming a Clause - std::unordered_map var_map_; - // handle of function -> handle of transformed function - std::unordered_map func_map_; /*! * \brief Convert call_node to add/multiply op to use overloaded functions for GradCell type */ @@ -891,35 +283,26 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator, public P Expr CallPrimitiveOp(const CallNode* call_node) { const auto fromFunc = module_->GetGlobalVar("FromGradCell"); tvm::Array args; - - // unwrap arguments + // use FromGradCell to convert args to Tensor for (Expr expr : call_node->args) { - args.push_back( - grad_cell_unwrapper_->VisitExpr(VisitExpr(expr), VisitType(expr->checked_type()))); + args.push_back(Call(fromFunc, {VisitExpr(expr)}, Attrs(), {expr->checked_type()})); } // result of operation return Call(call_node->op, args, call_node->attrs); } }; -IRModule LazyGradientInit(const IRModule& m) { - LazyGradientInitializer lgi = LazyGradientInitializer(m); - std::vector gvs; - for (const auto& p : m->functions) { - gvs.push_back(p.first); - } - for (const auto& gv : gvs) { - m->AddUnchecked(gv, lgi.VisitGlobalVar(gv)); - } - m->Check(); - return m; +Expr LazyGradientInit(const Expr& e, IRModule mod) { + return LazyGradientInitializer(mod).Transform(e); } namespace transform { Pass LazyGradientInit() { - runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext pc) { return relay::LazyGradientInit(m); }; - return CreateModulePass(pass_func, 1, "LazyGradientInit", {}); + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LazyGradientInit(f, m)); + }; + return CreateFunctionPass(pass_func, 2, "LazyGradientInit", {}); } TVM_REGISTER_GLOBAL("relay._transform.LazyGradientInit").set_body_typed(LazyGradientInit); diff --git a/tests/python/relay/test_pass_lazy_gradient_init.py b/tests/python/relay/test_pass_lazy_gradient_init.py index 7f1ca3ca6719..414926802870 100644 --- a/tests/python/relay/test_pass_lazy_gradient_init.py +++ b/tests/python/relay/test_pass_lazy_gradient_init.py @@ -21,7 +21,6 @@ from tvm.relay import create_executor, transform from tvm.relay.testing import rand, run_infer_type from tvm.testing import assert_allclose -from tvm.relay.prelude import Prelude import pytest def test_tc(): @@ -81,6 +80,7 @@ def test_add_tuple(): mod["main"] = y mod = transform.LazyGradientInit()(mod) + mod = tvm.transform.PrintIR(show_meta_data=True)(mod) y = mod["main"] assert mod["main"].checked_type == relay.FuncType([t], tensor_type) @@ -391,41 +391,5 @@ def test_ones_like(): y = ex.evaluate(y)(x) assert_allclose(y.asnumpy(), x.asnumpy() + np.ones_like(x.asnumpy())) -def test_list_adt(): - """test prelude functions on list ADT. which is a recursive ADT""" - mod = tvm.IRModule() - p = Prelude(mod) - - cons = p.cons - nil = p.nil - - mod = transform.LazyGradientInit()(mod) - - ex = create_executor(mod=mod) - - def to_list_adt(list): - l = nil() - for x in list: - l = cons(relay.const(x), l) - return ex.evaluate(l) - - def from_list_adt(list): - l = [] - def rec(x): - if x.constructor.tag == cons.tag: - l.insert(0, x.fields[0].asnumpy().tolist()) - rec(x.fields[1]) - rec(list) - return l - - # test sum - x = np.random.randint(1,101,10) - assert sum(x) == ex.evaluate(mod['sum'])(to_list_adt(x)).asnumpy() - - # test reverse - x = np.random.rand(10) - actual = from_list_adt(ex.evaluate(mod['rev'])(to_list_adt(x))) - assert_allclose(x[::-1], actual) - if __name__ == "__main__": pytest.main([__file__]) From ed183e352888034512b91bd177b406ea878e5997 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Sun, 21 Jun 2020 19:17:42 -0700 Subject: [PATCH 09/19] remove unrelated changes --- include/tvm/ir/module.h | 5 ----- src/ir/module.cc | 39 --------------------------------------- 2 files changed, 44 deletions(-) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 63640c3d65df..7af84b687f5f 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -85,11 +85,6 @@ class IRModuleNode : public Object { */ TVM_DLL void AddUnchecked(const GlobalVar& var, const BaseFunc& func); - /*! - * \brief Infer the type of all global functions - */ - TVM_DLL void Check(); - /*! * \brief Add a type-level definition to the global environment. * \param var The var of the global type definition. diff --git a/src/ir/module.cc b/src/ir/module.cc index faa5e6a5a71a..f28a30709e8a 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -231,45 +231,6 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) { global_var_map_.Set(var->name_hint, var); } -void IRModuleNode::Check() { - const auto& mod = GetRef(this); - const auto& globalvars = this->GetGlobalVars(); - - // first pass: fill in type for all functions - for (const auto& var : globalvars) { - relay::Function f = Downcast(this->Lookup(var)); - auto func = Downcast(relay::DeDup(std::move(f))); - - auto fv = relay::FreeVars(func); - auto ftv = relay::FreeTypeVars(func, mod); - if (fv.size() != 0) { - LOG(WARNING) << "There are free variables: " << fv << " in function: " << AsText(func, false) - << std::endl; - } - if (ftv.size() != 0) { - LOG(WARNING) << "There are free type variables: " << ftv - << " in function: " << AsText(func, false) << std::endl; - } - auto func_copy = relay::Function(concat(func->params, fv), func->body, func->ret_type, - concat(func->type_params, ftv), func->attrs); - - func_copy->checked_type_ = func_copy->func_type_annotation(); - mod->AddUnchecked(var, func_copy); - } - - // second pass: type inference on every function - for (const auto& var : globalvars) { - auto func = Downcast(this->Lookup(var)); - relay::Function checked_func = InferType(func, mod, var); - - Type type = checked_func->checked_type(); - CHECK(type.as() == nullptr) << "NULL"; - - var->checked_type_ = type; - mod->AddUnchecked(var, func); - } -} - void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { // We hash the global type var name to use as a globally unique prefix for tags. // The hash will be used as the most significant byte of the tag, with the index of From 47814b9f2b5dcdf02b6b7d720491c13efc8a5172 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 19 Aug 2020 11:57:15 -0700 Subject: [PATCH 10/19] seems to be working? --- src/relay/transforms/type_infer.cc | 58 ++++++++++++++++++--------- tests/python/relay/test_type_infer.py | 3 +- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 05fe4b458523..67afaa6e4d19 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -118,6 +118,37 @@ class TypeInferencer : private ExprFunctor, void Solve(); Expr ResolveType(Expr expr); + // Lazily get type for expr + // expression, we will populate it now, and return the result. + Type GetType(const Expr& expr) { + auto it = type_map_.find(expr); + if (it != type_map_.end() && it->second.checked_type.defined()) { + return it->second.checked_type; + } + Type ret = this->VisitExpr(expr); + CHECK(ret.defined()); + KindCheck(ret, mod_); + ResolvedTypeInfo& rti = type_map_[expr]; + rti.checked_type = ret; + return ret; + } + + // Lazily get type for expr + // expression, we will populate it now, and return the result. + Type GetTypeGlobalVar(const GlobalVar& expr) { + auto f = Downcast(mod_->Lookup(expr)); + Type ret = GetType(f); + auto it = type_map_.find(expr); + if (it != type_map_.end() && it->second.checked_type.defined()) { + return it->second.checked_type; + } + CHECK(ret.defined()); + KindCheck(ret, mod_); + ResolvedTypeInfo& rti = type_map_[expr]; + rti.checked_type = ret; + return ret; + } + private: // type resolver that maps back to type class Resolver; @@ -152,21 +183,6 @@ class TypeInferencer : private ExprFunctor, } } - // Lazily get type for expr - // expression, we will populate it now, and return the result. - Type GetType(const Expr& expr) { - auto it = type_map_.find(expr); - if (it != type_map_.end() && it->second.checked_type.defined()) { - return it->second.checked_type; - } - Type ret = this->VisitExpr(expr); - CHECK(ret.defined()); - KindCheck(ret, mod_); - ResolvedTypeInfo& rti = type_map_[expr]; - rti.checked_type = ret; - return ret; - } - void ReportFatalError(const ObjectRef& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); @@ -777,10 +793,10 @@ IRModule InferTypeAll(const IRModule& mod) { // second pass, fill in constraints for (const auto& var : globalvars) { - relay::Function func = Downcast(mod->Lookup(var)); - ti.SetCurrentFunc(var); - ti.GenerateConstraints(func); - // ti.Infer(func); + // relay::Function func = Downcast(mod->Lookup(var)); + // ti.SetCurrentFunc(var); + // ti.GenerateConstraints(func); + ti.GetTypeGlobalVar(var); } std::cout << "done generating constraints" << std::endl; @@ -792,9 +808,11 @@ IRModule InferTypeAll(const IRModule& mod) { // third pass for (const auto& var : globalvars) { relay::Function func = Downcast(mod->Lookup(var)); - ti.SetCurrentFunc(var); + // ti.SetCurrentFunc(var); Expr func_ret = ti.ResolveType(func); + std::cout << "resolved " << var << std::endl; //Expr func_ret = ti.Infer(func); + std::cout << var << ": " << func_ret->checked_type() << std::endl; mod->AddUnchecked(var, Downcast(func_ret)); } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index a49646588c75..aca0860f964a 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -390,7 +390,7 @@ def body(var, call_func): g = body(y, f_gv) mod = tvm.IRModule() - p = Prelude(mod) + # p = Prelude(mod) mod.add_unchecked(f_gv, f) mod.add_unchecked(g_gv, g) mod = transform.InferTypeAll()(mod) @@ -421,7 +421,6 @@ def body(var, call_func, type_param): g = body(y, f_gv, [ty]) mod = tvm.IRModule() - p = Prelude(mod) mod.add_unchecked(f_gv, f) mod.add_unchecked(g_gv, g) mod = transform.InferTypeAll()(mod) From 27646002859518cc3d9d243dd8b4a6aeacd8b00b Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 19 Aug 2020 19:24:52 -0700 Subject: [PATCH 11/19] add prelude test --- tests/python/relay/test_type_infer.py | 48 +++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index aca0860f964a..5273f08633f8 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -399,6 +399,54 @@ def body(var, call_func): tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected) tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) +def test_mutual_recursion_list_sum(): + # f[A](x: List[A]) = match x { + # Cons(a, Nil) => a + # Cons(_, b) => g(b) + # } + # g[B](y: List[B]) = match y { + # Cons(a, Nil) => a + # Cons(_, b) => f(b) + # } + p = Prelude() + l = p.l + A = relay.TypeVar("x") + B = relay.TypeVar("y") + + x = relay.Var("x", l(A)) + y = relay.Var("y", l(B)) + + f_gv = relay.GlobalVar('f') + g_gv = relay.GlobalVar('g') + + def body(var, call_func, type_param): + a = relay.Var("a") + b = relay.Var("b") + body = relay.Match( + var, + [ + relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternVar(a), relay.PatternConstructor(p.nil)]), a), + relay.Clause(relay.PatternConstructor(p.cons, [relay.PatternWildcard(), relay.PatternVar(b)]), relay.Call(call_func, [b])) + ], + complete=False + ) + func = relay.Function([var], body, type_params=type_param) + return func + + f = body(x, g_gv, [A]) + g = body(y, f_gv, [B]) + + mod = p.mod + mod.add_unchecked(f_gv, f) + mod.add_unchecked(g_gv, g) + mod = transform.InferTypeAll()(mod) + + tv = relay.TypeVar("test") + expected = relay.FuncType([l(tv)], tv, [tv]) + tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected) + tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) + + def test_id_mutual(): # f(x) = g(x) # g(y) = f(y) From 2907ee8199c812d132e252d281c8135b138302a1 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 19 Aug 2020 19:41:06 -0700 Subject: [PATCH 12/19] working rn --- include/tvm/relay/transform.h | 8 +++++++ src/relay/transforms/type_infer.cc | 2 +- tests/python/relay/test_type_infer.py | 32 ++------------------------- 3 files changed, 11 insertions(+), 31 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index c83beb133670..73ce05df6e63 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -210,6 +210,14 @@ TVM_DLL Pass FastMath(); */ TVM_DLL Pass InferType(); +/*! + * \brief Infer the type of all functions in a module. + * + * This pass should be used when typechecking modules + * with mutually recursive functions. + * + * \return The pass. + */ TVM_DLL Pass InferTypeAll(); /*! diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 67afaa6e4d19..f87a435640eb 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -789,7 +789,7 @@ IRModule InferTypeAll(const IRModule& mod) { mod->AddUnchecked(var, func); } - TypeInferencer ti = TypeInferencer(mod, GlobalVar("all")); + TypeInferencer ti = TypeInferencer(mod, GlobalVar("dummy")); // second pass, fill in constraints for (const auto& var : globalvars) { diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 5273f08633f8..6de40f3fda19 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -364,7 +364,7 @@ def test_let_polymorphism(): int32 = relay.TensorType((), "int32") tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) -def test_mutual_recursion2(): +def test_mutual_recursion(): # f(x) = if x > 0 then g(x - 1) else 0 # g(y) = if y > 0 then f(y - 1) else 0 tensortype = relay.TensorType((), 'float32') @@ -399,7 +399,7 @@ def body(var, call_func): tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected) tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) -def test_mutual_recursion_list_sum(): +def test_mutual_recursion_adt(): # f[A](x: List[A]) = match x { # Cons(a, Nil) => a # Cons(_, b) => g(b) @@ -446,34 +446,6 @@ def body(var, call_func, type_param): tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected) tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) - -def test_id_mutual(): - # f(x) = g(x) - # g(y) = f(y) - - tx = relay.TypeVar("x") - ty = relay.TypeVar("y") - - x = relay.Var("x", tx) - y = relay.Var("y", ty) - - f_gv = relay.GlobalVar('f') - g_gv = relay.GlobalVar('g') - - def body(var, call_func, type_param): - body = relay.Call(call_func, [var]) - func = relay.Function([var], body, type_params=type_param) - return func - - f = body(x, g_gv, [tx]) - g = body(y, f_gv, [ty]) - - mod = tvm.IRModule() - mod.add_unchecked(f_gv, f) - mod.add_unchecked(g_gv, g) - mod = transform.InferTypeAll()(mod) - - def test_if(): choice_t = relay.FuncType([], relay.scalar_type('bool')) f = relay.Var('f', choice_t) From 9e58d7eb80ba352317c00a7ce85e28d67ced8c9e Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 19 Aug 2020 20:02:47 -0700 Subject: [PATCH 13/19] fixup --- python/tvm/ir/module.py | 2 +- python/tvm/relay/transform/transform.py | 2 +- src/ir/module.cc | 3 +- src/relay/analysis/type_solver.h | 4 +-- src/relay/transforms/type_infer.cc | 44 ++++++++----------------- 5 files changed, 19 insertions(+), 36 deletions(-) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 9b13f6e34862..ca7f5a5b1120 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -90,7 +90,7 @@ def _add(self, var, val, update=False): def add_unchecked(self, var, val): assert isinstance(val, _expr.RelayExpr) _ffi_api.Module_AddUnchecked(self, var, val) - + def __getitem__(self, var): """Lookup a global definition by name or by variable. diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index dd154de61d8d..2d1ea829a92b 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -100,7 +100,7 @@ def InferType(): def InferTypeAll(): return _ffi_api.InferTypeAll() - + def FoldScaleAxis(): """Fold the scaling of axis into weights of conv2d/dense. This pass will invoke both forward and backward scale folding. diff --git a/src/ir/module.cc b/src/ir/module.cc index f28a30709e8a..55b1beae3b96 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -403,7 +403,8 @@ TVM_REGISTER_GLOBAL("ir.Module_Add").set_body([](TVMArgs args, TVMRetValue* ret) TVM_REGISTER_GLOBAL("ir.Module_AddDef").set_body_method(&IRModuleNode::AddTypeDef); -TVM_REGISTER_GLOBAL("ir.Module_AddUnchecked").set_body_method(&IRModuleNode::AddUnchecked); +TVM_REGISTER_GLOBAL("ir.Module_AddUnchecked") + .set_body_method(&IRModuleNode::AddUnchecked); TVM_REGISTER_GLOBAL("ir.Module_GetGlobalVar") .set_body_method(&IRModuleNode::GetGlobalVar); diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index ab1593415b04..e0748057a099 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -66,9 +66,7 @@ class TypeSolver { TypeSolver(const GlobalVar& current_func, const IRModule& _mod, ErrorReporter* err_reporter); ~TypeSolver(); - void SetCurrentFunc(GlobalVar current_func) { - this->current_func = current_func; - } + void SetCurrentFunc(GlobalVar current_func) { this->current_func = current_func; } /*! * \brief Add a type constraint to the solver. diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index f87a435640eb..1de8ba3ca4b1 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -114,7 +114,6 @@ class TypeInferencer : private ExprFunctor, this->solver_.SetCurrentFunc(current_func); } - void GenerateConstraints(Expr expr); void Solve(); Expr ResolveType(Expr expr); @@ -133,20 +132,13 @@ class TypeInferencer : private ExprFunctor, return ret; } - // Lazily get type for expr + // Lazily get type for GlobalVar // expression, we will populate it now, and return the result. Type GetTypeGlobalVar(const GlobalVar& expr) { + // make sure Function is typechecked auto f = Downcast(mod_->Lookup(expr)); Type ret = GetType(f); - auto it = type_map_.find(expr); - if (it != type_map_.end() && it->second.checked_type.defined()) { - return it->second.checked_type; - } - CHECK(ret.defined()); - KindCheck(ret, mod_); - ResolvedTypeInfo& rti = type_map_[expr]; - rti.checked_type = ret; - return ret; + return GetType(expr); } private: @@ -568,7 +560,6 @@ class TypeInferencer : private ExprFunctor, } return FuncType(c->inputs, TypeCall(c->belong_to, types), td->type_vars, {}); } - }; class TypeInferencer::Resolver : public ExprMutator, PatternMutator { @@ -716,10 +707,6 @@ Expr TypeInferencer::Infer(Expr expr) { return resolved_expr; } -void TypeInferencer::GenerateConstraints(Expr expr) { - GetType(expr); -} - void TypeInferencer::Solve() { solver_.Solve(); @@ -789,30 +776,26 @@ IRModule InferTypeAll(const IRModule& mod) { mod->AddUnchecked(var, func); } + // use dummy var, will be updated as we fill constraints/ + // solve for each GlobalVar TypeInferencer ti = TypeInferencer(mod, GlobalVar("dummy")); // second pass, fill in constraints for (const auto& var : globalvars) { - // relay::Function func = Downcast(mod->Lookup(var)); - // ti.SetCurrentFunc(var); - // ti.GenerateConstraints(func); + ti.SetCurrentFunc(var); ti.GetTypeGlobalVar(var); } - std::cout << "done generating constraints" << std::endl; - + // solve constraints ti.Solve(); - std::cout << "done solving" << std::endl; - - // third pass + // third pass, resolve types for (const auto& var : globalvars) { + ti.SetCurrentFunc(var); + relay::Function func = Downcast(mod->Lookup(var)); - // ti.SetCurrentFunc(var); Expr func_ret = ti.ResolveType(func); - std::cout << "resolved " << var << std::endl; - //Expr func_ret = ti.Infer(func); - std::cout << var << ": " << func_ret->checked_type() << std::endl; + // add function back to module mod->AddUnchecked(var, Downcast(func_ret)); } @@ -833,10 +816,11 @@ Pass InferTypeAll() { return CreateModulePass(pass_func, 0, "InferTypeAll", {}); } - TVM_REGISTER_GLOBAL("relay._transform.InferType").set_body_typed([]() { return InferType(); }); -TVM_REGISTER_GLOBAL("relay._transform.InferTypeAll").set_body_typed([]() { return InferTypeAll(); }); +TVM_REGISTER_GLOBAL("relay._transform.InferTypeAll").set_body_typed([]() { + return InferTypeAll(); +}); } // namespace transform From 58b05a6109c399f295dd408b08b793bdbea3a60f Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Wed, 19 Aug 2020 21:28:00 -0700 Subject: [PATCH 14/19] add even/odd test --- tests/python/relay/test_type_infer.py | 60 +++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 6de40f3fda19..ef4fc4f8434e 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -25,6 +25,7 @@ from tvm import relay from tvm.relay import op, transform, analysis from tvm.relay import Any +from tvm.relay.testing import add_nat_definitions def run_infer_type(expr, mod=None): if not mod: @@ -446,6 +447,65 @@ def body(var, call_func, type_param): tvm.ir.assert_structural_equal(mod[f_gv].checked_type, expected) tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) +def test_mutual_recursion_peano(): + # even and odd function for peano function + # even(x: nat) = match x { + # z => true + # s(a: nat) => odd(a) + # } + # odd(x: nat) = match x { + # z => false + # s(a: nat) => even(a) + # } + p = Prelude() + add_nat_definitions(p) + z = p.z + s = p.s + + even_gv = relay.GlobalVar('even') + odd_gv = relay.GlobalVar('odd') + + def create_even_func(odd_f): + x = relay.Var("x") + a = relay.Var("a") + true = tvm.nd.array(np.array(True)) + + body = relay.Match( + x, + [ + relay.Clause(relay.PatternConstructor(z), relay.Constant(true)), + relay.Clause(relay.PatternConstructor(s, [relay.PatternVar(a)]), relay.Call(odd_f, [a])) + ] + ) + func = relay.Function([x], body) + return func + def create_odd_func(even_f): + x = relay.Var("x") + a = relay.Var("a") + false = tvm.nd.array(np.array(False)) + + body = relay.Match( + x, + [ + relay.Clause(relay.PatternConstructor(z), relay.Constant(false)), + relay.Clause(relay.PatternConstructor(s, [relay.PatternVar(a)]), relay.Call(even_f, [a])) + ] + ) + func = relay.Function([x], body) + return func + + even_func = create_even_func(odd_gv) + odd_func = create_odd_func(even_gv) + + mod = p.mod + mod.add_unchecked(even_gv, even_func) + mod.add_unchecked(odd_gv, odd_func) + mod = transform.InferTypeAll()(mod) + + expected = relay.FuncType([p.nat], relay.TensorType((), dtype='bool')) + tvm.ir.assert_structural_equal(mod[even_gv].checked_type, expected) + tvm.ir.assert_structural_equal(mod[odd_gv].checked_type, expected) + def test_if(): choice_t = relay.FuncType([], relay.scalar_type('bool')) f = relay.Var('f', choice_t) From fbafb06321b61827ead5cc9240f875499596b570 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 20 Aug 2020 14:37:18 -0700 Subject: [PATCH 15/19] fix reolsving type conflict --- src/relay/transforms/type_infer.cc | 34 +++++++++++++++++++++++++-- tests/python/relay/test_type_infer.py | 2 +- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 1de8ba3ca4b1..dec00db895df 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -86,6 +86,31 @@ struct ResolvedTypeInfo { Array type_args = Array(ObjectPtr(nullptr)); }; +// helper class to dedup typevars of a type +// - types do not have to be already typechecked +// +// This is used to Dedup GlobalVar type to avoid +// incorrect type resolving across different usages +class DeDupType : public TypeMutator, public ExprMutator, public PatternMutator { + public: + TypeVar Fresh(const TypeVar& tv) { + TypeVar ret = TypeVar(tv->name_hint, tv->kind); + type_rename_[tv] = ret; + return ret; + } + + Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } + + Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } + + Type VisitType_(const TypeVarNode* op) final { + TypeVar v = GetRef(op); + return type_rename_.count(v) != 0 ? type_rename_.at(v) : Fresh(v); + } + + private: + std::unordered_map type_rename_; +}; // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -122,6 +147,11 @@ class TypeInferencer : private ExprFunctor, Type GetType(const Expr& expr) { auto it = type_map_.find(expr); if (it != type_map_.end() && it->second.checked_type.defined()) { + if (expr.as() != nullptr) { + // if we don't dedup GlobalVarNode, two functions that use the same GlobalVar + // may resolve to the same type incorrectly + return DeDupType().VisitType(it->second.checked_type); + } return it->second.checked_type; } Type ret = this->VisitExpr(expr); @@ -135,7 +165,8 @@ class TypeInferencer : private ExprFunctor, // Lazily get type for GlobalVar // expression, we will populate it now, and return the result. Type GetTypeGlobalVar(const GlobalVar& expr) { - // make sure Function is typechecked + // we have to visit functiion + // or else it may not be type-checked auto f = Downcast(mod_->Lookup(expr)); Type ret = GetType(f); return GetType(expr); @@ -320,7 +351,6 @@ class TypeInferencer : private ExprFunctor, this->ReportFatalError(match, ss); } } - return rtype; } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index ef4fc4f8434e..f343165b638a 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -502,7 +502,7 @@ def create_odd_func(even_f): mod.add_unchecked(odd_gv, odd_func) mod = transform.InferTypeAll()(mod) - expected = relay.FuncType([p.nat], relay.TensorType((), dtype='bool')) + expected = relay.FuncType([p.nat()], relay.TensorType((), dtype='bool')) tvm.ir.assert_structural_equal(mod[even_gv].checked_type, expected) tvm.ir.assert_structural_equal(mod[odd_gv].checked_type, expected) From 0cdcf2afbfc041c73fadf305d373bf501d15dc52 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 20 Aug 2020 20:30:08 -0700 Subject: [PATCH 16/19] remove type_params where possible --- src/relay/op/type_relations.cc | 1 + tests/python/relay/test_type_infer.py | 29 ++++++++++++++------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 0647ec9780f3..dd9464a3f946 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -118,6 +118,7 @@ bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& att CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] // << ",Out:" << types[2] << std::endl; + reporter->Assign(types[0], types[1]); if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index f343165b638a..fea82dc10499 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -370,8 +370,8 @@ def test_mutual_recursion(): # g(y) = if y > 0 then f(y - 1) else 0 tensortype = relay.TensorType((), 'float32') - x = relay.Var("x", tensortype) - y = relay.Var("y", tensortype) + x = relay.Var("x") + y = relay.Var("y") zero = relay.Constant(tvm.nd.array(np.array(0, dtype='float32'))) one = relay.Constant(tvm.nd.array(np.array(1, dtype='float32'))) @@ -401,27 +401,28 @@ def body(var, call_func): tvm.ir.assert_structural_equal(mod[g_gv].checked_type, expected) def test_mutual_recursion_adt(): - # f[A](x: List[A]) = match x { + # f[A](x: A) = match x { # Cons(a, Nil) => a # Cons(_, b) => g(b) # } - # g[B](y: List[B]) = match y { + # g[B](y: B) = match y { # Cons(a, Nil) => a # Cons(_, b) => f(b) # } p = Prelude() l = p.l - A = relay.TypeVar("x") - B = relay.TypeVar("y") - x = relay.Var("x", l(A)) - y = relay.Var("y", l(B)) + A = relay.TypeVar("A") + B = relay.TypeVar("B") + + x = relay.Var("x") + y = relay.Var("y") f_gv = relay.GlobalVar('f') g_gv = relay.GlobalVar('g') def body(var, call_func, type_param): - a = relay.Var("a") + a = relay.Var("a", type_param) b = relay.Var("b") body = relay.Match( var, @@ -431,11 +432,11 @@ def body(var, call_func, type_param): ], complete=False ) - func = relay.Function([var], body, type_params=type_param) + func = relay.Function([var], body, type_params=[type_param]) return func - f = body(x, g_gv, [A]) - g = body(y, f_gv, [B]) + f = body(x, g_gv, A) + g = body(y, f_gv, B) mod = p.mod mod.add_unchecked(f_gv, f) @@ -449,11 +450,11 @@ def body(var, call_func, type_param): def test_mutual_recursion_peano(): # even and odd function for peano function - # even(x: nat) = match x { + # even(x) = match x { # z => true # s(a: nat) => odd(a) # } - # odd(x: nat) = match x { + # odd(x) = match x { # z => false # s(a: nat) => even(a) # } From 7524af105d853949a3467e95ccf6aca5af9d38c0 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Thu, 20 Aug 2020 21:42:09 -0700 Subject: [PATCH 17/19] fix broadcastcomprel --- src/relay/op/type_relations.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index dd9464a3f946..169f7d388fec 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -118,7 +118,6 @@ bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& att CHECK_EQ(types.size(), 3); // DLOG(INFO) << "In1:" << types[0] << ",In2:" << types[1] // << ",Out:" << types[2] << std::endl; - reporter->Assign(types[0], types[1]); if (auto* t0 = types[0].as()) { if (auto* t1 = types[1].as()) { CHECK_EQ(t0->dtype, t1->dtype); @@ -127,6 +126,7 @@ bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& att return true; } } + reporter->Assign(types[0], types[1]); return false; } From 4dc95fcbe439b1720366ec86c7f8f1c60e1f338f Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 24 Aug 2020 11:50:13 -0700 Subject: [PATCH 18/19] fix changes --- include/tvm/relay/transform.h | 9 +++++++++ src/relay/analysis/type_solver.h | 2 +- src/relay/transforms/de_duplicate.cc | 27 +++++++++++++++++++++++++++ src/relay/transforms/type_infer.cc | 27 +-------------------------- 4 files changed, 38 insertions(+), 27 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 73ce05df6e63..4f57e956b0c9 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -503,6 +503,15 @@ TVM_DLL Function UnCPS(const Function& f); */ TVM_DLL Expr DeDup(const Expr& e); +/*! + * \brief Deduplicate the bound type variables in the type. + * + * \param e the type (does not have to be typechecked). + * + * \return the deduplicated type. + */ +TVM_DLL Type DeDupType(const Type& e); + } // namespace relay } // namespace tvm diff --git a/src/relay/analysis/type_solver.h b/src/relay/analysis/type_solver.h index e0748057a099..37d8e45bf279 100644 --- a/src/relay/analysis/type_solver.h +++ b/src/relay/analysis/type_solver.h @@ -66,7 +66,7 @@ class TypeSolver { TypeSolver(const GlobalVar& current_func, const IRModule& _mod, ErrorReporter* err_reporter); ~TypeSolver(); - void SetCurrentFunc(GlobalVar current_func) { this->current_func = current_func; } + void SetCurrentFunc(const GlobalVar& current_func) { this->current_func = current_func; } /*! * \brief Add a type constraint to the solver. diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index d90e5c584df3..6bc7873cfab8 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -99,6 +99,33 @@ Expr DeDup(const Expr& e) { return ret; } +// dedup bound type variables in type +// - types do not have to be already typechecked +Type DeDupType(const Type& e) { + class DeDupTypeMutator : public TypeMutator, public ExprMutator, public PatternMutator { + public: + TypeVar Fresh(const TypeVar& tv) { + TypeVar ret = TypeVar(tv->name_hint, tv->kind); + type_rename_[tv] = ret; + return ret; + } + + Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } + + Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } + + Type VisitType_(const TypeVarNode* op) final { + TypeVar v = GetRef(op); + return type_rename_.count(v) != 0 ? type_rename_.at(v) : Fresh(v); + } + + private: + std::unordered_map type_rename_; + }; + + Type ret = DeDupTypeMutator().VisitType(e); + return ret; +} TVM_REGISTER_GLOBAL("relay._transform.dedup").set_body_typed(DeDup); } // namespace relay diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index dec00db895df..3fc871435a6a 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -86,31 +86,6 @@ struct ResolvedTypeInfo { Array type_args = Array(ObjectPtr(nullptr)); }; -// helper class to dedup typevars of a type -// - types do not have to be already typechecked -// -// This is used to Dedup GlobalVar type to avoid -// incorrect type resolving across different usages -class DeDupType : public TypeMutator, public ExprMutator, public PatternMutator { - public: - TypeVar Fresh(const TypeVar& tv) { - TypeVar ret = TypeVar(tv->name_hint, tv->kind); - type_rename_[tv] = ret; - return ret; - } - - Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } - - Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } - - Type VisitType_(const TypeVarNode* op) final { - TypeVar v = GetRef(op); - return type_rename_.count(v) != 0 ? type_rename_.at(v) : Fresh(v); - } - - private: - std::unordered_map type_rename_; -}; // // The inference algorithm can roughly be devided into three stages: // - Populate the constraints by visiting the expression (TypeInferencer.GetType) @@ -150,7 +125,7 @@ class TypeInferencer : private ExprFunctor, if (expr.as() != nullptr) { // if we don't dedup GlobalVarNode, two functions that use the same GlobalVar // may resolve to the same type incorrectly - return DeDupType().VisitType(it->second.checked_type); + return DeDupType(it->second.checked_type); } return it->second.checked_type; } From e1c2c5787823f71f53796be0a74c665e05de5463 Mon Sep 17 00:00:00 2001 From: Andrew Liu Date: Mon, 24 Aug 2020 14:12:02 -0700 Subject: [PATCH 19/19] remove type relation for binary op --- src/relay/op/type_relations.cc | 1 - tests/python/relay/test_type_infer.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 169f7d388fec..0647ec9780f3 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -126,7 +126,6 @@ bool BroadcastCompRel(const Array& types, int num_inputs, const Attrs& att return true; } } - reporter->Assign(types[0], types[1]); return false; } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index fea82dc10499..73dcc2cf9621 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -369,9 +369,10 @@ def test_mutual_recursion(): # f(x) = if x > 0 then g(x - 1) else 0 # g(y) = if y > 0 then f(y - 1) else 0 tensortype = relay.TensorType((), 'float32') - - x = relay.Var("x") - y = relay.Var("y") + # we need to annotate with tensortype + # because binary op does add relations between operands + x = relay.Var("x", tensortype) + y = relay.Var("y", tensortype) zero = relay.Constant(tvm.nd.array(np.array(0, dtype='float32'))) one = relay.Constant(tvm.nd.array(np.array(1, dtype='float32')))