From e1c59a0c3e5071468025671133737ca5cdaac81a Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 20 Oct 2022 19:32:32 -0400 Subject: [PATCH 01/47] Initial spec --- relax_spec.md | 749 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 749 insertions(+) create mode 100644 relax_spec.md diff --git a/relax_spec.md b/relax_spec.md new file mode 100644 index 000000000000..40f69e4ba91c --- /dev/null +++ b/relax_spec.md @@ -0,0 +1,749 @@ +# Informal Relax Language Specification + +Note: Text in «double chevrons» indicates features not present in the current prototype. + +In order to develop and test Relax, it is important for compiler developers to agree on what a given program in Relax means and what makes it valid so that test cases can be evaluated independently of any particular Relax implementation. This document is intended to describe Relax's grammar constructs (its [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree), or AST), the semantics of its grammar (what the different constructs mean), Relax's type system and type-checking rules (what makes a Relax program valid), and its rules for reasoning about tensor shapes in detailed though still informal terms. If necessary, we may encode these rules more formally to allow for more automated analysis. + +Though this document will use the TVMScript front end for some examples, specifying the mapping from Python's AST to Relax's AST will be deferred until the parser becomes more stable. + +# Table of Contents + +1. [Overview](#overview) +2. [Top-Level Program Organization](#top-level-program-organization-irmodule) +3. [Values in Relax](#values-in-relax) +4. [Variable Scoping](#variable-scoping) +5. [Well-Formedness Criteria](#well-formedness-criteria) +6. [Types in Relax](#types-in-relax) +7. [Shapes in Relax](#shapes-in-relax) +8. [Semantics](#detailed-semantics) + +# Overview + +This section will outline the grammar of Relax and give very brief descriptions of the different components, including the semantics, type system, and shape system. The rest of this document will provide more detailed descriptions of these facets of the language, including the validity conditions that the type system and shape system uphold. + +## Grammar + +Below is a diagram of the various AST constructs in Relax, including types. In code, these are defined on the C++ side in `include/tvm/relax/{expr.h, type.h}` and in Python in `python/tvm/relax/{expr.py, ty.py}`. This diagram will give the names of the AST nodes and the types and names of their members. The semantics will describe what computation each construct represents; an AST is simply data. A Relax program consists of an `IRModule` with global variables bound to Relax functions that implement the computations of interest. + +(On the notation: `[x]` means "a list of `x`," `x?` means "optionally `x`," `{x: y}` means "a map of `x` to `y`," `x | y` means "`x` or `y`," and `#` is used for comments.) + +``` +# PrimExprs are defined in TIR, see include/tvm/tir/expr.h +# They are intended to have the same semantics as in TIR +PrimExpr ::= + Var(name: string) # shape variables + | IntImm(value: int64) + | Add(a: PrimExpr, b: PrimExpr) + | Sub(a: PrimExpr, b: PrimExpr) + | Mul(a: PrimExpr, b: PrimExpr) + | Div(a: PrimExpr, b: PrimExpr) + | Min(a: PrimExpr, b: PrimExpr) + | Max(a: PrimExpr, b: PrimExpr) + | Not(a: PrimExpr) + | And(a: PrimExpr, b: PrimExpr) + | Or(a: PrimExpr, b: PrimExpr) + | Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr) + # (others may be added later, as deemed necessary) + +Type ::= DynTensorType(ndim: int, dtype: DataType) + | ShapeType() + | ObjectType() + | TupleType(fields: [Type]) + | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») + +# expressions +Expr ::= Constant(data: NDArray) + # scoped to functions or SeqExprs + | Var(name_hint: string) + # scoped to DataflowBlocks + | DataflowVar(name_hint: string) + | GlobalVar(name_hint: string) + | Tuple(fields: [Expr]) + | SeqExpr(blocks: [BindingBlock], body: Expr) + | Function(params: [Var], body: Expr, ret_type: Type?, attrs: Attrs?) + | If(cond: Expr, true_branch: Expr, false_branch: Expr) + | ExternFunc(global_symbol: string) + | Call(op: Expr, args: [Expr], type_args: [Type], attrs: Attrs?) + | ShapeExpr(values: [PrimExpr]) + | TupleGetItem(tuple_value: Expr, index: int) + | Op(op_name: string) + | RuntimeDepShape() + +# binding blocks (analogous to sequence of statements) +BindingBlock ::= + BindingBlock(bindings: [Binding]) + | DataflowBlock(bindings: [Binding]) + +# bindings (analogous to statements) +Binding ::= + VarBinding(var: Var|DataflowVar, value: Expr) + | MatchShape(var: (Var|DataflowVar)?, pattern: [PrimExpr], value: Expr) + +# Relax programs are IRModules. Modules may bind global variables either to +# Relax functions or TIR PrimFuncs (specified separately). +# The Relax compiler may analyze and modify the TIR PrimFUncs as well. +Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) +``` + +## Expression Survey + +This specification provides a more detailed description of what each expression and type represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. + +1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). +2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. +3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchShape` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. +4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchShape` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." +5. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. + 1. For `ExternFunc` nodes, the call will look up the registered `PackedFunc` by its global symbol and will call it with the given arguments (note that a TIR `PrimFunc` can be compiled into a `PackedFunc` and called using `ExternFunc` by defining a `global_symbol` attribute in the `PrimFunc`). «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» + 2. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s) «and `cast` (performs dynamic type conversions).» + 3. Any other expression must evaluate to a closure; the closure will then be called with the given arguments. + + Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. + +6. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. +7. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. +8. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: + 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). + 2. `MatchShape`s: The `value` expression is evaluated and the resulting shape is dynamically checked against the shape denoted by the `PrimExpr`s in the `pattern` field. + 1. If `value` evaluates to a tensor value, the pattern will be checked against the shape of the tensor; if it evaluates to a shape value, the pattern will be checked directly against the shape. + 2. Any shape dimension in the pattern that consists of a single new shape variable is treated as a binding: The variable is bound to the size of the corresponding dimension of the value being matched. + 3. If the shapes do not match, an error is triggered. If there is a variable provided, the value is bound to the `var` expression (if the variable is omitted, the shape check is performed and any shape variables are updated, but no new binding is introduced). Shape variables introduced in a `SeqExpr` are similarly scoped to the `SeqExpr`. + + The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. + +9. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +10. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. + + The function can have shape annotations on the parameters and a return shape parameter. When the function is called, the annotations on parameters checked against the argument values in similar fashion to `MatchShape` and can introduce new shape variables that are scoped to the function. + + «A function mapped bound to a `GlobalVar` can have a `global_symbol` attribute defined to indicate that it should be externally linked externally (be accessible outside the `IRModule`). The absence of a `global_symbol` attribute on a function definition bound to a `GlobalVar` indicates that it is "private" and hence can be called only within the `IRModule`.» + +11. `RuntimeDepShape` nodes are used to denote that a shape is unknown at compile time and must be deduced at run time. These nodes may appear only in shape annotations and have no run-time semantics of their own. + +## Purity and Dataflow Blocks + +A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. «In Relax, we conservatively assume that any function that calls an impure function is itself impure, though the attribute `force_pure` on a function can be used as an override (e.g., if a function creates a new tensor, mutates it, and returns it, that is still pure but does not satisfy the conservative rule).» + +Above, it is mentioned that `DataflowBlock`s are not allowed to contain constructs featuring control flow (`If` nodes or recursive calls to the current function) or calls to impure functions. This ensures that `DataflowBlock`s represent a directed acyclic graph of pure operations, which is similar to the graph-like abstractions of traditional deep learning frameworks. This allows many common optimizations from past frameworks to be directly adapted to `DataflowBlock`s without having to accommodate additional reasoning about more expressive features like control flow and side effects. + +There is one visible side effect that Relax permits inside otherwise "pure" functions, namely exiting the program with an error. This can arise in the following cases: + +- Shape matching errors (from `MatchShape` or from implicit shape checks upon calling a Relax function) +- Errors raised by otherwise pure Relax operators or `PackedFunc`s, such as in `cast` (which dynamically checks types). Since the purity of operators or `PackedFunc`s must be manually registered, this means that it is permissible to register an operator or `PackedFunc` as being pure if its only side effect is issuing an error in some cases. + +Even though an abnormal program exit is a visible side effect and removing or reordering it changes the observable semantics, it would be too great a restriction to prohibit error checking inside `DataflowBlock`s. Relax does not have any notion of exception handling, so the only consequence of a failed safety check can be exiting the program. It is permissible for the compiler to reorder, duplicate, or eliminate `MatchShape`, `cast`, or otherwise pure operations that have the potential of failing, provided that doing so does not change the value returned by the program or any other visible behavior. + +To indicate that an operator or `PackedFunc` that can abort with an error should *never* be reordered or removed by the compiler, it should *not* be marked as pure. However, this means that it cannot be used inside a `DataflowBlock`. + +Note that in some programming languages like Koka, non-termination is also considered a side effect, since it can in some sense be "observed" by a user and affects the visible behavior of a program (e.g., if there is an infinite loop before a print statement, the print will never happen). However, since non-termination cannot be automatically detected in general and is unlikely to arise in deep learning models, we do not attempt to systematically track non-termination in Relax. In general, the Relax compiler is allowed to reorder or remove otherwise pure function calls even if they may not terminate. For example, if a pure function `f` that returns an integer scalar does not terminate, it is permissible in principle to rewrite `f() - f()` to 0. + +Exiting with an error and infinitely looping are traditionally considered "[divergence](https://en.wikipedia.org/wiki/Divergence_(computer_science))" in the programming languages literature. As a general principle, Relax's compiler is permitted to turn a program that diverges into a program that does not diverge (provided that no other visible effects change) so long as it never transforms a program that does not diverge into one that diverges. + +## Type System Survey + +The types in Relax correspond to the broad categories of the values given above: + +1. `DynTensorType` corresponds to tensor values, giving the scalar data type and the number of dimensions (rank), both of which are optional. +2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. +3. `ShapeType` corresponds to shape values. +4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» +5. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. + +The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» + +## Shape System Survey + +In Relax, tensor shapes are not handled in the type system; each expression instead a has an associated shape expression. In many cases, these shape computations can allow for statically concluding that two shapes are the same and thus eliminate the need for dynamic checks via `MatchShape`. However, when shapes cannot be statically concluded to be the same, it may be necessary for there to be dynamic checks. The compiler is also free to make use of shape expressions for memory planning purposes. «Relax is "strongly shaped," meaning that if the compiler cannot conclude that shapes match in certain cases, an error will be issued and an explicit `MatchShape` will be required.» + +--- + +# Top-level Program Organization: `IRModule` + +As with Relay, the top level of organization for a Relax program is an `IRModule`. An `IRModule` contains mappings of global variables to functions, both Relax functions as well as TIR functions (which can be called from Relax). The global function called `main` is usually considered the entry point to the program (meaning that execution starts by calling that function), though any function with a `global_symbol` attribute can be specified as the entry point during compilation. In the AST (see below), the names of Relax functions in the `IRModule`s are `GlobalVar` nodes. + +Oftentimes, compiler passes operate only on particular functions or add new functions to the `IRModule`, but a pass can operate over the entirety of a Relax program by iterating through all the functions in an `IRModule`. + +# Values in Relax + +Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. + +- *Tensors* are n-dimensional arrays of scalar values (which can be integers of fixed bitwidths, floats of fixed bitwidths, or bools). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. +- *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations return no value (as may be the case in some `PackedFunc` or operator calls that have side effects). +- *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time type information* (RTTI) indicating their argument types and result type, in order to facilitate dynamic type checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTTI is left up to the compiler implementation to determine so long as the `cast` operator can verify the type of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» +- *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values of type `ObjectType`. + +## Representation of Values at Run Time + +Because Relax supports calls to arbitrary `PackedFunc`s that can operate on a low level, it is necessary to define a convention for how values will be represented at run time. At this time, the specification does not require any specific representation and permits compiler implementations to choose their own representations, provided that each value type listed above can be recognized at run time (for dynamic type checks). This means that Relax programs that call `PackedFunc`s directly are not portable across compiler implementations: The `PackedFunc`s used must be able to operate on the run-time representations of values. + +Possible specification in terms of the TVM object system: + +- Tensors are represented at run time as `NDArray`s (see `include/tvm/NDArray.h`). +- Tuples are represented using TVM ADTs (algebraic data types), which are arrays of TVM objects with a tag (see `include/tvm/runtime/container/adt.h`). Tuples use a tag of 0. +- At run time, closures are represented as a `ClosureObj` (see `include/tvm/runtime/container/closure.h`); in the Relax VM these more specifically use the `VMClosureObj` (see [`https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h`](https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h)). +- Shape values are represented at run time as a `ShapeTuple` (see `include/tvm/runtime/container/shape_tuple.h`). +- We require objects other than the above values used by and returned by `PackedFunc` to inherit from TVM's `Object` class (defined in `include/tvm/runtime/Object.h`). Note that `PackedFunc`s are capable of using and returning all TVM POD (plain-old data) values (see `include/tvm/runtimes/packed_func.h`), which includes some representations that do not inherit from `Object`. In the future, we may define semantics for other values, but at present, these are *unsupported* in Relax and we make no guarantees about the semantics of calling `PackedFunc`s that use or return anything that does not inherit from `Object`. + +# Variable Scoping + +There are four relevant scopes in Relax, which determine where variables are visible and can be used: + +1. Global: `GlobalVar`s can be referenced from any function in the `IRModule`, whether a Relax function or a TIR `PrimFunc`. All global functions are visible to each other and to themselves, allowing for mutual recursion. +2. Function: The parameters to a function (ordinary `Var` nodes) can be referenced from anywhere in that function. In a recursive binding (a `Binding` node where the RHS is a `Function` node or `GlobalVar` being mapped to a function at the `IRModule` level), the variable being bound is also scoped to that function, allowing for defining a recursive function. +3. `SeqExpr`: `Var` nodes defined in a `BindingBlock` in a `SeqExpr` node can be referenced in any later binding within the same `BindingBlock`, in any binding within any later `BindingBlock` in that `SeqExpr` node, or in the `SeqExpr`'s body expression. The variables defined in the `BindingBlock`s leave scope once the `SeqExpr` returns. +4. `DataflowBlock`: `DataflowVar`s introduced in a `DataflowBlock` can be referenced in any later binding within that `DataflowBlock`, but leave scope *once that `DataflowBlock` finishes executing*. Definitions in a `DataflowBlock` that are intended to leave the `DataflowBlock` should be bound to an ordinary `Var`. + +# Well-Formedness Criteria + +Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid. + +1. `DataflowVar`s can be bound only inside `DataflowBlock`s. Additionally, a `DataflowVar` may not be used outside of the `DataflowBlock` in which it is defined. +2. A `Var` of any kind used in the program must be either a function parameter or appear on the LHS of a binding exactly once. In the binding where a `Var` is defined, the same `Var` is permitted to occur in the RHS of the binding only if the binding is defining a function (i.e., local functions are permitted to be recursive). +3. A `Var` of any kind may not appear before it is bound. Namely, if a `Var` is bound in a `BindingBlock` in a `SeqExpr`, that `Var` may not appear in bindings that precede the one where it appears on the LHS. +4. «A return shape annotation for a function is not allowed to use any shape variables that are not in scope at the function definition. That is, the only shape variables that can appear on the return shape annotation are those defined in the outer scope or those introduced in the argument shape annotations.» +5. In each function, `PrimExpr` variables (shape variables) similarly may not appear in `ShapeExpr`s or shape annotations before the shape variables are bound (either in function signatures or `MatchShape` bindings). A shape variable is bound only when it appears in a dimension by itself (for example, a dimension consisting of `x` will bind `x`; however, `2*x` is not a binding and is considered an error if `x` has not yet been bound) in a `MatchShape` node or a function argument shape annotation. +6. The following constructs are not permitted to occur inside `DataflowBlock`s, which must be side effect– and control flow–free: + 1. Recursive calls to the current function + 2. Calls to a global function that is mutually recursive with the current function + 3. `If` nodes + + «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during type checking.» + +7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return type annotation is *required*. [TODO: Do we also require a return shape annotation in such cases?]» +8. `Op` nodes may appear only as the `op` argument to `Call` nodes. +9. `ExternFunc` expressions may appear only as the `op` argument to `Call` nodes. +10. The `type_args` field is used only for type checking calls of `ExternFunc`s and `Op`s. Calls to `ExternFunc`s must have exactly one type argument, indicating the return type. Calls to `Op`s may use `type_args` as they wish. No other calls may have a non-empty `type_args`. +11. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. +12. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. +13. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» +14. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» +15. «Any `PackedFunc` or operator called in a shape annotation or `shape_` expression must be pure and be annotated as such.» +16. The node `RuntimeDepShape` may appear only in shape annotations and `shape_` expressions. It has no defined semantics at run time. + +# Types in Relax + +Relax presently has five types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: + +1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. +2. `ShapeType`, referring to shape values. +3. `FuncType`, referring to functions (closures). `FuncType`s specify the types of their parameters, a return type, and whether the function is pure. +4. `TupleType`, referring to tuple values, giving the types of their fields. +5. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. + +## Subtyping + +Relax implements subtyping, which means that members of types can be accepted where members of their supertypes are accepted. We will denote the subtyping relationship as `T1 <: T2`, indicating that `T1` is a subtype of `T2`. For example. if `T1 <: T2` and some function expects an argument of type `T2`, then passing a member of type `T1` to that function is permitted; passing a member of type `T2` as an argument to a function that expects type `T1` for that argument is *not* permitted—the value would have to be dynamically cast to `T1` using the `cast` operator. + +### Rules for Subtyping + +1. Reflexivity: For all types `T`, `T <: T`. +2. Transitivity: For all types `T1`, `T2`, and `T3`, if `T1 <: T2` and `T2 <: T3`, then `T1 <: T3`. +3. For all types `T`, `T <: ObjectType`. Hence, `ObjectType` is a supertype to all Relax types (all values in Relax are members of `ObjectType`). +4. Rules for `DynTensorType`: + 1. For all fixed `ndim` values `m`, where `m` ≥ 0, and `dtype`s `d`, `DynTensorType(ndim=m, dtype=d) <: DynTensorType(ndim=m, dtype=Void)`. + 2. For all fixed `ndim` values `m` and `dtype`s `d` that are not `Void`, `DynTensorType(ndim=m, dtype=d) <: DynTensorType(ndim=-1, dtype=d)`. + 3. Corollary: `DynTensorType(ndim=-1, dtype=Void)` is a supertype to all tensor types, since it refers to any possible tensor value. +5. Suppose we have types `T1 <: T1'`, `T2 <: T2'`, …, `Tn <: Tn'`. Then `TupleType(fields=[T1, T2, ..., Tn]) <: TupleType(fields=[T1', T2', ..., Tn'])`. +6. Rules for `FuncType`: + 1. Impure functions are supertypes to pure functions. Namely, if we have types `T1`, `T2`, …, `Tn` and `Tr`, then `FuncType(arg_types=[T1, T2, ..., Tn], ret_type=Tr, pure=True) <: FuncType(arg_types=[T1, T2, ..., Tn], ret_type=Tr, pure=False)`. + 2. Suppose we have types `T1' <: T1`, `T2' <: T2`, …, `Tn' <: Tn` and `Tr <: Tr'`. Then `FuncType(arg_types=[T1, T2, ... Tn], ret_type=Tr, pure=p) <: FuncType(arg_types=[T1', T2', ..., Tn'], ret_type=Tr', pure=p)`. Note the direction of the subtyping relationships for the argument and return types: We must be able to *call* this function with the *same* arguments and *use the returned value* wherever it is accepted—hence a function that takes more general arguments and returns a more specific return value can be used in place of the original. + +These rules allow us to define the least upper bound (LUB) for any two types `T1` and `T2`, meaning that it is the most specific type `T` for which `T1 <: T` and `T2 <: T` ("most specific" meaning that if there exists some other `T'` for which `T1 <: T'` and `T2 <: T'`, then `T <: T'`). The LUB is guaranteed to exist for any two types because `Object` is a supertype to all types. + +Note that the rule for obtaining the LUB of function types relies on the counterpart to the LUB, the greatest lower bound (GLB). The GLB is not guaranteed to exist for any two types in Relax, as there is no single type that is a subtype of all others. + +We can give an algorithm for determining the LUB and GLB for two types, in pseudocode: + +```python +def find_glb(T1 : Type, T2 : Type) -> Type?: + if T1 == T2: # syntactic equality + return T2 + if T1 is ObjectType: + return T2 + if T2 is ObjectType: + return T1 + if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType: + return None + if T1 and T2 are both DynTensorType: + ret_ndim = T1.ndim + ret_dtype = T1.dtype + if ret_ndim == -1: + ret_ndim == T2.ndim + if ret_dtype == Void: + ret_dtype = T2.dtype + if ret_ndim != -1 and T2.ndim != ret_ndim: + # mismatch, so there's no common lower bound + return None + if ret_dtype != Void and T2.dtype != ret_dtype: + return None + return DynTensorType(ret_ndim, ret_dtype) + if T1 and T2 are both TupleType: + if they do not have the same length: + return None + fields = [] + for field1, field2 in zip(T1.fields, T2.fields): + glb = find_glb(field1, field2) + if glb is None: + return None + fields.append(glb) + return TupleType(fields) + if T1 and T2 are both FuncType: + «if they are not both pure or both impure:» + «return None» + purity = T1.purity + if they do not have the same arity: + return None + # mutual recursion with finding the LUB + arg_types = [ + find_lub(arg_type1, arg_type2) + for arg_type1, arg_type2 in zip(T1.arg_types, T2.arg_types) + ] + ret_type = find_glb(T1.ret_type, T2.ret_type) + if ret_type is None: + return None + return FuncType(arg_types, ret_type, purity) + +def find_lub(T1 : Type, T2 : Type) -> Type: + if T1 == T2: # syntactic equality + return T1 + if T1 or T2 is ObjectType: + return Object + if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType: + return ObjectType + if T1 and T2 are both DynTensorType: + res_ndim = T1.ndim + res_dtype = T1.dtype + if T1.ndim != T2.ndim: + res_ndim = -1 + if T1.dtype != T2.dtype: + res_dtype = Void + return DynTensorType(res_ndim, res_dtype) + if T1 and T2 are both TupleType: + if they do not have the same length: + return ObjectType + return TupleType([ + find_lub(field1, field2) + for field1, field2 in zip(T1.fields, T2.fields) + ]) + if T1 and T2 are both FuncType: + «purity = (True iff they're both pure)» + if they do not have the same arity: + return ObjectType + arg_types = [] + for arg_type1, arg_type2 in zip(T1.arg_types, T2.arg_types): + # potential mutual recursion + glb = find_glb(arg_type1, arg_type2) + if glb is None: + return ObjectType + arg_types.append(glb) + return FuncType(arg_types, find_lub(T1.ret_type, T2.ret_type), «purity») +``` + +### When Type Conversions are Necessary + +For two types `T1` and `T2`, if `T1 <: T2`, then a value of type `T1` can be passed anywhere a value of type `T2` is expected without any need for type conversions or dynamic checks. + +*However*, if `T1 <: T2`, then passing a value of type `T2` where `T1` is expected can only be done if there has been some kind of dynamic check or conversion of that value. «Relax is *strongly* *typed*, meaning that the compiler will give an error in this situation and require an explicit conversion via the `cast` operator, which inspects the value's run-time representation and exits the program with an error message if the value is not a subtype of T1.» + +If `T1` is not a subtype of `T2` and `T2` is not a subtype of `T1`, then it is always a type error to pass a value of either type where a value of the other is expected (no member of either type can be a member of the other). + +## Type Checking Rules + +The type checking rules for Relax are relatively simple and allow in some cases for types to be inferred without user annotations. Below, we describe how the types for each expression can be derived and when type checking should return an error. + +Let us consider a typing context `Γ`, which is a map of variables to types. + +1. «We type check the entire `IRModule` one function definition at a time. To handle mutual recursion, we prepopulate `Γ` with the annotated types of all global functions that are called mutually recursively. We then proceed to check the types of the global functions one at a time.» +2. Given a variable `v`, if `v` is in `Γ`, then we return `Γ[v]`. Otherwise, it is an error. +3. Given a constant expression `Constant(data)`, the type is `DynTensorType(ndim=n, dtype=d)` where `n` is the number of dimensions in `data` and `d` is its data type (recall that `data` is an `NDArray` literal). +4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType`. +5. The type of a `RuntimeDepShape` expression is `ShapeType`. +6. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. +7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: + 1. If `op` is a Relax `Op` node, then we look up its registered `FInferType` property. `FInferType` is a macro that takes in the `Call` node and produces a type. We return the type `op.FInferType(Call(op, [a1, ..., an], type_args=[aT]))`. The implementation of `FInferType` is free to throw errors. + 2. If `op` is `ExternFunc`, then use the sole member of `type_args` (calls to `ExternFunc`s are required to have exactly one `type_args` member) `aT` as the return type. Packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function itself to do any validation. + 3. Otherwise, check the types of the subexpressions, left to right. Suppose `op` has the type `Tf`, `a1` has type `T1`, …, and `an` has type `Tn`. If `Tf` is not a function type with exactly `n` arguments, we consider it a type error and require an explicit cast. Suppose `Tf` has argument types `T1'`, `T2'`, …, `Tn'`. Consider it a type error if any of the following does not hold: `T1 <: T1'`, `T2 <: T2'`, …, or `Tn <: Tn'`. Then the return type is `Tf.ret_type`. +8. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. +9. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» +10. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. + 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. + 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» + 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. + 3. For each `MatchShape(v: T, e, shape_pattern)`, where `T` is an optional type annotation, let the checked type of `e` be `T'`. + 1. If `T'` is `ShapeType`, then emit an error if `T` is not a supertype of `ShapeType`. Add `v` to `Γ` with type `T`. + 2. If `T'` is `DynTensorType`: + 1. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. Let the datatype of `T'` be `d`. + 2. If `T` is not a supertype of `DynTensorType(ndim=len(shape_pattern), dtype=d)`, then emit an error. If `T` is a subtype of that type, emit an error and request a cast. + 3. Add `v` to `Γ` with type `T`. + 3. If `T'` is `ObjectType`, then the only type we can conclude for `v` is `ObjectType`. If `T` is not `ObjectType`, emit an error and request a cast. + 4. If `T'` is `TupleType` or `FuncType`, emit a type error. + 2. If the current block is a `DataflowBlock`, remove any `DataflowVar`s from `Γ` after we have finished processing the bindings in that block. + 3. Finally, the type of the `SeqExpr` is the type of `body` checked under the current `Γ`. Afterwards, remove all bindings added from the `b0`, `b1`, …, `bn` from `Γ`. +11. Let us consider a function `Function(v1 : T1, v2 : T2, ..., vn : Tn, attrs=a) -> Tr: body`. + 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, where `p` is `True` if a `pure` attribute is included and `False` otherwise. Remove `fv` from `Γ` before returning. + 2. Add `v1` to `Γ` with type `T1`, `v2` with type `T2`, …, and `tn` with type `Tn`. Recursively type check `body` in this new context: + 1. «Determining purity: If `body` contains any call to a function whose return type does not specify that it is pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not specify the `pure` attribute, then we consider the function to be (potentially) impure. If all calls are to functions whose return type specifies purity or that include the `pure` attribute on the call or `Op`, then the function is treated as pure.» + 2. «Suppose the purity defined in the previous step is `p'`. Suppose the annotated function purity (in the attributes) is `p`. If `p'` is false while `p` is true, then it is a type error; if `p` was omitted, use `p'` for `p`.» + 3. «If the function has the attribute "`force_pure`," then consider `p` to be true, even if the check above judged the function not to be pure. The compiler may emit a warning in this situation.» + 4. Suppose the result of type-checking `body` is `Tr'`. If the current function is not recursive or a mutually recursive global function and `Tr` was omitted, consider the function to have type `FuncType([T1, T2, ..., Tn], Tr', pure=p)`. If `Tr' <: Tr`, then we consider the function to have type `FuncType([T1, T2, ..., Tn], Tr, pure=p)`. «If `Tr <: Tr'`, return an error and require an explicit cast in the function body.» If `Tr'` is not a subtype of `Tr` and `Tr` is also not a subtype of `Tr'` (meaning a dynamic cast cannot succeed), this is an error. + 5. Remove `v1`, `v2`, …, and `vn` from `Γ` before returning. + +# Shapes in Relax + +In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. In Relax, to allow for greater flexibility for variable-shape tensors and make it easier to implement new operators, shapes can be checked at run time. Though every expression in Relax has a shape associated with it just as expressions also have types, there is no requirement that the shape be expressed at compile time. Instead, the compiler merely requires that an expression's shape define *a way* to compute a fully specified shape at run time. Users have the ability to make use of shape variables and arithmetic expressions to encode a wide variety of shape constraints that can be checked dynamically. + +Nevertheless, in many cases, these shapes can be analyzed at compile time (particularly when they are consist of constants or deducible variables) to facilitate compile-time optimization much like is possible with Relay or TIR. Through constant propagation, function inlining, and other partial evaluation–like transformations, we can potentially eliminate many more dynamic checks by allowing some shape computations to be simplified at compile time. + +## Defining Shape Computations + +In Relax, each expression has an associated shape computation, which defines how that expression's shape can be computed based on the shapes of its subexpressions. We will refer to this computation as `shape_`, as that is what it is called in the implementation. This essentially serves as a mechanism for propagating shape annotations on variable bindings and function definitions to other expressions and enable more compile-time analysis of shapes. In particular, `shape_` is useful for memory planning. These computations can also be used to simplify shape checking and eliminate many dynamic checks. + +### Expressing Dimensions + +A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds to a dimension. The use of TIR `PrimExpr`s for shape dimension allows shape computations to express complex constraints that include variables and integer arithmetic expressions in addition to just constant dimensions. + +**Scope of Shape Variables** + +Shape variables can be introduced in two places in a Relax program: In a function signature, where they may be included with the argument shapes and return shape annotations, or in `MatchShape` bindings. Shape variables used in the function signature are scoped to the entire function in which they appear. Shape variables used in `MatchShape` bindings are scoped only to the `SeqExpr` in which they appear. + +**Informal Semantics of `PrimExpr`s for Dimensions** + +1. Shape variables can be bound to a value exactly once: at the start of a function for shape annotations on function arguments, in `MatchShape` bindings, or before a function returns (for shape variables on the return type). In particular, matching a `PrimExpr` consisting only of an uninitialized shape variable is treated as its binding (see below on `MatchShape`). After a shape variable has been bound for the first time, future uses of it will refer to the same value. +2. It is not legal to use a shape var that has not yet been bound. This results in an error at run time, though most cases can be detected at compile time. +3. «Local functions will "capture" defined shape variables from the parent scope with their present values in the resulting closure.» +4. If all variables in the `PrimExpr` are defined, `PrimExpr` arithmetic will generally be evaluated according to the semantics of TIR. + +### Evaluating `MatchShape` + +`MatchShape` allows for binding shape variables in Relax. It can be used with either tensor values or shape values, and in both cases the evaluation of the `PrimExpr`s proceeds similarly. + +1. Evaluating `MatchShape(v, t, s)`, where `t` is a tensor value and `s` is a list of `PrimExpr`s corresponding to shape dimensions: + 1. Suppose `s` is `(p1, p2, ..., pn)` , where each variables is a `PrimExpr`. We evaluate `p1`, then `p2`, and so, in that order according to the following rules (corresponding to the `i`th dimension): + 1. If the current `PrimExpr` consists only of an uninitialized shape variable, we bind the shape variable in that scope to the concrete value of the `i`th dimension of the value of `t`. + 2. Evaluate the current `PrimExpr` and compare it to the concrete value of the `i`th dimension of `t`. Raise an error if they do not match. + 2. If `v` is provided, bind `t` to `v` (see the general semantics for how that should be implemented). +2. Evaluating `MatchShape(v, S, s)`, where `S` is a shape value proceeds identically to the above, except the `PrimExpr`s are compared to the `i`th element of `S`. + +### General Shape Computation Grammar + +Shape computations can consist of the following expressions, which are a subset of general Relax `Expr`s: + +``` +ShapeCompExpr ::= ShapeExpr(dims: [PrimExpr]) + | RuntimeDepShape() + | Tuple(fields: [ShapeCompExpr]) + | Call(op: Op|ExternFunc, args: [Var|Constant]) + | TupleGetItem(tuple_value: ShapeCompExpr, index: int) +``` + +The shape expressions can be interpreted as follows: + +- `ShapeExpr` describes the shape of a tensor as a list of dimensions +- `Tuple` describes the shapes of each member of a tuple +- `TupleGetItem` describes the shape of a member of a tuple +- `Call` describes the shape of a function (or operator) call return value in terms of its arguments +- `RuntimeDepShape` describes shapes that are unknown at compile time (like when a shape annotation is omitted) or the shapes of values that don't have shapes (like shapes themselves, paradoxically: they *are* shapes but do not *have* shapes). + +The `PrimExpr`s in a `ShapeCompExpr` can reference the same shape variables as in shape annotations, with the same semantics. + +**Restrictions** + +Shape computations are allowed to include calls to operators and even `PackedFunc`s, but these operators and `PackedFunc`s *must* be pure. Shape computations are primarily used for memory planning and it is at the compiler's discretion when, if ever, to evaluate them (except as described below), hence they must not have side effects. + +**Shape Annotations** + +For shape annotations, we use `ShapeCompExpr` as the grammar, as with `shape_` expressions. `ShapeExpr` is used to annotate shapes of tensor values, `Tuple` is used to annotate the shapes of tuple values, and `RuntimeDepShape` is used to indicate annotations that have been omitted or shapes that cannot be known at compile time (like the shapes of tensors whose rank is unknown at compile time). `Call` is used to annotate the shapes of calls to operators and `TupleGetItem` annotates the shapes of tuple indices. + +For example, suppose we have a tuple where some fields are tensors like the following: + +```python +x : Tuple(Tensor((m, n), "int32"), Tuple(), Tensor((), "int32"), Tensor(_, "int32")) = ... +``` + +It has the shape annotation + +```python +Tuple([ShapeExpr([m, n]), Tuple([]), ShapeExpr([]), RuntimeDepShape]) +``` + +Note that it is [a well-formedness requirement](https://www.notion.so/Informal-Relax-Language-Specification-d1fdedb8fae84f0d82b9f880f25e7370) that if any field in a type has a `ShapeExpr` annotation, it must be a `DynTensorType` with an `ndim` matching the number of dimensions in the `ShapeExpr`. For example, in the above function signatures, the `ndim` in the type annotations must be 2. + +### Assigning Shape Variables at the Start and End of a Function + +Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchShape`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchShape` bindings. Suppose a function has the following signature, where the `Ti` are type annotation and the `Si` are shape annotations: + +```python +def f(arg1 : (T1, S1), arg2 : (T2, S2), ..., argn : (Tn, Sn)) -> (Tr, Sr): + return body +``` + +This can be treated as a macro that expands to + +```python +def f(arg1 : T1, arg2 : T2, ..., argn : Tn) -> Tr: + check_annotation(arg1, T1, S1) + check_annotation(arg2, T2, S2) + ... + check_annotation(argn, Tn, Sn) + ret_var = body + check_annotation(ret_var, Tr, Sr) + return ret_var +``` + +Because `MatchShape` is defined only for tensor and shape values, we must use a macro to handle other possible types that may be passed into a function, given here in pseudocode: + +```python +def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: + if s is a ShapeExpr: + tmp = fresh_var() + # type checking should ensure that e is always a tensor + return SeqExpr( + [BindingBlock([MatchShape(tmp, e, s.dims)])], + tmp + ) + else if s is a Tuple: + # type checking should ensure that e is always a tuple and the lengths match + shapes = s.fields + tmp = fresh_var() + return SeqExpr( + [BindingBlock([ + VarBinding(tmp, e), + # recursive in case we have nested tuples + VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, 0), shapes[0])), + VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, 1), shapes[1])), + ..., + VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, n-1), shapes[n-1])) + ])], tmp + ) + else if s is a Call: + tmp = fresh_var() + return SeqExpr( + [BindingBlock([ + VarBinding(tmp, e), + # completely dynamic check that does not assign shape vars. + VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) + ])], tmp + ) + else if s is TupleGetItem: + val = s.tuple_value + if val is Tuple: + return check_annotation(e, val.fields[s.index]) + # otherwise, evaluate it + return SeqExpr( + [BindingBlock([ + VarBinding(tmp, e), + VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) + ])], tmp + ) + else if s is RuntimeDepShape: + # no need to check + return e +``` + +### Evaluating Shape Expressions + +Every shape expression in the program (`shape_`) is associated with a program expression. Other than in the above procedure for checking function parameter shapes and the return shape, the specification does not guarantee that any `shape_` expression will ever be evaluated or how many times it may be evaluated; `shape_` is intended primarily for the benefit of memory planning. Hence, all `shape_` expressions must be pure and must be guaranteed to terminate. The `shape_` for a given expression `e` is intended to be evaluated *before* `e`. + +Shape expressions follow the same evaluation rules as general program expressions. In particular, shape functions are permitted to reference any variable that is in scope at the point of its associated expression; i.e., when evaluated, they form closures that capture any free variables (Relax variables and shape variables) referenced in their body. The `RuntimeDepShape` expression has no semantics at run time and indicates a shape that cannot be predicted in advance. If a `RuntimeDepShape` is encountered at any point while dynamically checking a shape match (see the `check_annotation` procedure above), it should "short-circuit" the match and cause the match to succeed immediately. + +### Building Up `shape_` for Each Expression + +For each expression type, we can recursively build up an associated `shape_` expression according to the following rules: + +1. For `Constant(value)`, the `shape_` expression is a `ShapeExpr` corresponding to the concrete shape of `value`. For example, for `Constant(1)`, `shape_` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape_` is `ShapeExpr([2])`. +2. For `Tuple(fields)`, `shape_` can be defined as `Tuple([field.shape_ for field in fields])`. +3. For `ShapeExpr`s, `shape_` is `RuntimeDepShape`. +4. `RuntimeDepShape` expressions should appear only in shape expressions; their `shape_` is not defined. +5. For `If(cond, true_branch, false_branch)`, we compare the `shape_` of `true_branch` and `false_branch`. If these can be proven equivalent (by a method that the compiler implementation is free to determine), then the `If` node's `shape_` is that shape. If they do not match, then we set it to `RuntimeDepShape`. +6. For `SeqExpr`, we set the `shape_` to be the `shape_` of the body expression. The `shape_` must respect the scoping rules for the `SeqExpr`: If the `shape_` of the body expression contains shape variables not defined in the outer scope (i.e., shape variables that are scoped to the `SeqExpr` only) or if the `shape_` contains any `Var`s or `DataflowVar`s scoped to the `SeqExpr`, use `RuntimeDepShape` as the shape. +7. For handling variable bindings: + 1. For the arguments to a function, set the `shape_` to the annotated shape. If the annotation is omitted, use `RuntimeDepShape`. + 2. In the general `VarBinding(v, e)`, if `v` does not have a shape annotation or the annotation is `RuntimeDepShape`, then we set the `shape_` of `v` to the `shape_` of `e`. If `v` has a shape annotation, then if the `shape_` of `e` can be proven equivalent to the shape annotation, use the shape annotation for the `shape_` of `v`. «Otherwise, give an error and require an explicit `MatchShape`.» + + It is up to the compiler implementation to decide what method to use for attempting to prove equivalence. + + 3. For bindings where the RHS is a function literal or assigning the `shape_` of a `GlobalVar`, see the rule for `Function` nodes. + 4. For `MatchShape(var, value, shape)`, we set the `shape_` of `var` to `shape`, as it will be dynamically checked. +8. For `TupleGetItem(tuple_value, i)`, we examine the `shape_` of `tuple_value`; suppose it is `s`. If `s` is a `Tuple`, then we use its `i`th field. If it is `RuntimeDepShape`, we use `RuntimeDepShape`. If it is a `Call` to a function that returns a tuple with at least `i + 1` members, set the `shape_` to `TupleGetItem(s, i)`. Otherwise, raise an error at compile time (though this should not happen if type checking has passed). +9. For `Call` nodes: + 1. For a call to an `ExternFunc`, we use `RuntimeDepShape` because we cannot analyze the shapes of arbitrary `PackedFunc`s and must check dynamically. + 2. For a call to an `Op`, we use the manually defined `FInferShape` macro if it has been defined and `RuntimeDepShape` if it has not. `FInferShape` is a function that takes in the call node and produces a `ShapeCompExpr`. + 3. For all other cases with `Call(op, args)`, we consider the following cases: + 1. If `op` is a `GlobalVar` or a `Var` that refers to a function defined in the current scope, look up the `Function` node it references; let us call it `f`. Similarly, if `op` is itself a `Function` node, let `f` be `op`. + + Attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) on `f`'s return shape. A pseudocode procedure for this beta-reduction is given below, as a macro. + + 1. If the return shape of `f` is a `Call` node or contains any `Call` nodes, substitute any parameters of `f` for the corresponding member of `args`. (E.g., if `f` has parameters `p1`, `p2`, …, `pn` and any of these variables appears in the return shape, `p1` should be replaced with the first member of `args`; `p2`, with the second; etc.) If any member of `args` that is substituted this way is not a `Var` or `Constant`, consider beta-reduction to fail. + 2. For each shape annotation in the parameters of `f`, attempt to match it with the `shape_` of the corresponding member of `args`, substituting shape variables in the return shape accordingly. If the `shape_` of the member of `args` is `RuntimeDepShape`, consider beta-reduction to fail. If the `shape_` is not `RuntimeDepShape` but is incompatible with the parameter's shape annotation (e.g., a `Tuple` where a `ShapeExpr` was expected), report an error at compile time. + + If `f`'s return shape is `RuntimeDepShape`, then consider the call result to have `RuntimeDepShape`. If beta-reduction is considered to fail, then consider the call result to have `RuntimeDepShape`. If it succeeds, use the resulting shape as the `shape_` of the call result. + + 2. Otherwise, consider the result of the call to have `RuntimeDepShape`. +10. For a function node, set the `shape_` to `RuntimeDepShape`. + +### Procedure for Substituting a Function Return Shape to Determine the Shape of a Call + +The `substitute_shape` procedure defined below describes how the shape expression for a call result can be defined given the call arguments and the return shape annotation on the corresponding function node. Note that this procedure can obtain much more precise results in the cases of `Call` or `TupleGetItem` return shapes. + +```python +def map_shape_vars(param_shape: ShapeCompExpr, arg_shape: ShapeCompExpr, shape_var_mapping: {tir::Var : PrimExpr}) -> bool: + if param_shape is RuntimeDepShape or arg_shape is RuntimeDepShape: + return False + if param_shape is ShapeExpr and arg_shape is ShapeExpr: + if len(param_shape.values) != len(arg_shape.values): + raise UnificationError("Shapes are of incompatible ranks") + for param_dim, arg_dim in zip(param_shape.values, arg_shape.values): + if param_dim in shape_var_mapping: + # syntactic equality + if arg_dim != shape_var_mapping[param_dim]: + # if they are statically not equal, e.g., 5 != 7 or 3 + 3 != 3*3 + if can_prove_not_equal(arg_dim, shape_var_mapping[param_dim]): + raise UnificationError("Incompatible dimensions") + else: + return False + else: + shape_var_mapping[param_dim] = arg_dim + return True + if param_shape is Tuple and arg_shape is Tuple: + if len(param_shape.fields) != len(arg_shape.fields): + raise UnificationError("Tuples are of incompatible lengths") + for param_field, arg_field in zip(param_shape.fields, arg_shape.fields): + ret = map_shape_vars(param_field, arg_field, shape_var_mapping) + if not ret: + return False + return True + if param_shape is TupleGetItem and arg_shape is TupleGetItem: + # Does not necessarily indicate a unification error, + # depending on what the tuple values are. + # Constant folding the TupleGetItem nodes could improve this unification case + if param_shape.index != arg_shape.index: + return False + return map_shape_vars(param_shape.tup_value, arg_shape.tup_value) + if param_shape is Call and arg_shape is Call: + # no dimension mapping to do in this case + return True + # if either is a Call or TupleGetItem, it is possible that the shapes + # can match dynamically even if they don't match statically + if (param_shape is Call + or param_shape is TupleGetItem + or arg_shape is Call + or arg_shape is TupleGetItem): + return False + raise UnificationError("Incompatible shape constructs") + +def substitute_vars(target: Expr, var_mapping: {Var: Expr}, shape_var_mapping: {tir::Var: PrimExpr}) -> Expr: + def substitute_shape_vars(target: PrimExpr): + if target is tir::Var: + if target in shape_var_mapping: + return shape_var_mapping[target] + else: + return target + # proceed recursively in all subexpressions, checking for vars + + if target is Var: + if target in var_mapping: + return var_mapping[target] + return target + if target is ShapeExpr: + return ShapeExpr([ + substitute_shape_vars(dim) + for dim in target.values + ]) + # recurse through all other cases, checking for vars and shape exprs analogously + +def substitute_shape(func_params, arg_exprs, ret_shape): + var_mapping = {param: arg_expr for param, arg_expr in zip(func_params, arg_exprs)} + shape_var_mapping = {} + for param, arg_expr in zip(func_params, arg_exprs): + can_unify = map_shape_vars(param.shape_, arg_expr.shape_, shape_var_mapping) + if not can_unify: + return RuntimeDepShape() + + new_shape = substitute_vars(ret_shape, var_mapping, shape_var_mapping) + if new_shape contains any free (Relax or shape) variables: + return RuntimeDepShape() + return new_shape +``` + +### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks + +There can be some complexity involved in checking whether two shapes match during shape inference. A very simple, conservative method for determining equality is simply using alpha-equivalence: If the two shapes have the same structure, then they are equivalent. However, this method is conservative and can overlook numerical properties in `PrimExpr`s. We leave it up to compiler implementations as to whether to use more advanced methods for proving equivalence, such as attempting to use algebraic rewrite rules. (As a consequence, portability requires inserting dynamic checks wherever there needs to be a comparison of shapes.) + +Note that optimizations like function inlining or constant folding could allow for simplifying many shape annotations and expressions and make it possible to conclude at compile time that shapes in more cases are equivalent. In general, developing compiler infrastructure for partial evaluation and reasoning about common situations with shape annotations may eliminate many dynamic checks. + +Applying some kind of normalization or algebraic simplifications to `PrimExpr`s used in shape annotations and in `shape_` fields can also make it easier to conclude that certain dynamic checks may not be necessary by increasing the likelihood that more `shape_` expressions could be made syntactically identical to the shape annotations. It would also be possible to generate compile-time warnings if analysis reveals that two shapes may not match (either using rewrite rules or by trying random values for shape variables and checking). + +Since most dynamic shape checks are done for safety, it may be feasible to introduce a compilation mode that eliminates almost all dynamic shape checks. Some shape checks may not be possible to eliminate, since the body of the program may construct `ShapeExpr`s and use them in calls to `PackedFunc`s, so some bindings to shape variables may need to be preserved, per a liveness analysis. + +## Possible Extensions to the Shape Expression System + +We may consider two possible extensions to the shape expression system in order to accommodate two further cases: + +1. An explicit wildcard dimension (e.g., using `tir::Any`) to allow for dimensions to be specified as "unknown" in function return shapes. As described at present, the only way for a function to specify a partly unknown return shape is to make the entire return shape unknown (`RuntimeDepShape`), which loses partial shape information. +2. Adding `shape_` expressions consisting of functions, to allow arbitrary closures to have a known shape. This would allow the shapes of calls to closures of unknown origin (namely, in a higher-order function) to have their shapes correctly inferred rather than made `RuntimeDepShape`. + +In both cases, these additions would entail additional complexity (shape inference macros for operators would have to deal with potential `tir::Any` nodes and we would have to define rules for constructing, calling, and simplifying functions in `shape_` expressions). However, the advantage of implementing these features would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using `RuntimeDepShape` means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchShape` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. + +# Detailed Semantics + +## Program Entry Point + +In the `IRModule`, every mapping of a `GlobalVar` to a `Function` node or a TIR `PrimFunc` should be processed first and added to the global scope. «Global functions that have a `global_symbol` attribute should be externally linked, meaning that they can be invoked as program entry points; those that do not have a `global_symbol` attribute can be called only from within the global functions in the `IRModule`.» + +The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax; these objects have type `Object` and can be used only by the `call_tir` operator. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. + +## Evaluating Expressions + +For each expression, we define how it affects the program's visible state and the order in which they are evaluated. Below, all evaluation results are passed by reference (and hence possibly alias) unless it is explicitly specified that they allocate new values. + +1. The node `Constant(value)` creates a new tensor whose contents are `value`. +2. A variable (whether `Var`, `DataflowVar` , or `GlobalVar`) evaluates to the stored value for that variable in the current scope. +3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. +4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per type checking, must evaluate to a tuple) and then returning the `i`th field of the result. +5. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. +6. `RuntimeDepShape` expressions must not appear in the general body of a program; it is a well-formedness error if they do. They do not have any defined semantics. +7. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. +8. The node `If(cond, true_branch, false_branch)` is evaluated as follows: + 1. First `cond` is evaluated. Let the result be `r` (per type checking, it must be a bool scalar). + 2. If `r` is true, evaluate the `true_branch` and return its result. + 3. If `r` is false, evaluate the `false_branch` and return its result. +9. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: + 1. If `op` is an `ExternFunc` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Next, look up the `PackedFunc` registered under the global symbol name. If it exists (it is an error at run time if it does not), call the `PackedFunc` using the given arguments and return the result. Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. + 2. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» + 3. In all other cases, first evaluate `op` (it must evaluate to a closure). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. +10. For the node `SeqExpr(blocks, body)`, we evaluate as follows: + 1. Push a new scope onto the stack. + 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: + 1. If the binding is `MatchShape(var, value, shape)`, perform the shape matching and shape variable updates as described in the shape evaluation section. If `var` is provided, `var` will be bound to `value` in the current scope; this assignment is aliasing and no new value is allocated. If `var` is not provided, then the shape check is performed and shape variables are updated, but no new binding is introduced. + 2. If the binding is `VarBinding(var, value)`, then evaluate `value` and bind `var` to that value in the current scope; this assignment is aliasing and no new value is allocated. + 3. If `block` is a `DataflowBlock`, remove all `DataflowVar`s bound in the block from the current scope before proceeding to the next block. + 3. After iterating through the binding blocks, evaluate `body` in the current scope. That will be the return value of the `SeqExpr`. + 4. Pop the scope, removing any `Var` bindings introduced in the `SeqExpr`. This should also remove any shape variables introduced and bound in the `SeqExpr` as well. + +### Optimizations + +Optimizations are allowed to reorder and modify the operations of a program in any way so long as they do not change the value returned by evaluating the program or any visible behavior of the program. For the purposes of compilation, visible behaviors consist of side effects like mutating values in the program or external effects like I/O (printing to the console, creating files, etc.) and the order and number of times in which they happen. + +«Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchShape` or `cast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» + +The specification makes no guarantees about certain memory-related properties and hence also does not consider them to be "visible behaviors": + +- Whether an allocation happens at a given point. Compiler implementations are permitted to reuse already-allocated memory if it would not interfere with visible state in any other way, per the aliasing rules (`PackedFunc`s or operators may mutate values that are passed to them and those mutations should be visible as per aliasing in this specification). Copying values or sharing representations (e.g., interning constants) between values may be done only if they will not affect any other visible behaviors, dependent on the aliasing behavior. +- It is entirely the domain of compiler implementations to make guarantees (or not) as to whether memory allocations will succeed. +- `PackedFunc`s or operators can, in principle, access information about the machine's state and make changes to allocation policies or the state that affect how memory allocations are performed. The specification makes no guarantees in such an event. + +These semantic rules assume a single thread of evaluation on a single host machine. At this time, it is unspecified as to how Relax programs should behave if split over distinct threads or across multiple machines. + +### Notable Operators + +The above evaluation rules are general, but leave much room for implementations of operators to specify custom semantics. Certain operators are used to perform common operations and will be discussed here as well. + +- `call_tir(prim_func, arg1, arg2, ..., argn, shape, type_args=[aT])`: `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). The `shape` argument gives the shapes of the result of calling the TIR `PrimFunc`: It must be either of `ShapeType` (corresponding to returning a single tensor) or `TupleType` whose members are `ShapeType` (corresponding to returning a tuples of tensors). The type arg `aT` gives the type of the result of calling the `PrimFunc` and it must correspond to `shape` (namely, if `shape` is of `ShapeType`, `aT` must be a `DynTensorType`; if `shape` is of `TupleType`, `aT` must be a `TupleType` whose fields are `ShapeType`). `aT` is used especially to provide the `dtype` of returned tensors. + + Based on `shape`, the resulting tensor or tuple `r` will be allocated according to the sizes given in `shape`. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. + +- `call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type. +- `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. +- «`cast(v, type_args=[aT])`: Given an argument `v`, it dynamically checks if `v`'s run-time representation is a subtype of `aT`. If it is not, it exits the program with an error message. Otherwise, it returns `v`.» + From 1bd7f8e2cab0e133fd071fb6cca25818f0713c94 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 25 Oct 2022 18:16:36 -0400 Subject: [PATCH 02/47] call_dps_packed is not yet implemented --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 40f69e4ba91c..4123539323fc 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -743,7 +743,7 @@ The above evaluation rules are general, but leave much room for implementations Based on `shape`, the resulting tensor or tuple `r` will be allocated according to the sizes given in `shape`. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. -- `call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type. +- «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type.» - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. - «`cast(v, type_args=[aT])`: Given an argument `v`, it dynamically checks if `v`'s run-time representation is a subtype of `aT`. If it is not, it exits the program with an error message. Otherwise, it returns `v`.» From 59a0e6fc6e109d78e44b54d9de0eecbaaa2a983d Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 25 Oct 2022 18:22:07 -0400 Subject: [PATCH 03/47] Many shape mechanics are still unimplemented --- relax_spec.md | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 4123539323fc..1f4c828868ec 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -432,16 +432,16 @@ Shape computations can consist of the following expressions, which are a subset ``` ShapeCompExpr ::= ShapeExpr(dims: [PrimExpr]) | RuntimeDepShape() - | Tuple(fields: [ShapeCompExpr]) + | «Tuple(fields: [ShapeCompExpr])» | Call(op: Op|ExternFunc, args: [Var|Constant]) - | TupleGetItem(tuple_value: ShapeCompExpr, index: int) + | «TupleGetItem(tuple_value: ShapeCompExpr, index: int)» ``` The shape expressions can be interpreted as follows: - `ShapeExpr` describes the shape of a tensor as a list of dimensions -- `Tuple` describes the shapes of each member of a tuple -- `TupleGetItem` describes the shape of a member of a tuple +- «`Tuple` describes the shapes of each member of a tuple» +- «`TupleGetItem` describes the shape of a member of a tuple» - `Call` describes the shape of a function (or operator) call return value in terms of its arguments - `RuntimeDepShape` describes shapes that are unknown at compile time (like when a shape annotation is omitted) or the shapes of values that don't have shapes (like shapes themselves, paradoxically: they *are* shapes but do not *have* shapes). @@ -453,9 +453,9 @@ Shape computations are allowed to include calls to operators and even `PackedFun **Shape Annotations** -For shape annotations, we use `ShapeCompExpr` as the grammar, as with `shape_` expressions. `ShapeExpr` is used to annotate shapes of tensor values, `Tuple` is used to annotate the shapes of tuple values, and `RuntimeDepShape` is used to indicate annotations that have been omitted or shapes that cannot be known at compile time (like the shapes of tensors whose rank is unknown at compile time). `Call` is used to annotate the shapes of calls to operators and `TupleGetItem` annotates the shapes of tuple indices. +For shape annotations, we use `ShapeCompExpr` as the grammar, as with `shape_` expressions. `ShapeExpr` is used to annotate shapes of tensor values, «`Tuple` is used to annotate the shapes of tuple values», and `RuntimeDepShape` is used to indicate annotations that have been omitted or shapes that cannot be known at compile time (like the shapes of tensors whose rank is unknown at compile time). `Call` is used to annotate the shapes of calls to operators and «`TupleGetItem` annotates the shapes of tuple indices.» -For example, suppose we have a tuple where some fields are tensors like the following: +«For example, suppose we have a tuple where some fields are tensors like the following: ```python x : Tuple(Tensor((m, n), "int32"), Tuple(), Tensor((), "int32"), Tensor(_, "int32")) = ... @@ -466,12 +466,13 @@ It has the shape annotation ```python Tuple([ShapeExpr([m, n]), Tuple([]), ShapeExpr([]), RuntimeDepShape]) ``` +» Note that it is [a well-formedness requirement](https://www.notion.so/Informal-Relax-Language-Specification-d1fdedb8fae84f0d82b9f880f25e7370) that if any field in a type has a `ShapeExpr` annotation, it must be a `DynTensorType` with an `ndim` matching the number of dimensions in the `ShapeExpr`. For example, in the above function signatures, the `ndim` in the type annotations must be 2. -### Assigning Shape Variables at the Start and End of a Function +### «Assigning Shape Variables at the Start and End of a Function» -Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchShape`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchShape` bindings. Suppose a function has the following signature, where the `Ti` are type annotation and the `Si` are shape annotations: +«Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchShape`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchShape` bindings. Suppose a function has the following signature, where the `Ti` are type annotation and the `Si` are shape annotations: ```python def f(arg1 : (T1, S1), arg2 : (T2, S2), ..., argn : (Tn, Sn)) -> (Tr, Sr): @@ -490,6 +491,7 @@ def f(arg1 : T1, arg2 : T2, ..., argn : Tn) -> Tr: check_annotation(ret_var, Tr, Sr) return ret_var ``` +» Because `MatchShape` is defined only for tensor and shape values, we must use a macro to handle other possible types that may be passed into a function, given here in pseudocode: @@ -502,7 +504,7 @@ def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: [BindingBlock([MatchShape(tmp, e, s.dims)])], tmp ) - else if s is a Tuple: + «else if s is a Tuple: # type checking should ensure that e is always a tuple and the lengths match shapes = s.fields tmp = fresh_var() @@ -515,7 +517,7 @@ def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: ..., VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, n-1), shapes[n-1])) ])], tmp - ) + )» else if s is a Call: tmp = fresh_var() return SeqExpr( @@ -525,7 +527,7 @@ def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) ])], tmp ) - else if s is TupleGetItem: + «else if s is TupleGetItem: val = s.tuple_value if val is Tuple: return check_annotation(e, val.fields[s.index]) @@ -535,7 +537,7 @@ def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: VarBinding(tmp, e), VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) ])], tmp - ) + )» else if s is RuntimeDepShape: # no need to check return e @@ -552,7 +554,7 @@ Shape expressions follow the same evaluation rules as general program expression For each expression type, we can recursively build up an associated `shape_` expression according to the following rules: 1. For `Constant(value)`, the `shape_` expression is a `ShapeExpr` corresponding to the concrete shape of `value`. For example, for `Constant(1)`, `shape_` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape_` is `ShapeExpr([2])`. -2. For `Tuple(fields)`, `shape_` can be defined as `Tuple([field.shape_ for field in fields])`. +2. «For `Tuple(fields)`, `shape_` can be defined as `Tuple([field.shape_ for field in fields])`.» 3. For `ShapeExpr`s, `shape_` is `RuntimeDepShape`. 4. `RuntimeDepShape` expressions should appear only in shape expressions; their `shape_` is not defined. 5. For `If(cond, true_branch, false_branch)`, we compare the `shape_` of `true_branch` and `false_branch`. If these can be proven equivalent (by a method that the compiler implementation is free to determine), then the `If` node's `shape_` is that shape. If they do not match, then we set it to `RuntimeDepShape`. @@ -565,11 +567,11 @@ For each expression type, we can recursively build up an associated `shape_` exp 3. For bindings where the RHS is a function literal or assigning the `shape_` of a `GlobalVar`, see the rule for `Function` nodes. 4. For `MatchShape(var, value, shape)`, we set the `shape_` of `var` to `shape`, as it will be dynamically checked. -8. For `TupleGetItem(tuple_value, i)`, we examine the `shape_` of `tuple_value`; suppose it is `s`. If `s` is a `Tuple`, then we use its `i`th field. If it is `RuntimeDepShape`, we use `RuntimeDepShape`. If it is a `Call` to a function that returns a tuple with at least `i + 1` members, set the `shape_` to `TupleGetItem(s, i)`. Otherwise, raise an error at compile time (though this should not happen if type checking has passed). +8. «For `TupleGetItem(tuple_value, i)`, we examine the `shape_` of `tuple_value`; suppose it is `s`. If `s` is a `Tuple`, then we use its `i`th field. If it is `RuntimeDepShape`, we use `RuntimeDepShape`. If it is a `Call` to a function that returns a tuple with at least `i + 1` members, set the `shape_` to `TupleGetItem(s, i)`. Otherwise, raise an error at compile time (though this should not happen if type checking has passed).» 9. For `Call` nodes: 1. For a call to an `ExternFunc`, we use `RuntimeDepShape` because we cannot analyze the shapes of arbitrary `PackedFunc`s and must check dynamically. 2. For a call to an `Op`, we use the manually defined `FInferShape` macro if it has been defined and `RuntimeDepShape` if it has not. `FInferShape` is a function that takes in the call node and produces a `ShapeCompExpr`. - 3. For all other cases with `Call(op, args)`, we consider the following cases: + 3. «For all other cases with `Call(op, args)`, we consider the following cases: 1. If `op` is a `GlobalVar` or a `Var` that refers to a function defined in the current scope, look up the `Function` node it references; let us call it `f`. Similarly, if `op` is itself a `Function` node, let `f` be `op`. Attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) on `f`'s return shape. A pseudocode procedure for this beta-reduction is given below, as a macro. @@ -580,6 +582,7 @@ For each expression type, we can recursively build up an associated `shape_` exp If `f`'s return shape is `RuntimeDepShape`, then consider the call result to have `RuntimeDepShape`. If beta-reduction is considered to fail, then consider the call result to have `RuntimeDepShape`. If it succeeds, use the resulting shape as the `shape_` of the call result. 2. Otherwise, consider the result of the call to have `RuntimeDepShape`. + » 10. For a function node, set the `shape_` to `RuntimeDepShape`. ### Procedure for Substituting a Function Return Shape to Determine the Shape of a Call From de63d6d43dd56f7f800a266622485c97445a9162 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 25 Oct 2022 18:23:41 -0400 Subject: [PATCH 04/47] Indicate datatype in AST diagram --- relax_spec.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index 1f4c828868ec..d38afff8f963 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -51,6 +51,12 @@ Type ::= DynTensorType(ndim: int, dtype: DataType) | TupleType(fields: [Type]) | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») +DataType ::= + Int(bitwidth: int) + | Float(bitwidth: int) + | Bool() + | Void() + # expressions Expr ::= Constant(data: NDArray) # scoped to functions or SeqExprs From ce7c76b7505b7dbb2f3ee488fe5267aff89b7825 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 25 Oct 2022 18:37:14 -0400 Subject: [PATCH 05/47] Add text about variable shadowing --- relax_spec.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index d38afff8f963..6357a23c0db4 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -200,6 +200,27 @@ There are four relevant scopes in Relax, which determine where variables are vis 3. `SeqExpr`: `Var` nodes defined in a `BindingBlock` in a `SeqExpr` node can be referenced in any later binding within the same `BindingBlock`, in any binding within any later `BindingBlock` in that `SeqExpr` node, or in the `SeqExpr`'s body expression. The variables defined in the `BindingBlock`s leave scope once the `SeqExpr` returns. 4. `DataflowBlock`: `DataflowVar`s introduced in a `DataflowBlock` can be referenced in any later binding within that `DataflowBlock`, but leave scope *once that `DataflowBlock` finishes executing*. Definitions in a `DataflowBlock` that are intended to leave the `DataflowBlock` should be bound to an ordinary `Var`. +Note that Relax variables must be bound _exactly_ once. A global variable is bound if it is mapped to a function in the `IRModule` and a local variable is bound if it appears as a function parameter or if it appears on the left-hand side (LHS) of a binding (`VarBinding` or `MatchShape`). + +«If there is another binding to a local variable with the same name as an already-bound variable, that is binding is considered to _shadow_ the previous binding, i.e., it is a binding to a new, distinct variable that happens to have the same name as the existing variable. The new, shadowing variable will exist only in the current scope; if the older variable was defined in an outer scope, then future uses of that name will refer to the older variable. [See the Wikipedia page for more information on variable shadowing.](https://en.wikipedia.org/wiki/Variable_shadowing)» + +Below is an example of shadowing, in pseudocode: + +```python +@R.function +def func(x: Tensor) -> Tensor: + if True: + # the true branch will be a nested SeqExpr and hence a new scope + # this x will shadow the function parameter x + x = R.const(1) + R.print(x) # prints 1 + # the inner x goes out of scope + else: + R.print("not executed") + R.print(x) # this x is the function parameter + return x +``` + # Well-Formedness Criteria Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid. From 5d04bdcee89c099e970b268cc0aebe67a73c46d6 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 22 Nov 2022 22:45:38 -0500 Subject: [PATCH 06/47] Discuss differences from Relay --- relax_spec.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index 6357a23c0db4..89866323ca8e 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -21,6 +21,12 @@ Though this document will use the TVMScript front end for some examples, specify This section will outline the grammar of Relax and give very brief descriptions of the different components, including the semantics, type system, and shape system. The rest of this document will provide more detailed descriptions of these facets of the language, including the validity conditions that the type system and shape system uphold. +## Differences from Relay + +Per the [original workshop paper](https://arxiv.org/abs/1810.00952) and the [later report](https://arxiv.org/abs/1904.08368), Relay was designed to be a high-level functional language for expressing deep learning models at a high level. While Relay is not entirely pure (the `Ref` type is modeled after reference types in SML and similar functional languages), the assumption in Relay is that tensor operators are generally pure, meaning that they do not change the program state other than by producing new values. Additionally, Relay's type system also requires operators to have type relations that infer static tensor types or conclude that a dimension is unknown at compile time (`Any`). The need to register type relations and ensure operators' purity makes it difficult to add new operators to Relay and particularly difficult to call directly into TIR or external libraries, which are often not pure; any such extension requires adding new operators and abstracting over any impurity. + +While Relax aims to be as general and expressive in Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has an associated shape _computation_ associated with it, in addition to a type. These shape computations support static reasoning about shapes in many cases, but also facilitate a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. + ## Grammar Below is a diagram of the various AST constructs in Relax, including types. In code, these are defined on the C++ side in `include/tvm/relax/{expr.h, type.h}` and in Python in `python/tvm/relax/{expr.py, ty.py}`. This diagram will give the names of the AST nodes and the types and names of their members. The semantics will describe what computation each construct represents; an AST is simply data. A Relax program consists of an `IRModule` with global variables bound to Relax functions that implement the computations of interest. From 6cab2cbbf36420673a5cbda1504981b2ef556ece Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 22 Nov 2022 22:56:35 -0500 Subject: [PATCH 07/47] Correct typo --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 89866323ca8e..6084c4e286dc 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -25,7 +25,7 @@ This section will outline the grammar of Relax and give very brief descriptions Per the [original workshop paper](https://arxiv.org/abs/1810.00952) and the [later report](https://arxiv.org/abs/1904.08368), Relay was designed to be a high-level functional language for expressing deep learning models at a high level. While Relay is not entirely pure (the `Ref` type is modeled after reference types in SML and similar functional languages), the assumption in Relay is that tensor operators are generally pure, meaning that they do not change the program state other than by producing new values. Additionally, Relay's type system also requires operators to have type relations that infer static tensor types or conclude that a dimension is unknown at compile time (`Any`). The need to register type relations and ensure operators' purity makes it difficult to add new operators to Relay and particularly difficult to call directly into TIR or external libraries, which are often not pure; any such extension requires adding new operators and abstracting over any impurity. -While Relax aims to be as general and expressive in Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has an associated shape _computation_ associated with it, in addition to a type. These shape computations support static reasoning about shapes in many cases, but also facilitate a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. +While Relax aims to be as general and expressive as Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has an associated shape _computation_ associated with it, in addition to a type. These shape computations support static reasoning about shapes in many cases, but also facilitate a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. ## Grammar From b766e91d83e6b71eaaf4c1eec00a3691cc0bb405 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 1 Dec 2022 22:48:23 -0500 Subject: [PATCH 08/47] Add description of PackedFuncType --- relax_spec.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 6084c4e286dc..5899fdfdcaee 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -56,6 +56,7 @@ Type ::= DynTensorType(ndim: int, dtype: DataType) | ObjectType() | TupleType(fields: [Type]) | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») + | PackedFuncType() DataType ::= Int(bitwidth: int) @@ -159,7 +160,8 @@ The types in Relax correspond to the broad categories of the values given above: 2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. 3. `ShapeType` corresponds to shape values. 4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» -5. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. +5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). Since packed functions are not first-class values (`ExternFunc` can appear only in the `op` position of a `Call` node), these do not actually correspond to any value in Relax, but can be used to assign a type to `ExternFunc` nodes. +6. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» @@ -342,7 +344,7 @@ def find_lub(T1 : Type, T2 : Type) -> Type: return T1 if T1 or T2 is ObjectType: return Object - if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType: + if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType, or both PackedFuncType: return ObjectType if T1 and T2 are both DynTensorType: res_ndim = T1.ndim @@ -393,13 +395,14 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType`. 5. The type of a `RuntimeDepShape` expression is `ShapeType`. 6. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. -7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: +7. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. +8. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: 1. If `op` is a Relax `Op` node, then we look up its registered `FInferType` property. `FInferType` is a macro that takes in the `Call` node and produces a type. We return the type `op.FInferType(Call(op, [a1, ..., an], type_args=[aT]))`. The implementation of `FInferType` is free to throw errors. 2. If `op` is `ExternFunc`, then use the sole member of `type_args` (calls to `ExternFunc`s are required to have exactly one `type_args` member) `aT` as the return type. Packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function itself to do any validation. 3. Otherwise, check the types of the subexpressions, left to right. Suppose `op` has the type `Tf`, `a1` has type `T1`, …, and `an` has type `Tn`. If `Tf` is not a function type with exactly `n` arguments, we consider it a type error and require an explicit cast. Suppose `Tf` has argument types `T1'`, `T2'`, …, `Tn'`. Consider it a type error if any of the following does not hold: `T1 <: T1'`, `T2 <: T2'`, …, or `Tn <: Tn'`. Then the return type is `Tf.ret_type`. -8. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. -9. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» -10. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. +9. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. +10. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» +11. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. @@ -413,7 +416,7 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 4. If `T'` is `TupleType` or `FuncType`, emit a type error. 2. If the current block is a `DataflowBlock`, remove any `DataflowVar`s from `Γ` after we have finished processing the bindings in that block. 3. Finally, the type of the `SeqExpr` is the type of `body` checked under the current `Γ`. Afterwards, remove all bindings added from the `b0`, `b1`, …, `bn` from `Γ`. -11. Let us consider a function `Function(v1 : T1, v2 : T2, ..., vn : Tn, attrs=a) -> Tr: body`. +12. Let us consider a function `Function(v1 : T1, v2 : T2, ..., vn : Tn, attrs=a) -> Tr: body`. 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, where `p` is `True` if a `pure` attribute is included and `False` otherwise. Remove `fv` from `Γ` before returning. 2. Add `v1` to `Γ` with type `T1`, `v2` with type `T2`, …, and `tn` with type `Tn`. Recursively type check `body` in this new context: 1. «Determining purity: If `body` contains any call to a function whose return type does not specify that it is pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not specify the `pure` attribute, then we consider the function to be (potentially) impure. If all calls are to functions whose return type specifies purity or that include the `pure` attribute on the call or `Op`, then the function is treated as pure.» From 91ecf52808f2829c8654119f8c60811c5c447ba0 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 7 Dec 2022 17:05:56 -0500 Subject: [PATCH 09/47] Add a couple of missed references to PackedFuncType --- relax_spec.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 5899fdfdcaee..b690d6ab7b89 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -258,13 +258,14 @@ Prior to type-checking and shape inference, Relax programs must conform to certa # Types in Relax -Relax presently has five types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: +Relax presently has six types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: 1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. 2. `ShapeType`, referring to shape values. 3. `FuncType`, referring to functions (closures). `FuncType`s specify the types of their parameters, a return type, and whether the function is pure. 4. `TupleType`, referring to tuple values, giving the types of their fields. -5. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. +5. `PackedFuncType`, referring to the type of PackedFunctions. +6. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. ## Subtyping @@ -298,7 +299,7 @@ def find_glb(T1 : Type, T2 : Type) -> Type?: return T2 if T2 is ObjectType: return T1 - if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType: + if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType, or not both PackedFuncType: return None if T1 and T2 are both DynTensorType: ret_ndim = T1.ndim From 002a7aa940804fdc5f9bba0510000211f0784a39 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 7 Dec 2022 17:07:56 -0500 Subject: [PATCH 10/47] Add forward pointer to the type-checking rule for local functions --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index b690d6ab7b89..c41ddf04ada1 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -406,7 +406,7 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 11. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» - 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. + 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, the type annotation `T` is not optional and we add `v` to `Γ` before type-checking the function body; see the rule for `Function` nodes.) 3. For each `MatchShape(v: T, e, shape_pattern)`, where `T` is an optional type annotation, let the checked type of `e` be `T'`. 1. If `T'` is `ShapeType`, then emit an error if `T` is not a supertype of `ShapeType`. Add `v` to `Γ` with type `T`. 2. If `T'` is `DynTensorType`: From 58d70f8a53dbd5c20af86862e3f9eedec08b7caa Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Dec 2022 23:35:07 -0500 Subject: [PATCH 11/47] Describe normal form in the spec --- relax_spec.md | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index c41ddf04ada1..84868ac1392e 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -12,10 +12,11 @@ Though this document will use the TVMScript front end for some examples, specify 2. [Top-Level Program Organization](#top-level-program-organization-irmodule) 3. [Values in Relax](#values-in-relax) 4. [Variable Scoping](#variable-scoping) -5. [Well-Formedness Criteria](#well-formedness-criteria) -6. [Types in Relax](#types-in-relax) -7. [Shapes in Relax](#shapes-in-relax) -8. [Semantics](#detailed-semantics) +5. [Normal Form](#normal-form) +6. [Well-Formedness Criteria](#well-formedness-criteria) +7. [Types in Relax](#types-in-relax) +8. [Shapes in Relax](#shapes-in-relax) +9. [Semantics](#detailed-semantics) # Overview @@ -229,10 +230,31 @@ def func(x: Tensor) -> Tensor: return x ``` +# Normal Form + +To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the type- and shape-checking rules for operators rely on macros (`FInferType` and `FInferShape`), _this means that the structure of the program can affect type and shape inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these type- and shape-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying type- or shape-checking. + +The normal form for Relax is very similar to ANF; differences will be noted. Here are the criteria required for a program to be in normal form: +1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, `RuntimeDepShape`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. +2. `SeqExpr`s may appear only in the following locations: + 1. In the `body` field of a `Function` node. + 2. In the `true_branch` and `false_branch` fields of `If` nodes. +3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. +4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. + +Programs that are parsed should be "normalized" before performing type-checking or shape-checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: +1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. +2. If the function body is already a `SeqExpr`, consolidate all `BindingBlock`s, then check if the `body` field of the `SeqExpr` is a leaf expression. If not, bind it to a new var in the final `BindingBlock` and replace the `SeqExpr` body with the new var. +3. If the function body is not a `SeqExpr`, then recurse down the body's AST, binding any nested non-leaf expressions to a var in the current scope (doing this process in breadth-first order from left to right will respect the evaluation order in the semantics). If the body itself is a non-leaf expression, finally bind it to a var and have the final `SeqExpr` return the new var. +4. If an `If` node is encountered, ensure the `true_branch` and `false_branch` fields are `SeqExpr`s (consolidate `BindingBlock`s if necessary) or "wrap" them in `SeqExpr`s in the same manner as the function body. +5. If a `SeqExpr` node is encountered as the `value` node in a binding, "flatten" the `SeqExpr` by adding its bindings to the current scope and replacing the `SeqExpr` with its body. If the `SeqExpr` body is a non-leaf expression, normalize it recursively in the same manner as in step 3 before replacing the binding. Note that if the current scope (the location of the binding) is a `DataflowBlock` and the nested `SeqExpr` contains an ordinary `BindingBlock`, that indicates a malformed program. + + # Well-Formedness Criteria -Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid. +Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid, which includes conforming to the expectations of the above-described normal form. +The following criteria apply to all programs (including before normalization): 1. `DataflowVar`s can be bound only inside `DataflowBlock`s. Additionally, a `DataflowVar` may not be used outside of the `DataflowBlock` in which it is defined. 2. A `Var` of any kind used in the program must be either a function parameter or appear on the LHS of a binding exactly once. In the binding where a `Var` is defined, the same `Var` is permitted to occur in the RHS of the binding only if the binding is defining a function (i.e., local functions are permitted to be recursive). 3. A `Var` of any kind may not appear before it is bound. Namely, if a `Var` is bound in a `BindingBlock` in a `SeqExpr`, that `Var` may not appear in bindings that precede the one where it appears on the LHS. @@ -256,6 +278,8 @@ Prior to type-checking and shape inference, Relax programs must conform to certa 15. «Any `PackedFunc` or operator called in a shape annotation or `shape_` expression must be pure and be annotated as such.» 16. The node `RuntimeDepShape` may appear only in shape annotations and `shape_` expressions. It has no defined semantics at run time. +Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. + # Types in Relax Relax presently has six types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: From 53a2eb70f9ee844ee3757f0be972a0d328b3cb09 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 20 Dec 2022 19:54:54 -0500 Subject: [PATCH 12/47] Specify consolidating empty binding blocks --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 84868ac1392e..b27e75c8e020 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -240,7 +240,7 @@ The normal form for Relax is very similar to ANF; differences will be noted. Her 1. In the `body` field of a `Function` node. 2. In the `true_branch` and `false_branch` fields of `If` nodes. 3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. -4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. +4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. Empty `BindingBlock`s should be dropped. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. If all the `BindingBlock`s are empty, then the `blocks` field of the `SeqExpr` should be set to an empty list. Programs that are parsed should be "normalized" before performing type-checking or shape-checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: 1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. From 0dd7e86f31f883ffe0d6b4b3c631fe1d985d315b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 4 Jan 2023 19:16:43 -0500 Subject: [PATCH 13/47] Add ndim parameter to ShapeType --- relax_spec.md | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index b27e75c8e020..d0e9db96a50b 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -53,7 +53,7 @@ PrimExpr ::= # (others may be added later, as deemed necessary) Type ::= DynTensorType(ndim: int, dtype: DataType) - | ShapeType() + | ShapeType(ndim: int) | ObjectType() | TupleType(fields: [Type]) | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») @@ -159,7 +159,7 @@ The types in Relax correspond to the broad categories of the values given above: 1. `DynTensorType` corresponds to tensor values, giving the scalar data type and the number of dimensions (rank), both of which are optional. 2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. -3. `ShapeType` corresponds to shape values. +3. `ShapeType` corresponds to shape values, optionally giving the number of dimensions in the shape. 4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» 5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). Since packed functions are not first-class values (`ExternFunc` can appear only in the `op` position of a `Call` node), these do not actually correspond to any value in Relax, but can be used to assign a type to `ExternFunc` nodes. 6. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. @@ -285,7 +285,7 @@ Additionally, the criteria for normal form listed in the previous section must a Relax presently has six types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: 1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. -2. `ShapeType`, referring to shape values. +2. `ShapeType`, referring to shape values. The number of dimensions in the shape as given as `ndim` and is optional (using -1 for `ndim` indicates an unknown number of dimensions). 3. `FuncType`, referring to functions (closures). `FuncType`s specify the types of their parameters, a return type, and whether the function is pure. 4. `TupleType`, referring to tuple values, giving the types of their fields. 5. `PackedFuncType`, referring to the type of PackedFunctions. @@ -325,6 +325,13 @@ def find_glb(T1 : Type, T2 : Type) -> Type?: return T1 if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType, or not both PackedFuncType: return None + if T1 and T2 are both ShapeType: + ret_ndim = T1.ndim + if ret_ndim == -1: + ret_ndim == T2.ndim + if ret_ndim != -1 and T2.ndim != ret_ndim: + return None + return ShapeType(ret_ndim) if T1 and T2 are both DynTensorType: ret_ndim = T1.ndim ret_dtype = T1.dtype @@ -371,6 +378,11 @@ def find_lub(T1 : Type, T2 : Type) -> Type: return Object if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType, or both PackedFuncType: return ObjectType + if T1 and T2 are both ShapeType: + res_ndim = T1.ndim + if T1.ndim != T2.ndim: + res_ndim = -1 + return ShapeType(res_ndim) if T1 and T2 are both DynTensorType: res_ndim = T1.ndim res_dtype = T1.dtype @@ -417,8 +429,8 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 1. «We type check the entire `IRModule` one function definition at a time. To handle mutual recursion, we prepopulate `Γ` with the annotated types of all global functions that are called mutually recursively. We then proceed to check the types of the global functions one at a time.» 2. Given a variable `v`, if `v` is in `Γ`, then we return `Γ[v]`. Otherwise, it is an error. 3. Given a constant expression `Constant(data)`, the type is `DynTensorType(ndim=n, dtype=d)` where `n` is the number of dimensions in `data` and `d` is its data type (recall that `data` is an `NDArray` literal). -4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType`. -5. The type of a `RuntimeDepShape` expression is `ShapeType`. +4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType(n)`, where `n` is the length of `dims`. +5. The type of a `RuntimeDepShape` expression is `ShapeType(-1)`. 6. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. 7. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. 8. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: @@ -432,7 +444,10 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, the type annotation `T` is not optional and we add `v` to `Γ` before type-checking the function body; see the rule for `Function` nodes.) 3. For each `MatchShape(v: T, e, shape_pattern)`, where `T` is an optional type annotation, let the checked type of `e` be `T'`. - 1. If `T'` is `ShapeType`, then emit an error if `T` is not a supertype of `ShapeType`. Add `v` to `Γ` with type `T`. + 1. If `T'` is `ShapeType`: + 1. Emit an error if `T` is not a supertype of `ShapeType`. + 2. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. + 3. Add `v` to `Γ` with type `T`. 2. If `T'` is `DynTensorType`: 1. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. Let the datatype of `T'` be `d`. 2. If `T` is not a supertype of `DynTensorType(ndim=len(shape_pattern), dtype=d)`, then emit an error. If `T` is a subtype of that type, emit an error and request a cast. From c2b242e7d92b68bead306e91a0e9ba684d6c2f7e Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 7 Jan 2023 21:10:25 -0500 Subject: [PATCH 14/47] StructInfo update --- relax_spec.md | 768 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 457 insertions(+), 311 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index d0e9db96a50b..0e2b4c36604f 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -2,7 +2,7 @@ Note: Text in «double chevrons» indicates features not present in the current prototype. -In order to develop and test Relax, it is important for compiler developers to agree on what a given program in Relax means and what makes it valid so that test cases can be evaluated independently of any particular Relax implementation. This document is intended to describe Relax's grammar constructs (its [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree), or AST), the semantics of its grammar (what the different constructs mean), Relax's type system and type-checking rules (what makes a Relax program valid), and its rules for reasoning about tensor shapes in detailed though still informal terms. If necessary, we may encode these rules more formally to allow for more automated analysis. +In order to develop and test Relax, it is important for compiler developers to agree on what a given program in Relax means and what makes it valid so that test cases can be evaluated independently of any particular Relax implementation. This document is intended to describe Relax's grammar constructs (its [abstract syntax tree](https://en.wikipedia.org/wiki/Abstract_syntax_tree), or AST), the semantics of its grammar (what the different constructs mean), Relax's type system and type-checking rules (what makes a Relax program valid), and its rules for reasoning about structural information (such as tensor shapes) in detailed though still informal terms. If necessary, we may encode these rules more formally to allow for more automated analysis. Though this document will use the TVMScript front end for some examples, specifying the mapping from Python's AST to Relax's AST will be deferred until the parser becomes more stable. @@ -15,7 +15,7 @@ Though this document will use the TVMScript front end for some examples, specify 5. [Normal Form](#normal-form) 6. [Well-Formedness Criteria](#well-formedness-criteria) 7. [Types in Relax](#types-in-relax) -8. [Shapes in Relax](#shapes-in-relax) +8. [Structural Information in Relax](#structural-information-in-relax) 9. [Semantics](#detailed-semantics) # Overview @@ -26,7 +26,7 @@ This section will outline the grammar of Relax and give very brief descriptions Per the [original workshop paper](https://arxiv.org/abs/1810.00952) and the [later report](https://arxiv.org/abs/1904.08368), Relay was designed to be a high-level functional language for expressing deep learning models at a high level. While Relay is not entirely pure (the `Ref` type is modeled after reference types in SML and similar functional languages), the assumption in Relay is that tensor operators are generally pure, meaning that they do not change the program state other than by producing new values. Additionally, Relay's type system also requires operators to have type relations that infer static tensor types or conclude that a dimension is unknown at compile time (`Any`). The need to register type relations and ensure operators' purity makes it difficult to add new operators to Relay and particularly difficult to call directly into TIR or external libraries, which are often not pure; any such extension requires adding new operators and abstracting over any impurity. -While Relax aims to be as general and expressive as Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has an associated shape _computation_ associated with it, in addition to a type. These shape computations support static reasoning about shapes in many cases, but also facilitate a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. +While Relax aims to be as general and expressive as Relay, Relax is intended to make it much easier to interoperate with external libraries and especially with TIR. In particular, Relax includes a mechanism for calling arbitrary TVM `PackedFunc`s (which can call external libraries) and special support for TIR. The language accordingly does not assume that such operations are pure, though this does require reasoning about aliasing and similar issues. Additionally, tensor shapes are no longer handled during type checking; each expression has associated structural information associated with it, in addition to a type. This structural information supports static reasoning about tensor shapes in many cases, but also facilitates a fallback to dynamic checking when that is not possible. This approach to shapes allows for richer shape constraints and other structural properties to be checked at run time (such as with _symbolic_ shapes, where some dimensions are variables) and allows for more quickly integrating calls into TIR or external libraries into Relax code by obviating the need for type relations. ## Grammar @@ -59,29 +59,33 @@ Type ::= DynTensorType(ndim: int, dtype: DataType) | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») | PackedFuncType() -DataType ::= - Int(bitwidth: int) - | Float(bitwidth: int) - | Bool() - | Void() +DataType ::= Int(bitwidth: int) + | Float(bitwidth: int) + | Bool() + | Void() + +StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) + | ShapeStructInfo(values: [PrimExpr]?, ndim: int) + | ObjectStructInfo() + | TupleStructInfo(fields: [StructInfo]) + | FuncStructInfo(params: [StructInfo]?, ret: StructInfo, derive_func: EnvFunc?*) # expressions Expr ::= Constant(data: NDArray) # scoped to functions or SeqExprs - | Var(name_hint: string) + | Var(name_hint: string, struct_info_annotation: StructInfo?) # scoped to DataflowBlocks - | DataflowVar(name_hint: string) + | DataflowVar(name_hint: string, struct_info_annotation: StructInfo?) | GlobalVar(name_hint: string) | Tuple(fields: [Expr]) | SeqExpr(blocks: [BindingBlock], body: Expr) - | Function(params: [Var], body: Expr, ret_type: Type?, attrs: Attrs?) + | Function(params: [Var], body: Expr, ret_struct_info: StructInfo?, attrs: Attrs?) | If(cond: Expr, true_branch: Expr, false_branch: Expr) | ExternFunc(global_symbol: string) | Call(op: Expr, args: [Expr], type_args: [Type], attrs: Attrs?) | ShapeExpr(values: [PrimExpr]) | TupleGetItem(tuple_value: Expr, index: int) | Op(op_name: string) - | RuntimeDepShape() # binding blocks (analogous to sequence of statements) BindingBlock ::= @@ -91,7 +95,7 @@ BindingBlock ::= # bindings (analogous to statements) Binding ::= VarBinding(var: Var|DataflowVar, value: Expr) - | MatchShape(var: (Var|DataflowVar)?, pattern: [PrimExpr], value: Expr) + | MatchCast(var: (Var|DataflowVar)?, struct_info: StructInfo, value: Expr) # Relax programs are IRModules. Modules may bind global variables either to # Relax functions or TIR PrimFuncs (specified separately). @@ -99,13 +103,15 @@ Binding ::= Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) ``` +*The `derive_func` field of `FuncStructInfo` is a macro in the meta-language: Given a function call and the variable mapping context, return the `StructInfo` of the result. This field is used only at compile time for reasoning about the `StructInfo` of calls to `ExternFunc`s. + ## Expression Survey This specification provides a more detailed description of what each expression and type represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. 1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). 2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. -3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchShape` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. +3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchShape` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. 1. For `ExternFunc` nodes, the call will look up the registered `PackedFunc` by its global symbol and will call it with the given arguments (note that a TIR `PrimFunc` can be compiled into a `PackedFunc` and called using `ExternFunc` by defining a `global_symbol` attribute in the `PrimFunc`). «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» @@ -118,22 +124,23 @@ This specification provides a more detailed description of what each expression 7. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. 8. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). - 2. `MatchShape`s: The `value` expression is evaluated and the resulting shape is dynamically checked against the shape denoted by the `PrimExpr`s in the `pattern` field. - 1. If `value` evaluates to a tensor value, the pattern will be checked against the shape of the tensor; if it evaluates to a shape value, the pattern will be checked directly against the shape. - 2. Any shape dimension in the pattern that consists of a single new shape variable is treated as a binding: The variable is bound to the size of the corresponding dimension of the value being matched. - 3. If the shapes do not match, an error is triggered. If there is a variable provided, the value is bound to the `var` expression (if the variable is omitted, the shape check is performed and any shape variables are updated, but no new binding is introduced). Shape variables introduced in a `SeqExpr` are similarly scoped to the `SeqExpr`. + 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. + 1. The types must match: All `StructInfo` variants correspond to a type (`TensorStructInfo` to `DynTensorType`, `ShapeStructInfo` to `ShapeType`, etc.) and each type corresponds to a value (`DynTensorType` to a tensor value, `ShapeType` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: + 1. For comparing tensor values to `TensorStructInfo`, `ndim` must match the number of dimensions in the tensor value (unless `ndim` is -1) and `dtype` must match the datatype used (unless `dtype` is `Void`). If `shape` has been specified, the shape of the value must match that encoded by `shape`; if specified, `shape` must be either a `Var` already bound in the current scope or a `ShapeExpr`. + 2. For comparing shape values to `ShapeStructInfo`, `ndim` must match the number of dimensions in the shape value (unless `ndim` is -1). If `values` has been specified, the shape value must match that encoded by `values`. + 3. «For comparing closures (function values) to `FuncStructInfo`, it is necessary for the compiled program to track run-time structural information for closures, since it is not possible to introspect the closure; this subject will be discussed in further detail later in the document.» + 2. When comparing tensor values with `TensorStructInfo` or shape values with `ShapeStructInfo`, any member of `shape` in `TensorStructInfo` (if `shape` is a `ShapeExpr`) or `values` in `ShapeStructInfo` that consists of a single new (hitherto unbound) shape variable is treated as a binding: The shape variable is bound to the size of the corresponding dimension of the value being matched. + 3. If there is a variable provided, the value is bound to the `var` expression (if the variable is omitted, the structural check is performed and any shape variables are updated, but no new binding is introduced). Shape variables introduced in a `SeqExpr` are similarly scoped to the `SeqExpr`. The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. 9. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. 10. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. - The function can have shape annotations on the parameters and a return shape parameter. When the function is called, the annotations on parameters checked against the argument values in similar fashion to `MatchShape` and can introduce new shape variables that are scoped to the function. + The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. «A function mapped bound to a `GlobalVar` can have a `global_symbol` attribute defined to indicate that it should be externally linked externally (be accessible outside the `IRModule`). The absence of a `global_symbol` attribute on a function definition bound to a `GlobalVar` indicates that it is "private" and hence can be called only within the `IRModule`.» -11. `RuntimeDepShape` nodes are used to denote that a shape is unknown at compile time and must be deduced at run time. These nodes may appear only in shape annotations and have no run-time semantics of their own. - ## Purity and Dataflow Blocks A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. «In Relax, we conservatively assume that any function that calls an impure function is itself impure, though the attribute `force_pure` on a function can be used as an override (e.g., if a function creates a new tensor, mutates it, and returns it, that is still pure but does not satisfy the conservative rule).» @@ -142,10 +149,10 @@ Above, it is mentioned that `DataflowBlock`s are not allowed to contain construc There is one visible side effect that Relax permits inside otherwise "pure" functions, namely exiting the program with an error. This can arise in the following cases: -- Shape matching errors (from `MatchShape` or from implicit shape checks upon calling a Relax function) -- Errors raised by otherwise pure Relax operators or `PackedFunc`s, such as in `cast` (which dynamically checks types). Since the purity of operators or `PackedFunc`s must be manually registered, this means that it is permissible to register an operator or `PackedFunc` as being pure if its only side effect is issuing an error in some cases. +- Casting errors (from `MatchCast` or from implicit structural information checks upon calling a Relax function) +- Errors raised by otherwise pure Relax operators or `PackedFunc`s. Since the purity of operators or `PackedFunc`s must be manually registered, this means that it is permissible to register an operator or `PackedFunc` as being pure if its only side effect is issuing an error in some cases. -Even though an abnormal program exit is a visible side effect and removing or reordering it changes the observable semantics, it would be too great a restriction to prohibit error checking inside `DataflowBlock`s. Relax does not have any notion of exception handling, so the only consequence of a failed safety check can be exiting the program. It is permissible for the compiler to reorder, duplicate, or eliminate `MatchShape`, `cast`, or otherwise pure operations that have the potential of failing, provided that doing so does not change the value returned by the program or any other visible behavior. +Even though an abnormal program exit is a visible side effect and removing or reordering it changes the observable semantics, it would be too great a restriction to prohibit error checking inside `DataflowBlock`s. Relax does not have any notion of exception handling, so the only consequence of a failed safety check can be exiting the program. It is permissible for the compiler to reorder, duplicate, or eliminate `MatchCast`, or otherwise pure operations that have the potential of failing, provided that doing so does not change the value returned by the program or any other visible behavior. To indicate that an operator or `PackedFunc` that can abort with an error should *never* be reordered or removed by the compiler, it should *not* be marked as pure. However, this means that it cannot be used inside a `DataflowBlock`. @@ -166,7 +173,7 @@ The types in Relax correspond to the broad categories of the values given above: The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» -## Shape System Survey +## Structural Information System Survey In Relax, tensor shapes are not handled in the type system; each expression instead a has an associated shape expression. In many cases, these shape computations can allow for statically concluding that two shapes are the same and thus eliminate the need for dynamic checks via `MatchShape`. However, when shapes cannot be statically concluded to be the same, it may be necessary for there to be dynamic checks. The compiler is also free to make use of shape expressions for memory planning purposes. «Relax is "strongly shaped," meaning that if the compiler cannot conclude that shapes match in certain cases, an error will be issued and an explicit `MatchShape` will be required.» @@ -183,7 +190,7 @@ Oftentimes, compiler passes operate only on particular functions or add new func Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. - *Tensors* are n-dimensional arrays of scalar values (which can be integers of fixed bitwidths, floats of fixed bitwidths, or bools). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. -- *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations return no value (as may be the case in some `PackedFunc` or operator calls that have side effects). +- *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time type information* (RTTI) indicating their argument types and result type, in order to facilitate dynamic type checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTTI is left up to the compiler implementation to determine so long as the `cast` operator can verify the type of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values of type `ObjectType`. @@ -222,8 +229,8 @@ def func(x: Tensor) -> Tensor: # the true branch will be a nested SeqExpr and hence a new scope # this x will shadow the function parameter x x = R.const(1) - R.print(x) # prints 1 - # the inner x goes out of scope + R.print(x) # prints 1 + # the inner x goes out of scope else: R.print("not executed") R.print(x) # this x is the function parameter @@ -232,17 +239,17 @@ def func(x: Tensor) -> Tensor: # Normal Form -To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the type- and shape-checking rules for operators rely on macros (`FInferType` and `FInferShape`), _this means that the structure of the program can affect type and shape inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these type- and shape-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying type- or shape-checking. +To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the type- and structure-checking rules for operators rely on macros (`FInferShapeInfo`), _this means that the structure of the program can affect type and structure inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these type- and structure-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying type or structure checking. The normal form for Relax is very similar to ANF; differences will be noted. Here are the criteria required for a program to be in normal form: -1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, `RuntimeDepShape`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. +1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. 2. `SeqExpr`s may appear only in the following locations: 1. In the `body` field of a `Function` node. 2. In the `true_branch` and `false_branch` fields of `If` nodes. 3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. 4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. Empty `BindingBlock`s should be dropped. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. If all the `BindingBlock`s are empty, then the `blocks` field of the `SeqExpr` should be set to an empty list. -Programs that are parsed should be "normalized" before performing type-checking or shape-checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: +Programs that are parsed should be "normalized" before performing type checking or structure checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: 1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. 2. If the function body is already a `SeqExpr`, consolidate all `BindingBlock`s, then check if the `body` field of the `SeqExpr` is a leaf expression. If not, bind it to a new var in the final `BindingBlock` and replace the `SeqExpr` body with the new var. 3. If the function body is not a `SeqExpr`, then recurse down the body's AST, binding any nested non-leaf expressions to a var in the current scope (doing this process in breadth-first order from left to right will respect the evaluation order in the semantics). If the body itself is a non-leaf expression, finally bind it to a var and have the final `SeqExpr` return the new var. @@ -258,8 +265,8 @@ The following criteria apply to all programs (including before normalization): 1. `DataflowVar`s can be bound only inside `DataflowBlock`s. Additionally, a `DataflowVar` may not be used outside of the `DataflowBlock` in which it is defined. 2. A `Var` of any kind used in the program must be either a function parameter or appear on the LHS of a binding exactly once. In the binding where a `Var` is defined, the same `Var` is permitted to occur in the RHS of the binding only if the binding is defining a function (i.e., local functions are permitted to be recursive). 3. A `Var` of any kind may not appear before it is bound. Namely, if a `Var` is bound in a `BindingBlock` in a `SeqExpr`, that `Var` may not appear in bindings that precede the one where it appears on the LHS. -4. «A return shape annotation for a function is not allowed to use any shape variables that are not in scope at the function definition. That is, the only shape variables that can appear on the return shape annotation are those defined in the outer scope or those introduced in the argument shape annotations.» -5. In each function, `PrimExpr` variables (shape variables) similarly may not appear in `ShapeExpr`s or shape annotations before the shape variables are bound (either in function signatures or `MatchShape` bindings). A shape variable is bound only when it appears in a dimension by itself (for example, a dimension consisting of `x` will bind `x`; however, `2*x` is not a binding and is considered an error if `x` has not yet been bound) in a `MatchShape` node or a function argument shape annotation. +4. «A return structural annotation for a function is not allowed to use any shape variables that are not in scope at the function definition. That is, the only shape variables that can appear on the return structural annotation are those defined in the outer scope or those introduced in the argument structural annotations.» +5. In each function, `PrimExpr` variables (shape variables) similarly may not appear in `ShapeExpr`s or shape annotations before the shape variables are bound (either in function signatures or `MatchCast` bindings). A shape variable is bound only when it appears in a dimension by itself (for example, a dimension consisting of `x` will bind `x`; however, `2*x` is not a binding and is considered an error if `x` has not yet been bound) in a `MatchCast` node or a function argument shape annotation. 6. The following constructs are not permitted to occur inside `DataflowBlock`s, which must be side effect– and control flow–free: 1. Recursive calls to the current function 2. Calls to a global function that is mutually recursive with the current function @@ -267,22 +274,25 @@ The following criteria apply to all programs (including before normalization): «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during type checking.» -7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return type annotation is *required*. [TODO: Do we also require a return shape annotation in such cases?]» +7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» 8. `Op` nodes may appear only as the `op` argument to `Call` nodes. 9. `ExternFunc` expressions may appear only as the `op` argument to `Call` nodes. -10. The `type_args` field is used only for type checking calls of `ExternFunc`s and `Op`s. Calls to `ExternFunc`s must have exactly one type argument, indicating the return type. Calls to `Op`s may use `type_args` as they wish. No other calls may have a non-empty `type_args`. +10. The `type_args` field is used only for type checking calls of `ExternFunc`s and `Op`s. No other calls may have a non-empty `type_args`. 11. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. 12. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. 13. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» 14. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» -15. «Any `PackedFunc` or operator called in a shape annotation or `shape_` expression must be pure and be annotated as such.» -16. The node `RuntimeDepShape` may appear only in shape annotations and `shape_` expressions. It has no defined semantics at run time. +15. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. +16. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. +17. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. # Types in Relax -Relax presently has six types, defined in the implementation in `python/tvm/relax/ty.py` and `include/tvm/relax/type.h`: +Relax's type system is intended to enforce strong guarantees that values are passed correctly between expressions. The design emphasis is on simplicity, aiming to leave more complex analysis to the structural information. + +Relax presently has six types, corresponding to the values in the language: 1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. 2. `ShapeType`, referring to shape values. The number of dimensions in the shape as given as `ndim` and is optional (using -1 for `ndim` indicates an unknown number of dimensions). @@ -291,6 +301,32 @@ Relax presently has six types, defined in the implementation in `python/tvm/rela 5. `PackedFuncType`, referring to the type of PackedFunctions. 6. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. +## Erasing Structural Information into Types + +Several type-checking rules rely on structural annotations or rules for defining the structural information for a call to an `Op` or `PackedFunc`. In general, types are simpler than structural information (to facilitate more precise reasoning). Structural information can be convereted into a type as follows (in pseudocode): + +```python +def erase_struct_info(si: StructInfo) -> Type: + if si is TensorStructInfo: + return DynTensorType(ndim=si.ndim, dtype=si.dtype) + if si is ShapeStructInfo: + return ShapeType(ndim=si.ndim) + if si is TupleStructInfo: + return TupleType(fields=[erase_struct_info(field) for field in si.fields]) + if si is FuncStructInfo: + # this should be the case only for packed funcs + if si.params is not specified: + return PackedFuncType() + return FuncType( + arg_types=[erase_struct_info(arg_type) for arg_type in si.params], + ret_type=erase_struct_info(si.ret) + pure=False) # TODO: This suggests we should either handle purity + # in StructInfo entirely (and not make it part of the type) + # or include it in both StructInfo and the type system + # only remaining case is ObjectStructInfo + return ObjectType() +``` + ## Subtyping Relax implements subtyping, which means that members of types can be accepted where members of their supertypes are accepted. We will denote the subtyping relationship as `T1 <: T2`, indicating that `T1` is a subtype of `T2`. For example. if `T1 <: T2` and some function expects an argument of type `T2`, then passing a member of type `T1` to that function is permitted; passing a member of type `T2` as an argument to a function that expects type `T1` for that argument is *not* permitted—the value would have to be dynamically cast to `T1` using the `cast` operator. @@ -416,7 +452,7 @@ def find_lub(T1 : Type, T2 : Type) -> Type: For two types `T1` and `T2`, if `T1 <: T2`, then a value of type `T1` can be passed anywhere a value of type `T2` is expected without any need for type conversions or dynamic checks. -*However*, if `T1 <: T2`, then passing a value of type `T2` where `T1` is expected can only be done if there has been some kind of dynamic check or conversion of that value. «Relax is *strongly* *typed*, meaning that the compiler will give an error in this situation and require an explicit conversion via the `cast` operator, which inspects the value's run-time representation and exits the program with an error message if the value is not a subtype of T1.» +*However*, if `T1 <: T2`, then passing a value of type `T2` where `T1` is expected can only be done if there has been some kind of dynamic check or conversion of that value. «Relax is *strongly typed*, meaning that the compiler will give an error in this situation and require an explicit conversion via a `MatchCast` node, which inspects the value's run-time representation.» If `T1` is not a subtype of `T2` and `T2` is not a subtype of `T1`, then it is always a type error to pass a value of either type where a value of the other is expected (no member of either type can be a member of the other). @@ -430,339 +466,452 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 2. Given a variable `v`, if `v` is in `Γ`, then we return `Γ[v]`. Otherwise, it is an error. 3. Given a constant expression `Constant(data)`, the type is `DynTensorType(ndim=n, dtype=d)` where `n` is the number of dimensions in `data` and `d` is its data type (recall that `data` is an `NDArray` literal). 4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType(n)`, where `n` is the length of `dims`. -5. The type of a `RuntimeDepShape` expression is `ShapeType(-1)`. -6. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. -7. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. -8. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT])`: - 1. If `op` is a Relax `Op` node, then we look up its registered `FInferType` property. `FInferType` is a macro that takes in the `Call` node and produces a type. We return the type `op.FInferType(Call(op, [a1, ..., an], type_args=[aT]))`. The implementation of `FInferType` is free to throw errors. - 2. If `op` is `ExternFunc`, then use the sole member of `type_args` (calls to `ExternFunc`s are required to have exactly one `type_args` member) `aT` as the return type. Packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function itself to do any validation. +5. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. +6. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. +7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT1, aT2, ..., aTn])`: + 1. If `op` is a Relax `Op` node, then we look up its registered `FInferStructInfo` property. `FInferStructInfo` is a macro that takes in the `Call` node and produces structural information. Invoke `op.FInferStructInfo(Call(op, [a1, ..., an], type_args=[aT1, aT2, ..., aTn]))` and convert the result to a type using the `erase_struct_info` procedure defined above. The implementation of `FInferStructInfo` is free to throw errors. + 2. If `op` is `ExternFunc`, note that packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function's implementation to do any validation at run time. However, the type system uses the `type_args` field to determine the result type as follows: + 1. If there are no `type_args`, the resulting type is `ObjectType()`. + 2. If there is exactly one member of `type_args`, use that as the return type. + 3. If there are multiple members of `type_args`, then the type is `TupleType(fields=[aT1, aT2, ..., aTn])`. 3. Otherwise, check the types of the subexpressions, left to right. Suppose `op` has the type `Tf`, `a1` has type `T1`, …, and `an` has type `Tn`. If `Tf` is not a function type with exactly `n` arguments, we consider it a type error and require an explicit cast. Suppose `Tf` has argument types `T1'`, `T2'`, …, `Tn'`. Consider it a type error if any of the following does not hold: `T1 <: T1'`, `T2 <: T2'`, …, or `Tn <: Tn'`. Then the return type is `Tf.ret_type`. -9. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. -10. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» -11. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. +8. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. +9. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» +10. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» - 2. For each binding `VarBinding(v : T, e)` in the current block, where `T` is the optional annotation on `v`, check the type of `e` and suppose it is `T'`. If `T` has been omitted, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, the type annotation `T` is not optional and we add `v` to `Γ` before type-checking the function body; see the rule for `Function` nodes.) - 3. For each `MatchShape(v: T, e, shape_pattern)`, where `T` is an optional type annotation, let the checked type of `e` be `T'`. - 1. If `T'` is `ShapeType`: - 1. Emit an error if `T` is not a supertype of `ShapeType`. - 2. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. - 3. Add `v` to `Γ` with type `T`. - 2. If `T'` is `DynTensorType`: - 1. If the `ndim` of `T'` is `n` ≥ 0, then emit an error if the length of the given `shape_pattern` is not `n`. Let the datatype of `T'` be `d`. - 2. If `T` is not a supertype of `DynTensorType(ndim=len(shape_pattern), dtype=d)`, then emit an error. If `T` is a subtype of that type, emit an error and request a cast. - 3. Add `v` to `Γ` with type `T`. - 3. If `T'` is `ObjectType`, then the only type we can conclude for `v` is `ObjectType`. If `T` is not `ObjectType`, emit an error and request a cast. - 4. If `T'` is `TupleType` or `FuncType`, emit a type error. + 2. For each binding `VarBinding(v, e)` in the current block, check the type of `e` and suppose it is `T'`. If `v` has a structural annotation, then let `T` be the corresponding type (via the `erase_struct_info` procedure above). If there is no annotation, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and otherwise add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, we require `v` to have a structural annotation add `v` to `Γ` with its annotated type before type-checking the function body; see the rule for `Function` nodes.) + 3. For each `MatchCast(v, e, struct_info)`: + 1. Check the type of `e` and let it be `T'`. + 2. Let `T''` be the type corresponding to `struct_info` (via the `erase_struct_info` procedure). + 3. Emit a warning if `T'` is not a supertype of `T''` and `T''` is also not a supertype of `T'`; this indicates that the cast is _guaranteed_ to fail at run time. + 4. If `v` has been defined and it has a structural annotation, then let `T` be its corresponding type (via `erase_struct_info`). + 5. If `T` has been defined, then emit an error if `T` is not a supertype of `T''`. + 6. If `v` has been defined and does not have a structural annotation, then add `v` to `Γ` with type `T''`. If `T` has also been defined, then add `v` to `Γ` with type `T`. 2. If the current block is a `DataflowBlock`, remove any `DataflowVar`s from `Γ` after we have finished processing the bindings in that block. 3. Finally, the type of the `SeqExpr` is the type of `body` checked under the current `Γ`. Afterwards, remove all bindings added from the `b0`, `b1`, …, `bn` from `Γ`. -12. Let us consider a function `Function(v1 : T1, v2 : T2, ..., vn : Tn, attrs=a) -> Tr: body`. - 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, where `p` is `True` if a `pure` attribute is included and `False` otherwise. Remove `fv` from `Γ` before returning. +11. Let us consider a function `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`. All of the vars are required to have structural annotations; let `T1` be the type corresponding to `v1`'s annotation (via `erase_struct_info`), `T2` be the type corresponding to `v2`'s annotation, etc.. + 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, «where `p` is `True` if a `pure` attribute is included and `False` otherwise». Remove `fv` from `Γ` before returning. 2. Add `v1` to `Γ` with type `T1`, `v2` with type `T2`, …, and `tn` with type `Tn`. Recursively type check `body` in this new context: 1. «Determining purity: If `body` contains any call to a function whose return type does not specify that it is pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not specify the `pure` attribute, then we consider the function to be (potentially) impure. If all calls are to functions whose return type specifies purity or that include the `pure` attribute on the call or `Op`, then the function is treated as pure.» 2. «Suppose the purity defined in the previous step is `p'`. Suppose the annotated function purity (in the attributes) is `p`. If `p'` is false while `p` is true, then it is a type error; if `p` was omitted, use `p'` for `p`.» 3. «If the function has the attribute "`force_pure`," then consider `p` to be true, even if the check above judged the function not to be pure. The compiler may emit a warning in this situation.» - 4. Suppose the result of type-checking `body` is `Tr'`. If the current function is not recursive or a mutually recursive global function and `Tr` was omitted, consider the function to have type `FuncType([T1, T2, ..., Tn], Tr', pure=p)`. If `Tr' <: Tr`, then we consider the function to have type `FuncType([T1, T2, ..., Tn], Tr, pure=p)`. «If `Tr <: Tr'`, return an error and require an explicit cast in the function body.» If `Tr'` is not a subtype of `Tr` and `Tr` is also not a subtype of `Tr'` (meaning a dynamic cast cannot succeed), this is an error. + 4. Suppose the result of type-checking `body` is `Tr'`. If the current function is not recursive or a mutually recursive global function and `ret_struct_info` is undefined, consider the function to have type `FuncType([T1, T2, ..., Tn], Tr', pure=p)`. If `ret_struct_info` is defined, then let `Tr` be `erase_struct_info(ret_struct_info)`. If `Tr' <: Tr`, then we consider the function to have type `FuncType([T1, T2, ..., Tn], Tr, pure=p)`. «If `Tr <: Tr'`, return an error and require an explicit cast in the function body.» If `Tr'` is not a subtype of `Tr` and `Tr` is also not a subtype of `Tr'` (meaning a dynamic cast cannot succeed), this is an error. 5. Remove `v1`, `v2`, …, and `vn` from `Γ` before returning. -# Shapes in Relax +# Structural Information in Relax -In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. In Relax, to allow for greater flexibility for variable-shape tensors and make it easier to implement new operators, shapes can be checked at run time. Though every expression in Relax has a shape associated with it just as expressions also have types, there is no requirement that the shape be expressed at compile time. Instead, the compiler merely requires that an expression's shape define *a way* to compute a fully specified shape at run time. Users have the ability to make use of shape variables and arithmetic expressions to encode a wide variety of shape constraints that can be checked dynamically. +In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. While this allows Relay's type system to make strong guarantees about tensor shapes, it results in greater complexity in type checking and makes it difficult to implement new operators or handle cases like tensors with symbolic shapes. -Nevertheless, in many cases, these shapes can be analyzed at compile time (particularly when they are consist of constants or deducible variables) to facilitate compile-time optimization much like is possible with Relay or TIR. Through constant propagation, function inlining, and other partial evaluation–like transformations, we can potentially eliminate many more dynamic checks by allowing some shape computations to be simplified at compile time. +Relax instead aims to facilitate analysis of more complex properties like shapes by tracking _structural information_ pertaining, encoding as much analysis as is feasible at compile-time in a _"best-effort"_ fashion. Anything that cannot be proved statically can instead be checked at run time. Each Relax expression has structural information associated with it just as it has a type. Indeed, the structural information for each expression can be simplified into a type (recall [the procedure for doing so](#erasing-structural-information-into-types)), so the structural information for an expression can be thought of as an extended type that is checked in a less precise manner. The best-effort nature of the structural system in Relax means that the analysis may detect _some_ errors at compile time and report them, but it may give warnings when it _cannot_ draw conclusions, perhaps suggesting that dynamic checks via `MatchCast` should be inserted. Note that the precision of the static analysis can potentially be improved by some compile-time optimizations like constant propagation, function inlining, and other partial evaluation–like transformations. -## Defining Shape Computations +Tensor shapes are the primary motivation for including structural information in Relax, as shape information is particularly important for memory planning. Relax's structural information system uses expressions to encode tensor shapes, which allows for using shape variables and arithmetic expressions to encode a rich variety of shape constraints. Note, however, that the structural system could potentially be extended to encode and analyze further information, like tensor sparsity or density. -In Relax, each expression has an associated shape computation, which defines how that expression's shape can be computed based on the shapes of its subexpressions. We will refer to this computation as `shape_`, as that is what it is called in the implementation. This essentially serves as a mechanism for propagating shape annotations on variable bindings and function definitions to other expressions and enable more compile-time analysis of shapes. In particular, `shape_` is useful for memory planning. These computations can also be used to simplify shape checking and eliminate many dynamic checks. +## Defining Structural Information -### Expressing Dimensions +As with types, the structural information in Relax corresponds to the values in the language: +* `TensorStructInfo` describes tensor values. Like in `DynTensorType`, the `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` whose type is `ShapeType`. If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation (that returns a shape). which can be useful for memory planning. +* `ShapeStructInfo` describes shape values. Like `ShapeType`, it has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. +* `TupleStructInfo` describes tuple values, namely by giving the structural information for each of the tuple's members via `fields`. +* `FuncStructInfo` describes closure values or `PackedFunc`s. There are two ways in which to specify `FuncStructInfo`: + 1. By specifying `params` and `ret` (for closures). `params` gives the structural information corresponding to each of the function's parameters and `ret` gives the structural information corresponding to the result of calling the function. In this case, the `derive_func` field is left undefined. + 2. By giving a `derive_func` macro (for `PackedFunc`s). The `derive_func` macro is takes a call to the corresponding `PackedFunc` and the variable mapping context and returns the `StructInfo` of the result. In this case, the `params` field is left undefined and the `ret` field is ignored. +* `ObjectStructInfo` describes arbitrary object values. -A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds to a dimension. The use of TIR `PrimExpr`s for shape dimension allows shape computations to express complex constraints that include variables and integer arithmetic expressions in addition to just constant dimensions. +While these categories correspond closely to types, they serve as a mechanism for propagating further information (especially as given in shape annotations in variable bindings) throughout the program and facilitating more static analysis. + +### Expressing Shape Dimensions + +A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds to a dimension. The use of TIR `PrimExpr`s for shape dimensions allows shape computations to express complex constraints that include variables and integer arithmetic expressions in addition to just constant dimensions. **Scope of Shape Variables** -Shape variables can be introduced in two places in a Relax program: In a function signature, where they may be included with the argument shapes and return shape annotations, or in `MatchShape` bindings. Shape variables used in the function signature are scoped to the entire function in which they appear. Shape variables used in `MatchShape` bindings are scoped only to the `SeqExpr` in which they appear. +New shape variables can be bound in two places in a Relax program: In `TensorStructInfo` or `ShapeStructInfo` annotations on function parameters or as the `struct_info` parameter in a `MatchCast` binding. Shape variables used in the function signature are scoped to the entire function in which they appear (including in the return structural annotation). Shape variables used in `MatchShape` bindings are scoped only to the `SeqExpr` in which they appear. **Informal Semantics of `PrimExpr`s for Dimensions** -1. Shape variables can be bound to a value exactly once: at the start of a function for shape annotations on function arguments, in `MatchShape` bindings, or before a function returns (for shape variables on the return type). In particular, matching a `PrimExpr` consisting only of an uninitialized shape variable is treated as its binding (see below on `MatchShape`). After a shape variable has been bound for the first time, future uses of it will refer to the same value. -2. It is not legal to use a shape var that has not yet been bound. This results in an error at run time, though most cases can be detected at compile time. +1. Shape variables can be bound to a value exactly once, either at the start of a function for shape annotations on function arguments or in `MatchCast` bindings. In particular, matching a `PrimExpr` consisting only of an uninitialized shape variable is treated as its binding (see below on `MatchCast`). After a shape variable has been bound for the first time, future uses of it will refer to the same value. +2. It is not legal to use a shape var that has not yet been bound. This results in an error at compile time. 3. «Local functions will "capture" defined shape variables from the parent scope with their present values in the resulting closure.» 4. If all variables in the `PrimExpr` are defined, `PrimExpr` arithmetic will generally be evaluated according to the semantics of TIR. -### Evaluating `MatchShape` +### Evaluating `MatchCast` -`MatchShape` allows for binding shape variables in Relax. It can be used with either tensor values or shape values, and in both cases the evaluation of the `PrimExpr`s proceeds similarly. +Because structural information is checked in a "best-effort" fashion, it is not always possible for the compiler to statically draw conclusions about all details of a given value's structural information. Hence, `MatchCast` allows for checking this information at run time, similar to a typecast. However, `MatchCast` also allows for binding shape variables in the process of pattern matching, hence the "match" portion of its name. -1. Evaluating `MatchShape(v, t, s)`, where `t` is a tensor value and `s` is a list of `PrimExpr`s corresponding to shape dimensions: - 1. Suppose `s` is `(p1, p2, ..., pn)` , where each variables is a `PrimExpr`. We evaluate `p1`, then `p2`, and so, in that order according to the following rules (corresponding to the `i`th dimension): - 1. If the current `PrimExpr` consists only of an uninitialized shape variable, we bind the shape variable in that scope to the concrete value of the `i`th dimension of the value of `t`. - 2. Evaluate the current `PrimExpr` and compare it to the concrete value of the `i`th dimension of `t`. Raise an error if they do not match. - 2. If `v` is provided, bind `t` to `v` (see the general semantics for how that should be implemented). -2. Evaluating `MatchShape(v, S, s)`, where `S` is a shape value proceeds identically to the above, except the `PrimExpr`s are compared to the `i`th element of `S`. +This section describes the run-time checking performed by `MatchCast(var, value, struct_info)`, for each combination of value and structural annotation (if `var` is defined, then `value` will be bound to `var` as discussed in the [general section on semantics](#detailed-semantics)). If any check given below fails, an error is raised by the `MatchCast`. -### General Shape Computation Grammar +1. If `struct_info` is `ObjectStructInfo`, then no additional check is performed. All values in Relax are objects. +2. If `struct_info` is `TensorStructInfo(ndim, dtype, shape)`, then check that `value` is a tensor value, that it has a rank of `ndim` (if `ndim` is not -1), a datatype of `dtype` (if `dtype` is not `Void`). If `shape` is defined, consider the following cases: + 1. If `shape` is a `Var`, then check that the concrete shape of `value` matches the value bound to the `Var`. + 2. If `shape` is a `ShapeExpr`, then compare the fields of the `ShapeExpr` to the concrete shape of `value`, dimension by dimension (comparing the `i`th field of the `ShapeExpr` to the `i`th dimension of the shape of `value`). Give an error if the number of the dimensions does not match the number of fields in the `ShapeExpr`. + 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. + 2. Otherwise, evaluate the field of the `ShapeExpr` and ensure that it matches the concrete value of the dimension. +3. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): + 1. If the `i`th member of `values` consists of only an unbound shape variable, then bind that variable to the `i`th field of the the concrete shape value. + 2. Otherwise, evaluate the `i`th member of `values` and check that it is equal to teh `i`th field of the concrete shape value. +4. If `struct_info` is `TupleStructInfo(fields)`, then check that `value` is a tuple value with `n` fields, where `n` is the length of `fields`. Also recursively check the `i`th field of the tuple value against the `i`th member of `fields`. +5. If `struct_info` is `FuncStructInfo(params, ret, derive_func)`, then if `params` is defined, check that `value` is a closure value; if `derive_func` is defined, check that `value` is a `PackedFunc`. No further validation may be done on a `PackedFunc`. «If `value` is a closure value, then it can contain run-time structural information indicating the structural information of its intended arguments and return value that can be compared against `params` and `ret`.» -Shape computations can consist of the following expressions, which are a subset of general Relax `Expr`s: +### Checking Structural Information at the Start and End of a Function -``` -ShapeCompExpr ::= ShapeExpr(dims: [PrimExpr]) - | RuntimeDepShape() - | «Tuple(fields: [ShapeCompExpr])» - | Call(op: Op|ExternFunc, args: [Var|Constant]) - | «TupleGetItem(tuple_value: ShapeCompExpr, index: int)» +«Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchCast`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchCast` bindings. Suppose a function has the following signature, where the `Si` are structural annotations: + +```python +def f(arg1 : S1, arg2 : S2, ..., argn : Sn) -> Sr: + return body ``` -The shape expressions can be interpreted as follows: +This can be treated as a macro that expands to -- `ShapeExpr` describes the shape of a tensor as a list of dimensions -- «`Tuple` describes the shapes of each member of a tuple» -- «`TupleGetItem` describes the shape of a member of a tuple» -- `Call` describes the shape of a function (or operator) call return value in terms of its arguments -- `RuntimeDepShape` describes shapes that are unknown at compile time (like when a shape annotation is omitted) or the shapes of values that don't have shapes (like shapes themselves, paradoxically: they *are* shapes but do not *have* shapes). +```python +def f(arg1, arg2, ..., argn): + MatchCast(arg1, S1) + MatchCast(arg2, S2) + ... + MatchCast(argn, Sn) + ret_var = body + MatchCast(ret_var, Sr) + return ret_var +``` +» -The `PrimExpr`s in a `ShapeCompExpr` can reference the same shape variables as in shape annotations, with the same semantics. +## Deriving the Structural Information for Each Expression -**Restrictions** +For each expression type, we can recursively build up the structural information associated with the expression. -Shape computations are allowed to include calls to operators and even `PackedFunc`s, but these operators and `PackedFunc`s *must* be pure. Shape computations are primarily used for memory planning and it is at the compiler's discretion when, if ever, to evaluate them (except as described below), hence they must not have side effects. +### Auxiliary Procedures -**Shape Annotations** +**`derive_func` for `FuncStructInfo`** -For shape annotations, we use `ShapeCompExpr` as the grammar, as with `shape_` expressions. `ShapeExpr` is used to annotate shapes of tensor values, «`Tuple` is used to annotate the shapes of tuple values», and `RuntimeDepShape` is used to indicate annotations that have been omitted or shapes that cannot be known at compile time (like the shapes of tensors whose rank is unknown at compile time). `Call` is used to annotate the shapes of calls to operators and «`TupleGetItem` annotates the shapes of tuple indices.» +There are two special `derive_func` values built into the compiler that are used for checking the structural information of `PackedFunc`s. -«For example, suppose we have a tuple where some fields are tensors like the following: +The first is `default_derive`, giving a simple way to determine the resulting structural information of a `PackedFunc` from its type arguments. `default_derive` takes one argument that is a `Call` node and is defined as follows: +1. Suppose its call node argument is `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`. +2. If `type_args` is of length 0, then return `ObjectStructInfo()`. +3. If `type_args` is of length 1, then return `wrap_type(aT1)`. +4. If `type_args` is of a greater length than 1, then return `TupleStructInfo(fields=[wrap_type(aT1), wrap_type(aT2), ..., wrap_type(aTn)])`. -```python -x : Tuple(Tensor((m, n), "int32"), Tuple(), Tensor((), "int32"), Tensor(_, "int32")) = ... -``` +The second is `empty_derive`, which is the weakest possible derivation. It simply returns `ObjectStructInfo` regardless of its argument. This is used for worst-case deducation of `StructInfo` for a `PackedFunc`. + +**Wrapping Types** -It has the shape annotation +For deriving the structural information for a `PackedFunc` call, the type arguments are converted into structural information. This is a straightforward procedure, given here in pseudocode: ```python -Tuple([ShapeExpr([m, n]), Tuple([]), ShapeExpr([]), RuntimeDepShape]) +def wrap_type(t: Type) -> StructInfo: + if t is ObjectType: + return ObjectStructInfo() + if t is PackedFuncType: + # leave params undefined; see default_derive below + return FuncStructInfo(ret=ObjectStructInfo(), derive_func=default_derive) + if t is FuncType: + # leave derive_func undefined + return FuncStructInfo( + params=[wrap_type(arg_type) for arg_type in t.arg_types], + ret=wrap_type(t.ret_type) + ) + if t is TupleType: + return TupleStructInfo(fields=[wrap_type(field) for field in t.fields]) + if t is ShapeType: + # leave values undefined + return ShapeStructInfo(ndim=t.ndim) + if t is DynTensorType: + # leave shape undefined + return TensorStructInfo(ndim=t.ndim, dtype=t.dtype) ``` -» - -Note that it is [a well-formedness requirement](https://www.notion.so/Informal-Relax-Language-Specification-d1fdedb8fae84f0d82b9f880f25e7370) that if any field in a type has a `ShapeExpr` annotation, it must be a `DynTensorType` with an `ndim` matching the number of dimensions in the `ShapeExpr`. For example, in the above function signatures, the `ndim` in the type annotations must be 2. -### «Assigning Shape Variables at the Start and End of a Function» +**Erasing Out-of-Scope Information** -«Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchShape`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchShape` bindings. Suppose a function has the following signature, where the `Ti` are type annotation and the `Si` are shape annotations: +When returning a value from an inner scope to an outer scope (namely, the `body` field of a `SeqExpr`, which may use variables defined in the binding blocks, and the `body` field of a `Function`, which may use variables defined in the function body), it may be possible for the derived `TensorStructInfo` or `ShapeStructInfo` to contain Relax variables or shape vars that have gone out of scope. We defined a procedure to check for any of these out-of-scope variables and weaken the structural information not to include it. The procedure is defined below, in pseudocode: ```python -def f(arg1 : (T1, S1), arg2 : (T2, S2), ..., argn : (Tn, Sn)) -> (Tr, Sr): - return body +def erase_to_well_defined( + s: StructInfo, + var_scope: set of Relax vars in current scope, + shape_var_scope: set of shape vars in current scope) + -> StructInfo: + + if s is ObjectStructInfo: + return s + if s is TensorStructInfo: + if s.shape is defined: + if (s.shape is a Relax var that is not in var_scope + or s.shape is a ShapeExpr that contains any shape var not in shape_var_scope): + # leave shape undefined + return TensorStructInfo(ndim=s.ndim, dtype=s.dtype) + else: + return s + else: + return s + if s is ShapeStructInfo: + if (s.values is defined + and any member of s.values contains a shape var not in shape_var_scope): + # leave values undefined + return ShapeStructInfo(ndim=s.ndim) + if s is TupleStructInfo: + return TupleStructInfo( + fields=[ + erase_to_well_defined(field, var_scope, shape_var_scope) + for field in s.fields + ] + ) + if s is FuncStructInfo: + if params is defined: + return FuncStructInfo( + params=[ + erase_to_well_defined(param, var_scope, shape_var_scope) + for param in s.params + ], + ret=erase_to_well_defined(s.ret, var_scope, shape_var_scope) + ) + else: + return FuncStructInfo( + ret=erase_to_well_defined(s.ret, var_scope, shape_var_scope), + derive_func=s.derive_func + ) ``` -This can be treated as a macro that expands to +**Substituting Free Shape Variables in `FuncStructInfo`** -```python -def f(arg1 : T1, arg2 : T2, ..., argn : Tn) -> Tr: - check_annotation(arg1, T1, S1) - check_annotation(arg2, T2, S2) - ... - check_annotation(argn, Tn, Sn) - ret_var = body - check_annotation(ret_var, Tr, Sr) - return ret_var -``` -» +The `params` field of `FuncStructInfo` can contain free shape variables, indicating that these shape variables are bound to the corresponding dimensions of the argument when the function is called. For checking the compatibility of two function types, we can construct a mapping of shape variables and then substitute shape variables according to the mapping. The mapping can be constructed by doing a simple structural match, as when checking alpha-equivalence. -Because `MatchShape` is defined only for tensor and shape values, we must use a macro to handle other possible types that may be passed into a function, given here in pseudocode: +For clarity, additional detail on how the mapping should be constructed is given here in pseudocode: ```python -def check_annotation(e: Expr, s: ShapeCompExpr) -> Expr: - if s is a ShapeExpr: - tmp = fresh_var() - # type checking should ensure that e is always a tensor - return SeqExpr( - [BindingBlock([MatchShape(tmp, e, s.dims)])], - tmp - ) - «else if s is a Tuple: - # type checking should ensure that e is always a tuple and the lengths match - shapes = s.fields - tmp = fresh_var() - return SeqExpr( - [BindingBlock([ - VarBinding(tmp, e), - # recursive in case we have nested tuples - VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, 0), shapes[0])), - VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, 1), shapes[1])), - ..., - VarBinding(fresh_var(), check_annotation(TupleGetItem(tmp, n-1), shapes[n-1])) - ])], tmp - )» - else if s is a Call: - tmp = fresh_var() - return SeqExpr( - [BindingBlock([ - VarBinding(tmp, e), - # completely dynamic check that does not assign shape vars. - VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) - ])], tmp - ) - «else if s is TupleGetItem: - val = s.tuple_value - if val is Tuple: - return check_annotation(e, val.fields[s.index]) - # otherwise, evaluate it - return SeqExpr( - [BindingBlock([ - VarBinding(tmp, e), - VarBinding(fresh_var(), dynamically_check_shapes(shape_of(tmp), s)) - ])], tmp - )» - else if s is RuntimeDepShape: - # no need to check - return e +def get_shape_var_mapping(S1: StructInfo, S2: StructInfo) -> {tir::Var, PrimExpr}: + if S1 and S2 are not the same type: + return {} + if S1 and S2 are both TupleStructInfo: + if S1.fields and S2.fields don't have the same length: + return {} + ret = {} + for 0 <= i < length of S1.fields: + ret = union of ret and get_shape_var_mapping(S1.fields[i], S2.fields[i]) + return ret + if S1 and S2 are both FuncStructInfo: + if S1 and S2 both have params defined and the params are the same length: + ret = {} + for 0 <= i < length of S1.params: + ret = union of ret and get_shape_var_mapping(S1.params[i], S2.params[i]) + # don't look at the return field; it's not a binding position + return ret + else: + return {} + if S1 and S2 are both ShapeStructInfo: + if S1 and S2 both have values defined and the values are the same length: + ret = {} + for 0 <= i < length of S1.values: + if S1.values[i] is an unbound shape variable: + ret[S1.values[i]] = S1.values[i] + return ret + else: + return {} + if S1 and S2 are both TensorStructInfo: + if ( + S1 and S2 both have shape defined + and the shapes are both ShapeExprs + and their values fields are the same length + ): + ret = {} + for 0 <= i < length of S1.shape.values: + if S1.shape.values[i] is an unbound shape variable: + ret[S1.shape.values[i]] = S2.shape.values[i] + return ret + else: + return {} ``` -### Evaluating Shape Expressions - -Every shape expression in the program (`shape_`) is associated with a program expression. Other than in the above procedure for checking function parameter shapes and the return shape, the specification does not guarantee that any `shape_` expression will ever be evaluated or how many times it may be evaluated; `shape_` is intended primarily for the benefit of memory planning. Hence, all `shape_` expressions must be pure and must be guaranteed to terminate. The `shape_` for a given expression `e` is intended to be evaluated *before* `e`. - -Shape expressions follow the same evaluation rules as general program expressions. In particular, shape functions are permitted to reference any variable that is in scope at the point of its associated expression; i.e., when evaluated, they form closures that capture any free variables (Relax variables and shape variables) referenced in their body. The `RuntimeDepShape` expression has no semantics at run time and indicates a shape that cannot be predicted in advance. If a `RuntimeDepShape` is encountered at any point while dynamically checking a shape match (see the `check_annotation` procedure above), it should "short-circuit" the match and cause the match to succeed immediately. - -### Building Up `shape_` for Each Expression - -For each expression type, we can recursively build up an associated `shape_` expression according to the following rules: - -1. For `Constant(value)`, the `shape_` expression is a `ShapeExpr` corresponding to the concrete shape of `value`. For example, for `Constant(1)`, `shape_` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape_` is `ShapeExpr([2])`. -2. «For `Tuple(fields)`, `shape_` can be defined as `Tuple([field.shape_ for field in fields])`.» -3. For `ShapeExpr`s, `shape_` is `RuntimeDepShape`. -4. `RuntimeDepShape` expressions should appear only in shape expressions; their `shape_` is not defined. -5. For `If(cond, true_branch, false_branch)`, we compare the `shape_` of `true_branch` and `false_branch`. If these can be proven equivalent (by a method that the compiler implementation is free to determine), then the `If` node's `shape_` is that shape. If they do not match, then we set it to `RuntimeDepShape`. -6. For `SeqExpr`, we set the `shape_` to be the `shape_` of the body expression. The `shape_` must respect the scoping rules for the `SeqExpr`: If the `shape_` of the body expression contains shape variables not defined in the outer scope (i.e., shape variables that are scoped to the `SeqExpr` only) or if the `shape_` contains any `Var`s or `DataflowVar`s scoped to the `SeqExpr`, use `RuntimeDepShape` as the shape. -7. For handling variable bindings: - 1. For the arguments to a function, set the `shape_` to the annotated shape. If the annotation is omitted, use `RuntimeDepShape`. - 2. In the general `VarBinding(v, e)`, if `v` does not have a shape annotation or the annotation is `RuntimeDepShape`, then we set the `shape_` of `v` to the `shape_` of `e`. If `v` has a shape annotation, then if the `shape_` of `e` can be proven equivalent to the shape annotation, use the shape annotation for the `shape_` of `v`. «Otherwise, give an error and require an explicit `MatchShape`.» - - It is up to the compiler implementation to decide what method to use for attempting to prove equivalence. - - 3. For bindings where the RHS is a function literal or assigning the `shape_` of a `GlobalVar`, see the rule for `Function` nodes. - 4. For `MatchShape(var, value, shape)`, we set the `shape_` of `var` to `shape`, as it will be dynamically checked. -8. «For `TupleGetItem(tuple_value, i)`, we examine the `shape_` of `tuple_value`; suppose it is `s`. If `s` is a `Tuple`, then we use its `i`th field. If it is `RuntimeDepShape`, we use `RuntimeDepShape`. If it is a `Call` to a function that returns a tuple with at least `i + 1` members, set the `shape_` to `TupleGetItem(s, i)`. Otherwise, raise an error at compile time (though this should not happen if type checking has passed).» -9. For `Call` nodes: - 1. For a call to an `ExternFunc`, we use `RuntimeDepShape` because we cannot analyze the shapes of arbitrary `PackedFunc`s and must check dynamically. - 2. For a call to an `Op`, we use the manually defined `FInferShape` macro if it has been defined and `RuntimeDepShape` if it has not. `FInferShape` is a function that takes in the call node and produces a `ShapeCompExpr`. - 3. «For all other cases with `Call(op, args)`, we consider the following cases: - 1. If `op` is a `GlobalVar` or a `Var` that refers to a function defined in the current scope, look up the `Function` node it references; let us call it `f`. Similarly, if `op` is itself a `Function` node, let `f` be `op`. - - Attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) on `f`'s return shape. A pseudocode procedure for this beta-reduction is given below, as a macro. - - 1. If the return shape of `f` is a `Call` node or contains any `Call` nodes, substitute any parameters of `f` for the corresponding member of `args`. (E.g., if `f` has parameters `p1`, `p2`, …, `pn` and any of these variables appears in the return shape, `p1` should be replaced with the first member of `args`; `p2`, with the second; etc.) If any member of `args` that is substituted this way is not a `Var` or `Constant`, consider beta-reduction to fail. - 2. For each shape annotation in the parameters of `f`, attempt to match it with the `shape_` of the corresponding member of `args`, substituting shape variables in the return shape accordingly. If the `shape_` of the member of `args` is `RuntimeDepShape`, consider beta-reduction to fail. If the `shape_` is not `RuntimeDepShape` but is incompatible with the parameter's shape annotation (e.g., a `Tuple` where a `ShapeExpr` was expected), report an error at compile time. - - If `f`'s return shape is `RuntimeDepShape`, then consider the call result to have `RuntimeDepShape`. If beta-reduction is considered to fail, then consider the call result to have `RuntimeDepShape`. If it succeeds, use the resulting shape as the `shape_` of the call result. - - 2. Otherwise, consider the result of the call to have `RuntimeDepShape`. - » -10. For a function node, set the `shape_` to `RuntimeDepShape`. - -### Procedure for Substituting a Function Return Shape to Determine the Shape of a Call - -The `substitute_shape` procedure defined below describes how the shape expression for a call result can be defined given the call arguments and the return shape annotation on the corresponding function node. Note that this procedure can obtain much more precise results in the cases of `Call` or `TupleGetItem` return shapes. - +**Checking Compatibility** + +In many cases during the derivation of structural information, it is important to judge when two distinct structural information encodings are compatible with each other or when they are too different from each other to be reconciled, which can indicate an error. In the case of shape information, this could mean having two symbolic shapes that can be proven not to be equal to each other. Because shape expressions can contain arithmetic and it can be very difficult to statically prove whether two arithmetic expressions are equal, we permit the compiler implementation to make a best-effort attempt to prove equality for arithmetic expressions. (The user can insert a `MatchCast` to check definitively.) Since the checks are best-effort, the compatibility check will only report incompatibility if two values are _definitely_ different from each other. + +We can check if some structural information `S1` is accepted where structural information `S2` is expected by the process given below, which we refer to as `check_compability(S1, S2)` for convenience. `check_compatibility` can find that `S1` and `S2` are compatible, possibly compatible, or incompatible. "Incompatible" indicates a definite mismatch that should result in a compiler error; "possibly compatible" indicates that the structures may or may not match and should likely result in a compiler warning (indicating that a user may want to insert a dynamic check). An invariant that should should is that if `check_compatibility(S1, S2)` returns "compatible" or "possible compatible", `erase_struct_info(S1) <: erase_struct_info(S2)` should hold; that is, compatibility of structural information should be consistent with typing rules. + +1. If `S2` is `ObjectStructInfo`, then they are compatible. +2. Otherwise, if `S1` and `S2` are not both `TensorStructInfo` or both `TupleStructInfo`, etc. (besides `ObjectStructInfo`), then report an incompatibility. +3. If `S1` and `S2` are both `TupleStructInfo`: + 1. If `S1.fields` is not the same length as `S2.fields`, they are incompatible + 2. Call `check_compability(S1.fields[i], S2.fields[i])` for all `i`. If any pair of fields is incompatible, then `S1` and `S2` are incompatible. If no pair of fields is incompatible but at least one is possibly compatible, then `S1` and `S2` are possibly compatible. If all pairs of fields are compatible, then `S1` and `S2` are compatible. +4. If `S1` and `S2` are both `ShapeStructInfo`: + 1. `S2.ndim` is -1, then they are compatible. + 2. Otherwise, give an error if `S1.ndim` does not match `S2.ndim`. + 3. If `values` is not defined for `S2`, then they are compatible. + 4. If `values` is defined for `S2` but not defined for `S1`, then they are possibly compatible. + 5. If `values` is defined for both `S1` and `S2`, then the two are incompatible if `S1.values[i]` can be proven to be _not_ equal to `S2.values[i]` for some `i`. If all members can be proven to be equal, then they are compatible. Otherwise, if at least one pair of values cannot be proven to be either equal or unequal, then they are possibly compatible. +5. If `S1` and `S2` are both `TensorStructInfo`: + 1. If `S2.dtype` is not `Void` and does not match `S1.dtype`, then they are incompatible. + 2. If `S2.ndim` is not -1 and does not match `S1.ndim`, then they are incompatible. + 3. If `S2.shape` is not defined, then they are compatible. + 4. If `S2.shape` is defined and `S1.shape` is not defined, then they are possibly compatible. + 5. Otherwise, if both `shape` fields are given and either is a `Var`, then consider `S1` and `S2` compatible if the compiler can statically prove that the `Var` holds the same value as the other `shape` field, consider them possibly compatible if the compiler cannot draw a conclusion one way or the other, and consider them incompatible if the `Var` definitely has a different value from the other `shape`. + 6. If both `shape` fields are given and they are both `ShapeExpr` nodes, then `S1` and `S2` are incompatible if the compiler can prove that some dimension of `S1.shape` is _not_ equal to the corresponding dimension of `S2.shape`. Otherwise, if the all dimensions can be proven to be equal, then consider them compatible. If at least one pair of dimensions cannot be proven to be equal or unequal, consider them possibly compatible. +6. If `S1` and `S2` are both `FuncStructInfo`: + 1. If `S1` and `S2` don't both have defined `params` or both have undefined `params`, consider them incompatible. + 2. If both `S1` and `S2` have undefined `params`, consider them compatible if they have an identical `derive_func` and consider them possibly compatible if they have different `derive_func`s (as they is no further way to introspect the `derive_func` and draw static conslusions about `PackedFunc`s). + 3. If `params` is defined for both `S1` and `S2`: + 1. Consider them incompatible if the `params` have different lengths. + 2. Next, map unbound shape variables as follows: Get a variable mapping `m` by applying `get_shape_var_mapping(S1.params[i], S2.params[i])` for all values of `i`, taking the union of all resulting mappings. Next, substitute all occurrences of the shape variables in `S1` with their values in `m`. + 3. If `check_compatible(S2.params[i], S1.params[i])` (note the direction of the check: see the subtyping rule for `FuncType`) is incompatible for any `i` or if `check_compatible(S1.ret, S2.ret)` is incompatible, then they are incompatible. Otherwise, if `check_compatible(S2.params[i], S1.params[i])` is possibly compatible for any `i` or if `check_compatible(S1.ret, S2.ret)` is possibly compatible, consider `S1` and `S2` possibly compatible. Consider `S1` and `S2` compatible only if all checks are compatible. + +**Unification** + +Analogously to subtyping, we can also consider a hierarchy of structural information, considering some structural information to more or less specific than other structural information. Accordingly, we can also define a least upper bound for structural information, as with types. + +We can define an analogue to subtyping for structural information, as below. We say that `S1` is more specific than `S2` and denote it as `S1 <<: S2` (to distinguish from the notation on subtyping) based on the conditions given here. As an invariant, if `S1 <<: S2` holds, then `erase_struct_info(S1) <: erase_struct_info(S2)`, though the converse may not be true. +1. Reflexivity: `S1 <<: S1` for all `S1`. +2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <<: S2` and `S2 <<: S3`, then `S1 <<: S3`. +3. For all `S1`, `S1 <<: ObjectStructInfo()`. +4. For `TensorStructInfo`: + 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=-1, dtype=d)`. + 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=n, dtype=Void, shape=s)`. + 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (not undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. + 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <<: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ or _possibly_ statically equal. +5. For `ShapeStructInfo`: + 1. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (possibly undefined), `ShapeStructInfo(ndim=n, values=v) <<: ShapeStructInfo(ndim=-1)`. + 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <<: ShapeStructInfo(ndim=n, values=undefined)`. + 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <<: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ or _possibly_ statically equal. +6. Given two lists of structural information `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <<: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <<: fields2[i]`. +7. For `FuncStructInfo`: + 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <<: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. + 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <<: F2` only if `F1.derive_func` and `F2.derive_func` are identical. + 3. Given two lists of structural information parameters `P1` and `P2` and two structural information annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <<: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <<: P1[i]` and `R1 <<: R2`. + +Given these rules, we can define how to unify (get the LUB) of two structural information annotations as follows (in pseudocode): ```python -def map_shape_vars(param_shape: ShapeCompExpr, arg_shape: ShapeCompExpr, shape_var_mapping: {tir::Var : PrimExpr}) -> bool: - if param_shape is RuntimeDepShape or arg_shape is RuntimeDepShape: - return False - if param_shape is ShapeExpr and arg_shape is ShapeExpr: - if len(param_shape.values) != len(arg_shape.values): - raise UnificationError("Shapes are of incompatible ranks") - for param_dim, arg_dim in zip(param_shape.values, arg_shape.values): - if param_dim in shape_var_mapping: - # syntactic equality - if arg_dim != shape_var_mapping[param_dim]: - # if they are statically not equal, e.g., 5 != 7 or 3 + 3 != 3*3 - if can_prove_not_equal(arg_dim, shape_var_mapping[param_dim]): - raise UnificationError("Incompatible dimensions") - else: - return False - else: - shape_var_mapping[param_dim] = arg_dim - return True - if param_shape is Tuple and arg_shape is Tuple: - if len(param_shape.fields) != len(arg_shape.fields): - raise UnificationError("Tuples are of incompatible lengths") - for param_field, arg_field in zip(param_shape.fields, arg_shape.fields): - ret = map_shape_vars(param_field, arg_field, shape_var_mapping) - if not ret: - return False - return True - if param_shape is TupleGetItem and arg_shape is TupleGetItem: - # Does not necessarily indicate a unification error, - # depending on what the tuple values are. - # Constant folding the TupleGetItem nodes could improve this unification case - if param_shape.index != arg_shape.index: - return False - return map_shape_vars(param_shape.tup_value, arg_shape.tup_value) - if param_shape is Call and arg_shape is Call: - # no dimension mapping to do in this case - return True - # if either is a Call or TupleGetItem, it is possible that the shapes - # can match dynamically even if they don't match statically - if (param_shape is Call - or param_shape is TupleGetItem - or arg_shape is Call - or arg_shape is TupleGetItem): - return False - raise UnificationError("Incompatible shape constructs") - -def substitute_vars(target: Expr, var_mapping: {Var: Expr}, shape_var_mapping: {tir::Var: PrimExpr}) -> Expr: - def substitute_shape_vars(target: PrimExpr): - if target is tir::Var: - if target in shape_var_mapping: - return shape_var_mapping[target] - else: - return target - # proceed recursively in all subexpressions, checking for vars - - if target is Var: - if target in var_mapping: - return var_mapping[target] - return target - if target is ShapeExpr: - return ShapeExpr([ - substitute_shape_vars(dim) - for dim in target.values +def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: + if S2 is ObjectStructInfo: + return S1 + if S1 is ObjectStructInfo: + return S2 + if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): + return ObjectStructInfo() + if S1 and S2 are both ShapeStructInfo: + if S1.ndim == -1: + return S1 + if S2.ndim == -1: + return S2 + if S1.ndim != S2.ndim: + return ShapeStructInfo(ndim=-1) + if S1.ndim == S2.ndim: + if S1.values is undefined: + return S1 + if S2.values is defined: + return S2 + if S1.values can be statically proven to match S2.values: + return S1 + # values either proven not to match or unknown + return ShapeStructInfo(ndim=S1.ndim) # leave values undefined + if S1 and S2 are both TensorStructInfo: + ndim = S1.ndim if S1.ndim == S2.ndim else -1 + dtype = S1.dtype if S1.dtype == S2.dtype else Void + if ( + S1.ndim == -1 or S2.ndim == -1 or S1.ndim != S2.ndim + or S1.shape is undefined or S2.shape is undefined + ): + return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + # both shapes are defined + if S1.shape can be proven to equal S2.shape: + return S1 + # either proven to be unequal or cannot be concluded whether they are equal + return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + if S1 and S2 are both TupleStructInfo: + if S1.fields and S2.fields are of different lengths: + return ObjectStructInfo() + return TupleStructInfo( + unify_struct_info(S1.fields[i], S2.fields[i]) + for 0 <= i < length of S1.fields ]) - # recurse through all other cases, checking for vars and shape exprs analogously - -def substitute_shape(func_params, arg_exprs, ret_shape): - var_mapping = {param: arg_expr for param, arg_expr in zip(func_params, arg_exprs)} - shape_var_mapping = {} - for param, arg_expr in zip(func_params, arg_exprs): - can_unify = map_shape_vars(param.shape_, arg_expr.shape_, shape_var_mapping) - if not can_unify: - return RuntimeDepShape() - - new_shape = substitute_vars(ret_shape, var_mapping, shape_var_mapping) - if new_shape contains any free (Relax or shape) variables: - return RuntimeDepShape() - return new_shape + if S1 and S2 are both FuncStructInfo: + if S1.params and S2.params are not both defined or both undefined: + return ObjectStructInfo() + if S1.params and S2.params are both undefined: + # they must be the same function, not bothering to check eta-equivalence + if S1.derive_func == S2.derive_func: + return S1 + return FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive) + if S1.params and S2.params are both defined: + if S1.params and S2.params do not have the same length: + return ObjectStructInfo() + unified_params = [] + for 0 <= i < length of S1.params: + unified_param = unify_struct_info(S1.params[i], S2.params[i]) + # That is, if the params judged to be equal, use them. + # If there is some pair that is not equal, + # we can't unify these types except with ObjectStructInfo + # See the use of GLB with FuncTypes + if unified_param <<: S1.params[i] and unified_param <<: S2.params[i]: + unified_params[i] = unified_param + else: + return ObjectStructInfo() + return FuncStructInfo(params=unified_params, ret=unify_struct_info(S1.ret, S2.ret)) ``` +### Derivation Rules + +Let `Δ` be the structural information context for Relax variables (to distinguish from `Γ` for types) and let `Σ` track which shape variables are in scope. + +1. «Prepopulate `Δ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Δ` corresponding to that `GlobalVar`.» +2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Δ[v]` for the structural information. +3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. +4. For `Tuple(fields)`, the resulting structural information is `TupleStructInfo([f.struct_info for f in fields])`, after deriving the structural information for the fields recursively. +5. For `ShapeExpr(values)`, the resulting structural information is `ShapeStructInfo(ndim, values)`, where `ndim` is the length of `values`. +6. For `If(cond, true_branch, false_branch)`, we compare the structural information of `true_branch` and `false_branch` (call these `S_t` and `S_f`, respectively). The resulting structural information is `unify_struct_info(S_t, S_f)`. +7. For `SeqExpr(blocks, body)`: + 1. For each binding block in `blocks` (call the current one `block`): + 1. Process each binding in the block, updating `Δ` and `Σ` accordingly (this is discussed in detail below). + 2. If `block` is a `DataflowBlock`, then remove all `DataflowVar`s introduced in `block` from `Δ` before proceeding to the next block. + 2. Next derive the structural information for `body`. Let us call this `S`. + 3. Remove all Relax variables introduced in `blocks` from `Δ` and all shape variables introduced in `blocks` from `Σ`. + 4. The structural information of the entire `SeqExpr` is `erase_to_well_defined(S, Δ, Σ)`. +8. For handling variable bindings: + 1. If `v` is the argument to a function, then if `v` has a structural annotation `S`, set `Δ[v]` to `S`. Add any unbound shape variables in `S` to `Σ`. If `v` does not have a structural annotation, set `Δ[v]` to `ObjectStructInfo()`. + 2. In the general `VarBinding(v, e)`: + 1. If `e` is a function literal, then recursion is permitted. In this case, `v` must have a structural annotation `Sv`. Derive the structural information for `e` as follows: Set `Δ[v]` to `Sv`, apply the normal rule for function literals (given below) to `e` to derive structural information `Se`, and finally remove `v` from `Δ`. Raise an error if `Se` and `Sv` are not compatible (via `check_compatibility`). + 2. Otherwise, derive the structural information of `e` and call it `Se`. + 3. If `v` has a structural annotation `Sv`, then apply `check_compatibility` to `Sv` and `Se`. If they are compatible, then set `Δ[v]` to `Sv` (respecting the user's intent in giving an annotation). Give a warning if `Sv` is more specific than `Se`. If are not compatible, then raise an error. + 4. If `v` does not have a structural annotation, then set `Δ[v]` to `Se`. + 3. For `MatchCast(v, value, S)`: + 1. Derive the structural information of `value` and call it `Sv`. + 2. Add any new shape variables in `S` to `Σ`. + 3. If `S <<: Sv` and `Sv <<: S` do not both hold, give a warning, as this indicates a cast that will always fail at run time. + 4. If `v` is given and it has a structural annotation `S'`, then give an error if `S` and `S'` are not compatible via `check_compatibility`. If they are compatible, then set `Δ[v]` to `S'` (respecting the user's intent in giving an annotation). (TODO: It doesn't seem very sensible to have a dynamic cast and give a different annotation, perhaps we should simply not permit doing that.) + 5. If `v` is given and it does not have a structural annotation, then set `Δ[v]` to `S`. +9. For `TupleGetItem(tuple_value, i)`, derive the structural information for `tuple_value` and call it `St`. Raise an error if `St` is not `TupleStructInfo`. If `St` is `TupleStructInfo(fields)`, then raise an error if `fields` value has less than `i + 1` members (this should not happen if type checking has passed) and use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. +10. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. +11. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: + 1. For a call to an `Op`, we use the manually defined `FInferStructInfo` macro if it has been defined and `ObjectStructInfo()` if it has not. `FInferStructInfo` is a function that takes in the call node and returns the structural information of the result. + 2. Otherwise, derive the structural information for `op` and call it `Sf`. Next derive the structural information for the args and call it `S1`, `S2`, ..., and `Sn`. + 1. Give an error if `Sf` is not `FuncStructInfo`. + 2. If the `derive_func` field of `Sf` is defined, then apply the `derive_func` macro to the call node to derive the structural information for the call node, ignoring the `ret` field of `Sf`. + 3. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. + 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. Replace all variables in `m` with their mapping in `Sf`. + 5. After the substitutions, give an error if `check_compatibility` indicates that the `i`th member of `params` and `Si` are incompatible for some `i` (warn if they are only possibly compatible). + 6. Use `erase_to_well_defined(Sf.ret, Δ, Σ)` as the resulting structural information. +12. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: + 1. Let `S1`, `S2`, ..., `Sn` be the structural information of the parameters. If `vi` has a structural annotation, then use that annotation for `Si`; if not, use `ObjectStructInfo()`. Let `Sr` be `ret_struct_info` if it is defined and `ObjectStructInfo()` if not. + 2. If the function is bound to a `GlobalVar` `gv`, set `Δ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. + 3. For each of the `vi`, set `Δ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. + 4. Derive the structural information for `body`, calling it `Sb`. + 5. Give an error if `Sb` is incompatible with `Sr` via `check_compatibility` (warn if only possibly compatible). + 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Δ, Σ))`. + 7. Remove all variables added to `Δ` and `Σ` during the derivation. + ### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks There can be some complexity involved in checking whether two shapes match during shape inference. A very simple, conservative method for determining equality is simply using alpha-equivalence: If the two shapes have the same structure, then they are equivalent. However, this method is conservative and can overlook numerical properties in `PrimExpr`s. We leave it up to compiler implementations as to whether to use more advanced methods for proving equivalence, such as attempting to use algebraic rewrite rules. (As a consequence, portability requires inserting dynamic checks wherever there needs to be a comparison of shapes.) Note that optimizations like function inlining or constant folding could allow for simplifying many shape annotations and expressions and make it possible to conclude at compile time that shapes in more cases are equivalent. In general, developing compiler infrastructure for partial evaluation and reasoning about common situations with shape annotations may eliminate many dynamic checks. -Applying some kind of normalization or algebraic simplifications to `PrimExpr`s used in shape annotations and in `shape_` fields can also make it easier to conclude that certain dynamic checks may not be necessary by increasing the likelihood that more `shape_` expressions could be made syntactically identical to the shape annotations. It would also be possible to generate compile-time warnings if analysis reveals that two shapes may not match (either using rewrite rules or by trying random values for shape variables and checking). +Applying some kind of normalization or algebraic simplifications to `PrimExpr`s used in structural information and `MatchCast` bindings can also make it easier to conclude that certain dynamic checks may not be necessary by increasing the likelihood that more derive structural information could be made syntactically identical to the structural annotations. It would also be possible to generate compile-time warnings if analysis reveals that two shapes may not match (either using rewrite rules or by trying random values for shape variables and checking). -Since most dynamic shape checks are done for safety, it may be feasible to introduce a compilation mode that eliminates almost all dynamic shape checks. Some shape checks may not be possible to eliminate, since the body of the program may construct `ShapeExpr`s and use them in calls to `PackedFunc`s, so some bindings to shape variables may need to be preserved, per a liveness analysis. +Since most dynamic structure checks are done for safety, it may be feasible to introduce a compilation mode that eliminates almost all dynamic structure checks. Some structure checks may not be possible to eliminate, since `ShapeExpr`s can use shape variables introduced in `MatchCast` brindings, so this would require some liveness analysis. -## Possible Extensions to the Shape Expression System +## Possible Extension: Indicating Unknown Dimensions -We may consider two possible extensions to the shape expression system in order to accommodate two further cases: +A further case that may be of interest might be using an explicit wildcard dimension (e.g., using `tir::Any`) to allow for dimensions to be specified as "unknown" in function return shapes. As described at present, the only way for a function to specify a partly unknown return shape is to make the entire return shape unknown (`RuntimeDepShape`), which loses partial shape information. -1. An explicit wildcard dimension (e.g., using `tir::Any`) to allow for dimensions to be specified as "unknown" in function return shapes. As described at present, the only way for a function to specify a partly unknown return shape is to make the entire return shape unknown (`RuntimeDepShape`), which loses partial shape information. -2. Adding `shape_` expressions consisting of functions, to allow arbitrary closures to have a known shape. This would allow the shapes of calls to closures of unknown origin (namely, in a higher-order function) to have their shapes correctly inferred rather than made `RuntimeDepShape`. - -In both cases, these additions would entail additional complexity (shape inference macros for operators would have to deal with potential `tir::Any` nodes and we would have to define rules for constructing, calling, and simplifying functions in `shape_` expressions). However, the advantage of implementing these features would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using `RuntimeDepShape` means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchShape` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. +This addition would entail some, as `FInferStructInfo` and `derive_func` macros would have to deal with potential `tir::Any` nodes. However, the advantage of implementing it would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using undefined `shape` fields means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchCast` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. # Detailed Semantics @@ -781,20 +930,19 @@ For each expression, we define how it affects the program's visible state and th 3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. 4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per type checking, must evaluate to a tuple) and then returning the `i`th field of the result. 5. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. -6. `RuntimeDepShape` expressions must not appear in the general body of a program; it is a well-formedness error if they do. They do not have any defined semantics. -7. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. -8. The node `If(cond, true_branch, false_branch)` is evaluated as follows: +6. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. +7. The node `If(cond, true_branch, false_branch)` is evaluated as follows: 1. First `cond` is evaluated. Let the result be `r` (per type checking, it must be a bool scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. -9. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: +8. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: 1. If `op` is an `ExternFunc` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Next, look up the `PackedFunc` registered under the global symbol name. If it exists (it is an error at run time if it does not), call the `PackedFunc` using the given arguments and return the result. Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. 2. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» - 3. In all other cases, first evaluate `op` (it must evaluate to a closure). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. -10. For the node `SeqExpr(blocks, body)`, we evaluate as follows: + 3. In all other cases, first evaluate `op` (it must evaluate to a closure). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) +9. For the node `SeqExpr(blocks, body)`, we evaluate as follows: 1. Push a new scope onto the stack. 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: - 1. If the binding is `MatchShape(var, value, shape)`, perform the shape matching and shape variable updates as described in the shape evaluation section. If `var` is provided, `var` will be bound to `value` in the current scope; this assignment is aliasing and no new value is allocated. If `var` is not provided, then the shape check is performed and shape variables are updated, but no new binding is introduced. + 1. If the binding is `MatchCast(var, value, struct_info)`, perform the structure matching and shape variable updates as described in the structural information section. If `var` is provided, `var` will be bound to `value` in the current scope; this assignment is aliasing and no new value is allocated. If `var` is not provided, then the structural check is performed and shape variables are updated, but no new binding is introduced. 2. If the binding is `VarBinding(var, value)`, then evaluate `value` and bind `var` to that value in the current scope; this assignment is aliasing and no new value is allocated. 3. If `block` is a `DataflowBlock`, remove all `DataflowVar`s bound in the block from the current scope before proceeding to the next block. 3. After iterating through the binding blocks, evaluate `body` in the current scope. That will be the return value of the `SeqExpr`. @@ -804,7 +952,7 @@ For each expression, we define how it affects the program's visible state and th Optimizations are allowed to reorder and modify the operations of a program in any way so long as they do not change the value returned by evaluating the program or any visible behavior of the program. For the purposes of compilation, visible behaviors consist of side effects like mutating values in the program or external effects like I/O (printing to the console, creating files, etc.) and the order and number of times in which they happen. -«Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchShape` or `cast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» +«Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchCast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» The specification makes no guarantees about certain memory-related properties and hence also does not consider them to be "visible behaviors": @@ -824,5 +972,3 @@ The above evaluation rules are general, but leave much room for implementations - «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type.» - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. -- «`cast(v, type_args=[aT])`: Given an argument `v`, it dynamically checks if `v`'s run-time representation is a subtype of `aT`. If it is not, it exits the program with an error message. Otherwise, it returns `v`.» - From 9e1997ab62b47ea2bd8a5f7bbec256ff6b10eb7f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 7 Jan 2023 21:33:16 -0500 Subject: [PATCH 15/47] Further StructInfo revisions --- relax_spec.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 0e2b4c36604f..6f1d2ea80d82 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -175,7 +175,7 @@ The type checking rules assign types to every variable in scope and every type o ## Structural Information System Survey -In Relax, tensor shapes are not handled in the type system; each expression instead a has an associated shape expression. In many cases, these shape computations can allow for statically concluding that two shapes are the same and thus eliminate the need for dynamic checks via `MatchShape`. However, when shapes cannot be statically concluded to be the same, it may be necessary for there to be dynamic checks. The compiler is also free to make use of shape expressions for memory planning purposes. «Relax is "strongly shaped," meaning that if the compiler cannot conclude that shapes match in certain cases, an error will be issued and an explicit `MatchShape` will be required.» +In Relax, tensor shapes are not handled in the type system, even though it would be greatly beneficial for the compiler to make use of shape information for static optimizations. Instead, shape information is tracked using Relax's structural information system, in which every expression has structural information associated with it (like tensor shapes) that is more expressive than its type. Structural information can convey richer properties about expressions, like tensor shapes, and can facilitate a greater degree of static reasoning. However, when it is not feasible for the compiler to draw conclusions about structural information, this information can be checked dynamically via `MatchCast`. The structural information is essentially an extended type system, so `MatchCast` also serves to handle type casting. --- @@ -526,7 +526,7 @@ A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds **Scope of Shape Variables** -New shape variables can be bound in two places in a Relax program: In `TensorStructInfo` or `ShapeStructInfo` annotations on function parameters or as the `struct_info` parameter in a `MatchCast` binding. Shape variables used in the function signature are scoped to the entire function in which they appear (including in the return structural annotation). Shape variables used in `MatchShape` bindings are scoped only to the `SeqExpr` in which they appear. +New shape variables can be bound in two places in a Relax program: In `TensorStructInfo` or `ShapeStructInfo` annotations on function parameters or as the `struct_info` parameter in a `MatchCast` binding. Shape variables used in the function signature are scoped to the entire function in which they appear (including in the return structural annotation). Shape variables used in `MatchCast` bindings are scoped only to the `SeqExpr` in which they appear. **Informal Semantics of `PrimExpr`s for Dimensions** @@ -555,7 +555,7 @@ This section describes the run-time checking performed by `MatchCast(var, value, ### Checking Structural Information at the Start and End of a Function -«Shape variables are bound at the start and end of a function or in `MatchShape` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchCast`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchCast` bindings. Suppose a function has the following signature, where the `Si` are structural annotations: +«Shape variables are bound at the start and end of a function or in `MatchCast` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchCast`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchCast` bindings. Suppose a function has the following signature, where the `Si` are structural annotations: ```python def f(arg1 : S1, arg2 : S2, ..., argn : Sn) -> Sr: From c11ece05b90a3f38e296664bfc6a85e2154543ae Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sat, 7 Jan 2023 21:33:55 -0500 Subject: [PATCH 16/47] Make PackedFunc first-class --- relax_spec.md | 59 ++++++++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 6f1d2ea80d82..b7f5a867ff0b 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -112,17 +112,17 @@ This specification provides a more detailed description of what each expression 1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). 2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. 3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. -4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchShape` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." -5. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. - 1. For `ExternFunc` nodes, the call will look up the registered `PackedFunc` by its global symbol and will call it with the given arguments (note that a TIR `PrimFunc` can be compiled into a `PackedFunc` and called using `ExternFunc` by defining a `global_symbol` attribute in the `PrimFunc`). «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» - 2. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s) «and `cast` (performs dynamic type conversions).» - 3. Any other expression must evaluate to a closure; the closure will then be called with the given arguments. +4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." +5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. +6. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. + 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s) «and `cast` (performs dynamic type conversions).» + 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. - Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. + Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» -6. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. -7. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. -8. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: +7. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. +8. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. +9. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. 1. The types must match: All `StructInfo` variants correspond to a type (`TensorStructInfo` to `DynTensorType`, `ShapeStructInfo` to `ShapeType`, etc.) and each type corresponds to a value (`DynTensorType` to a tensor value, `ShapeType` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: @@ -134,8 +134,8 @@ This specification provides a more detailed description of what each expression The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. -9. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. -10. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. +10. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +11. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. @@ -168,7 +168,7 @@ The types in Relax correspond to the broad categories of the values given above: 2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. 3. `ShapeType` corresponds to shape values, optionally giving the number of dimensions in the shape. 4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» -5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). Since packed functions are not first-class values (`ExternFunc` can appear only in the `op` position of a `Call` node), these do not actually correspond to any value in Relax, but can be used to assign a type to `ExternFunc` nodes. +5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). 6. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» @@ -191,8 +191,9 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Tensors* are n-dimensional arrays of scalar values (which can be integers of fixed bitwidths, floats of fixed bitwidths, or bools). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. - *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). -- *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time type information* (RTTI) indicating their argument types and result type, in order to facilitate dynamic type checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTTI is left up to the compiler implementation to determine so long as the `cast` operator can verify the type of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» +- *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. +- *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. - Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values of type `ObjectType`. ## Representation of Values at Run Time @@ -216,7 +217,7 @@ There are four relevant scopes in Relax, which determine where variables are vis 3. `SeqExpr`: `Var` nodes defined in a `BindingBlock` in a `SeqExpr` node can be referenced in any later binding within the same `BindingBlock`, in any binding within any later `BindingBlock` in that `SeqExpr` node, or in the `SeqExpr`'s body expression. The variables defined in the `BindingBlock`s leave scope once the `SeqExpr` returns. 4. `DataflowBlock`: `DataflowVar`s introduced in a `DataflowBlock` can be referenced in any later binding within that `DataflowBlock`, but leave scope *once that `DataflowBlock` finishes executing*. Definitions in a `DataflowBlock` that are intended to leave the `DataflowBlock` should be bound to an ordinary `Var`. -Note that Relax variables must be bound _exactly_ once. A global variable is bound if it is mapped to a function in the `IRModule` and a local variable is bound if it appears as a function parameter or if it appears on the left-hand side (LHS) of a binding (`VarBinding` or `MatchShape`). +Note that Relax variables must be bound _exactly_ once. A global variable is bound if it is mapped to a function in the `IRModule` and a local variable is bound if it appears as a function parameter or if it appears on the left-hand side (LHS) of a binding (`VarBinding` or `MatchCast`). «If there is another binding to a local variable with the same name as an already-bound variable, that is binding is considered to _shadow_ the previous binding, i.e., it is a binding to a new, distinct variable that happens to have the same name as the existing variable. The new, shadowing variable will exist only in the current scope; if the older variable was defined in an outer scope, then future uses of that name will refer to the older variable. [See the Wikipedia page for more information on variable shadowing.](https://en.wikipedia.org/wiki/Variable_shadowing)» @@ -276,15 +277,13 @@ The following criteria apply to all programs (including before normalization): 7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» 8. `Op` nodes may appear only as the `op` argument to `Call` nodes. -9. `ExternFunc` expressions may appear only as the `op` argument to `Call` nodes. -10. The `type_args` field is used only for type checking calls of `ExternFunc`s and `Op`s. No other calls may have a non-empty `type_args`. -11. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. -12. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. -13. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» -14. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» -15. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. -16. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. -17. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. +9. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. +10. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. +11. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» +12. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» +13. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. +14. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. +15. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. @@ -470,7 +469,7 @@ Let us consider a typing context `Γ`, which is a map of variables to types. 6. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. 7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT1, aT2, ..., aTn])`: 1. If `op` is a Relax `Op` node, then we look up its registered `FInferStructInfo` property. `FInferStructInfo` is a macro that takes in the `Call` node and produces structural information. Invoke `op.FInferStructInfo(Call(op, [a1, ..., an], type_args=[aT1, aT2, ..., aTn]))` and convert the result to a type using the `erase_struct_info` procedure defined above. The implementation of `FInferStructInfo` is free to throw errors. - 2. If `op` is `ExternFunc`, note that packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function's implementation to do any validation at run time. However, the type system uses the `type_args` field to determine the result type as follows: + 2. If `op` has `PackedFuncType`, note that packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function's implementation to do any validation at run time. (TODO: `derive_func` should be used here, propagated from the structural information.) However, the type system uses the `type_args` field to determine the result type as follows: 1. If there are no `type_args`, the resulting type is `ObjectType()`. 2. If there is exactly one member of `type_args`, use that as the return type. 3. If there are multiple members of `type_args`, then the type is `TupleType(fields=[aT1, aT2, ..., aTn])`. @@ -935,10 +934,12 @@ For each expression, we define how it affects the program's visible state and th 1. First `cond` is evaluated. Let the result be `r` (per type checking, it must be a bool scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. -8. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: - 1. If `op` is an `ExternFunc` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Next, look up the `PackedFunc` registered under the global symbol name. If it exists (it is an error at run time if it does not), call the `PackedFunc` using the given arguments and return the result. Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. - 2. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» - 3. In all other cases, first evaluate `op` (it must evaluate to a closure). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. Push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) +8. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. +9. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: + 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» + 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. + 1. If `op` evaluated to a closure, push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) + 2. If `op` evaluated to a `PackedFunc`, simply invoke it. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. 9. For the node `SeqExpr(blocks, body)`, we evaluate as follows: 1. Push a new scope onto the stack. 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: @@ -946,7 +947,7 @@ For each expression, we define how it affects the program's visible state and th 2. If the binding is `VarBinding(var, value)`, then evaluate `value` and bind `var` to that value in the current scope; this assignment is aliasing and no new value is allocated. 3. If `block` is a `DataflowBlock`, remove all `DataflowVar`s bound in the block from the current scope before proceeding to the next block. 3. After iterating through the binding blocks, evaluate `body` in the current scope. That will be the return value of the `SeqExpr`. - 4. Pop the scope, removing any `Var` bindings introduced in the `SeqExpr`. This should also remove any shape variables introduced and bound in the `SeqExpr` as well. + 4. Pop the scope, removing any `Var` or shape variable bindings introduced in the `SeqExpr`. ### Optimizations From 9c55433c2a973af9ba4a53364aad3401e6d2af02 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 9 Jan 2023 15:11:24 -0500 Subject: [PATCH 17/47] erase_to_well_defined should handle unbound shape vars in FuncStructInfo --- relax_spec.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index b7f5a867ff0b..164b0e31b5c0 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -657,13 +657,17 @@ def erase_to_well_defined( ) if s is FuncStructInfo: if params is defined: - return FuncStructInfo( - params=[ - erase_to_well_defined(param, var_scope, shape_var_scope) - for param in s.params - ], + new_params = [] + for param in s.params: + if param contains unbound shape variables: + insert unbound shape variables into shape_var_scope + new_params.append(erase_to_well_defined(param, var_scope, shape_var_scope)) + ret = FuncStructInfo( + params=new_params, ret=erase_to_well_defined(s.ret, var_scope, shape_var_scope) ) + remove any unbound shape variables added into shape_var_scope above + return ret else: return FuncStructInfo( ret=erase_to_well_defined(s.ret, var_scope, shape_var_scope), From e5eab8d2482c074e074cffcce51f01c0ac9151a7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 31 Jan 2023 19:56:16 -0500 Subject: [PATCH 18/47] Include `sinfo_args`, greatly condense discussion of types --- relax_spec.md | 661 ++++++++++++++++---------------------------------- 1 file changed, 208 insertions(+), 453 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 164b0e31b5c0..cc911b314a53 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -14,13 +14,12 @@ Though this document will use the TVMScript front end for some examples, specify 4. [Variable Scoping](#variable-scoping) 5. [Normal Form](#normal-form) 6. [Well-Formedness Criteria](#well-formedness-criteria) -7. [Types in Relax](#types-in-relax) -8. [Structural Information in Relax](#structural-information-in-relax) -9. [Semantics](#detailed-semantics) +7. [Structural Information in Relax](#structural-information-in-relax) +8. [Semantics](#detailed-semantics) # Overview -This section will outline the grammar of Relax and give very brief descriptions of the different components, including the semantics, type system, and shape system. The rest of this document will provide more detailed descriptions of these facets of the language, including the validity conditions that the type system and shape system uphold. +This section will outline the grammar of Relax and give very brief descriptions of the different components, including the semantics and structural information (`StructInfo`) system. The rest of this document will provide more detailed descriptions of these facets of the language, including the validity conditions that the `StructInfo` system upholds. ## Differences from Relay @@ -52,13 +51,6 @@ PrimExpr ::= | Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr) # (others may be added later, as deemed necessary) -Type ::= DynTensorType(ndim: int, dtype: DataType) - | ShapeType(ndim: int) - | ObjectType() - | TupleType(fields: [Type]) - | FuncType(arg_types: [Type], ret_type: Type, «pure: bool») - | PackedFuncType() - DataType ::= Int(bitwidth: int) | Float(bitwidth: int) | Bool() @@ -82,7 +74,7 @@ Expr ::= Constant(data: NDArray) | Function(params: [Var], body: Expr, ret_struct_info: StructInfo?, attrs: Attrs?) | If(cond: Expr, true_branch: Expr, false_branch: Expr) | ExternFunc(global_symbol: string) - | Call(op: Expr, args: [Expr], type_args: [Type], attrs: Attrs?) + | Call(op: Expr, args: [Expr], sinfo_args: [StructInfo], attrs: Attrs?) | ShapeExpr(values: [PrimExpr]) | TupleGetItem(tuple_value: Expr, index: int) | Op(op_name: string) @@ -107,7 +99,7 @@ Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) ## Expression Survey -This specification provides a more detailed description of what each expression and type represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. +This specification provides a more detailed description of what each expression and `StructInfo` represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. 1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). 2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. @@ -115,17 +107,17 @@ This specification provides a more detailed description of what each expression 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. 6. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. - 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s) «and `cast` (performs dynamic type conversions).» + 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s). 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. - Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. «The attribute "pure" can be specified on a call to an `ExternFunc` to indicate that it performs no side effects (for use inside `DataflowBlock`s).» + Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. 7. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. 8. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. 9. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. - 1. The types must match: All `StructInfo` variants correspond to a type (`TensorStructInfo` to `DynTensorType`, `ShapeStructInfo` to `ShapeType`, etc.) and each type corresponds to a value (`DynTensorType` to a tensor value, `ShapeType` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: + 1. The types must match: All `StructInfo` variants correspond to a category of value value (`TensorStructInfo` to a tensor value, `ShapeStructInfo` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: 1. For comparing tensor values to `TensorStructInfo`, `ndim` must match the number of dimensions in the tensor value (unless `ndim` is -1) and `dtype` must match the datatype used (unless `dtype` is `Void`). If `shape` has been specified, the shape of the value must match that encoded by `shape`; if specified, `shape` must be either a `Var` already bound in the current scope or a `ShapeExpr`. 2. For comparing shape values to `ShapeStructInfo`, `ndim` must match the number of dimensions in the shape value (unless `ndim` is -1). If `values` has been specified, the shape value must match that encoded by `values`. 3. «For comparing closures (function values) to `FuncStructInfo`, it is necessary for the compiled program to track run-time structural information for closures, since it is not possible to introspect the closure; this subject will be discussed in further detail later in the document.» @@ -143,7 +135,7 @@ This specification provides a more detailed description of what each expression ## Purity and Dataflow Blocks -A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. «In Relax, we conservatively assume that any function that calls an impure function is itself impure, though the attribute `force_pure` on a function can be used as an override (e.g., if a function creates a new tensor, mutates it, and returns it, that is still pure but does not satisfy the conservative rule).» +A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. Above, it is mentioned that `DataflowBlock`s are not allowed to contain constructs featuring control flow (`If` nodes or recursive calls to the current function) or calls to impure functions. This ensures that `DataflowBlock`s represent a directed acyclic graph of pure operations, which is similar to the graph-like abstractions of traditional deep learning frameworks. This allows many common optimizations from past frameworks to be directly adapted to `DataflowBlock`s without having to accommodate additional reasoning about more expressive features like control flow and side effects. @@ -154,28 +146,22 @@ There is one visible side effect that Relax permits inside otherwise "pure" func Even though an abnormal program exit is a visible side effect and removing or reordering it changes the observable semantics, it would be too great a restriction to prohibit error checking inside `DataflowBlock`s. Relax does not have any notion of exception handling, so the only consequence of a failed safety check can be exiting the program. It is permissible for the compiler to reorder, duplicate, or eliminate `MatchCast`, or otherwise pure operations that have the potential of failing, provided that doing so does not change the value returned by the program or any other visible behavior. -To indicate that an operator or `PackedFunc` that can abort with an error should *never* be reordered or removed by the compiler, it should *not* be marked as pure. However, this means that it cannot be used inside a `DataflowBlock`. - Note that in some programming languages like Koka, non-termination is also considered a side effect, since it can in some sense be "observed" by a user and affects the visible behavior of a program (e.g., if there is an infinite loop before a print statement, the print will never happen). However, since non-termination cannot be automatically detected in general and is unlikely to arise in deep learning models, we do not attempt to systematically track non-termination in Relax. In general, the Relax compiler is allowed to reorder or remove otherwise pure function calls even if they may not terminate. For example, if a pure function `f` that returns an integer scalar does not terminate, it is permissible in principle to rewrite `f() - f()` to 0. Exiting with an error and infinitely looping are traditionally considered "[divergence](https://en.wikipedia.org/wiki/Divergence_(computer_science))" in the programming languages literature. As a general principle, Relax's compiler is permitted to turn a program that diverges into a program that does not diverge (provided that no other visible effects change) so long as it never transforms a program that does not diverge into one that diverges. -## Type System Survey - -The types in Relax correspond to the broad categories of the values given above: - -1. `DynTensorType` corresponds to tensor values, giving the scalar data type and the number of dimensions (rank), both of which are optional. -2. `TupleType` corresponds to tuple values, giving the type of each member of the tuple. -3. `ShapeType` corresponds to shape values, optionally giving the number of dimensions in the shape. -4. `FunctionType` corresponds to function values (closures), giving the types of the parameters, the return type, «and whether the function is pure.» -5. `PackedFuncType` is the type given to arbitrary packed functions (external functions). -6. `ObjectType` is the parent type of all Relax types and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. +## Structural Information (`StructInfo`) System Survey -The type checking rules assign types to every variable in scope and every type of expression based on the values it returns, making use of subtyping to assign more general types when a more specific one cannot be determined. «Relax is strongly typed, meaning that if a type encountered is less specific than the one expected, an error will be issued and an explicit cast (via the `cast` operator) will be required.» +Analogously to a type system in most languages, Relax tracks structural information (referred to as `StructInfo` in the implementation) related to the categories of values in Relax: +1. `TensorStructInfo` corresponds to tensor values, giving the scalar data type, the number of dimensions (rank), and an expression that computes the tensor's shape (either a `ShapeExpr` or a `Var`), all of which are optional. +2. `TupleStructInfo` corresponds to tuple values, giving the `StructInfo` for each member of the tuple. +3. `ShapeStructInfo` corresponds to shape values, optionally giving the number of dimensions in the shape and an expression that computes the shape's dimensions (either a `ShapeExpr` or a `Var`). +4. `FunctionStructInfo` corresponds to function values (closures) and `PackedFunc`s (external functions), giving the types of the parameters, the return type, «and whether the function is pure.» +5. `ObjectStructInfo` is a parent to all Relax `StructInfo` and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. -## Structural Information System Survey +`StructInfo` is assigned to every variable in scope and every type of expression based on the values it returns via a set of inference rules defined later in the specification, making use of subtyping to assign more general `StructInfo` when a more specific one cannot be determined. «Relax is strongly typed, meaning that if the `StructInfo` inferred is less specific than the one expected, an error will be issued and an explicit check via `MatchCast` will be required.» -In Relax, tensor shapes are not handled in the type system, even though it would be greatly beneficial for the compiler to make use of shape information for static optimizations. Instead, shape information is tracked using Relax's structural information system, in which every expression has structural information associated with it (like tensor shapes) that is more expressive than its type. Structural information can convey richer properties about expressions, like tensor shapes, and can facilitate a greater degree of static reasoning. However, when it is not feasible for the compiler to draw conclusions about structural information, this information can be checked dynamically via `MatchCast`. The structural information is essentially an extended type system, so `MatchCast` also serves to handle type casting. +In Relax, tensor shapes are not statically handled in the type system, even though it would be greatly beneficial for the compiler to make use of shape information for static optimizations. Instead, shape information is tracked using Relax's structural information system, in which every expression has structural information associated with it (like tensor shapes) that is more expressive than its type. `StructInfo` can convey richer properties about expressions, like tensor shapes, and can facilitate a greater degree of static reasoning. However, when it is not feasible for the compiler to draw conclusions about structural information, this information can be checked dynamically via `MatchCast`. The structural information is essentially an extended type system, so `MatchCast` also serves to handle type casting. --- @@ -194,11 +180,11 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. -- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values of type `ObjectType`. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. ## Representation of Values at Run Time -Because Relax supports calls to arbitrary `PackedFunc`s that can operate on a low level, it is necessary to define a convention for how values will be represented at run time. At this time, the specification does not require any specific representation and permits compiler implementations to choose their own representations, provided that each value type listed above can be recognized at run time (for dynamic type checks). This means that Relax programs that call `PackedFunc`s directly are not portable across compiler implementations: The `PackedFunc`s used must be able to operate on the run-time representations of values. +Because Relax supports calls to arbitrary `PackedFunc`s that can operate on a low level, it is necessary to define a convention for how values will be represented at run time. At this time, the specification does not require any specific representation and permits compiler implementations to choose their own representations, provided that each value type listed above can be recognized at run time (for dynamic `StructInfo` checks). This means that Relax programs that call `PackedFunc`s directly are not portable across compiler implementations: The `PackedFunc`s used must be able to operate on the run-time representations of values. Possible specification in terms of the TVM object system: @@ -240,7 +226,7 @@ def func(x: Tensor) -> Tensor: # Normal Form -To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the type- and structure-checking rules for operators rely on macros (`FInferShapeInfo`), _this means that the structure of the program can affect type and structure inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these type- and structure-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying type or structure checking. +To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the `StructInfo`-checking rules for operators rely on macros (`FInferShapeInfo`), _this means that the structure of the program can affect `StructInfo` inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these `StructInfo`-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying `StructInfo` checking. The normal form for Relax is very similar to ANF; differences will be noted. Here are the criteria required for a program to be in normal form: 1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. @@ -250,7 +236,7 @@ The normal form for Relax is very similar to ANF; differences will be noted. Her 3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. 4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. Empty `BindingBlock`s should be dropped. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. If all the `BindingBlock`s are empty, then the `blocks` field of the `SeqExpr` should be set to an empty list. -Programs that are parsed should be "normalized" before performing type checking or structure checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: +Programs that are parsed should be "normalized" before performing `StructInfo` checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: 1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. 2. If the function body is already a `SeqExpr`, consolidate all `BindingBlock`s, then check if the `body` field of the `SeqExpr` is a leaf expression. If not, bind it to a new var in the final `BindingBlock` and replace the `SeqExpr` body with the new var. 3. If the function body is not a `SeqExpr`, then recurse down the body's AST, binding any nested non-leaf expressions to a var in the current scope (doing this process in breadth-first order from left to right will respect the evaluation order in the semantics). If the body itself is a non-leaf expression, finally bind it to a var and have the final `SeqExpr` return the new var. @@ -260,7 +246,7 @@ Programs that are parsed should be "normalized" before performing type checking # Well-Formedness Criteria -Prior to type-checking and shape inference, Relax programs must conform to certain syntactic criteria to be valid, which includes conforming to the expectations of the above-described normal form. +Prior to `StructInfo` checking, Relax programs must conform to certain syntactic criteria to be valid, which includes conforming to the expectations of the above-described normal form. The following criteria apply to all programs (including before normalization): 1. `DataflowVar`s can be bound only inside `DataflowBlock`s. Additionally, a `DataflowVar` may not be used outside of the `DataflowBlock` in which it is defined. @@ -273,11 +259,11 @@ The following criteria apply to all programs (including before normalization): 2. Calls to a global function that is mutually recursive with the current function 3. `If` nodes - «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during type checking.» + «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during `StructInfo` checking.» 7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» 8. `Op` nodes may appear only as the `op` argument to `Call` nodes. -9. If a variable has both a type annotation and a shape annotation, the `ndim` of any `DynTensorType`s must match the number of dimensions in the corresponding shape annotation. +9. If a variable has a `StructInfo` annotation, the `ndim` of any `TensorStructInfo` and `ShapeStructInfo`s must match the number of dimensions in their `shape` and `values` fields, respectively. 10. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. 11. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» 12. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» @@ -287,238 +273,23 @@ The following criteria apply to all programs (including before normalization): Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. -# Types in Relax - -Relax's type system is intended to enforce strong guarantees that values are passed correctly between expressions. The design emphasis is on simplicity, aiming to leave more complex analysis to the structural information. - -Relax presently has six types, corresponding to the values in the language: - -1. `DynTensorType`, referring to tensor values (referred to in the front-end as `Tensor`). In Relax, tensor types keep track of the rank (number of dimensions) in the `ndim` field and the data type of the tensor data in the `dtype` field. Both the rank and data type are optional: Using -1 for `ndim` indicates that the tensor is of unknown rank and using `DataType::Void()` for `dtype` indicates that it's of unknown data type. -2. `ShapeType`, referring to shape values. The number of dimensions in the shape as given as `ndim` and is optional (using -1 for `ndim` indicates an unknown number of dimensions). -3. `FuncType`, referring to functions (closures). `FuncType`s specify the types of their parameters, a return type, and whether the function is pure. -4. `TupleType`, referring to tuple values, giving the types of their fields. -5. `PackedFuncType`, referring to the type of PackedFunctions. -6. `ObjectType`, referring to any Relax value, including values used and returned by `PackedFunc` or operator calls that do not belong in any of the above categories. - -## Erasing Structural Information into Types - -Several type-checking rules rely on structural annotations or rules for defining the structural information for a call to an `Op` or `PackedFunc`. In general, types are simpler than structural information (to facilitate more precise reasoning). Structural information can be convereted into a type as follows (in pseudocode): - -```python -def erase_struct_info(si: StructInfo) -> Type: - if si is TensorStructInfo: - return DynTensorType(ndim=si.ndim, dtype=si.dtype) - if si is ShapeStructInfo: - return ShapeType(ndim=si.ndim) - if si is TupleStructInfo: - return TupleType(fields=[erase_struct_info(field) for field in si.fields]) - if si is FuncStructInfo: - # this should be the case only for packed funcs - if si.params is not specified: - return PackedFuncType() - return FuncType( - arg_types=[erase_struct_info(arg_type) for arg_type in si.params], - ret_type=erase_struct_info(si.ret) - pure=False) # TODO: This suggests we should either handle purity - # in StructInfo entirely (and not make it part of the type) - # or include it in both StructInfo and the type system - # only remaining case is ObjectStructInfo - return ObjectType() -``` - -## Subtyping - -Relax implements subtyping, which means that members of types can be accepted where members of their supertypes are accepted. We will denote the subtyping relationship as `T1 <: T2`, indicating that `T1` is a subtype of `T2`. For example. if `T1 <: T2` and some function expects an argument of type `T2`, then passing a member of type `T1` to that function is permitted; passing a member of type `T2` as an argument to a function that expects type `T1` for that argument is *not* permitted—the value would have to be dynamically cast to `T1` using the `cast` operator. - -### Rules for Subtyping - -1. Reflexivity: For all types `T`, `T <: T`. -2. Transitivity: For all types `T1`, `T2`, and `T3`, if `T1 <: T2` and `T2 <: T3`, then `T1 <: T3`. -3. For all types `T`, `T <: ObjectType`. Hence, `ObjectType` is a supertype to all Relax types (all values in Relax are members of `ObjectType`). -4. Rules for `DynTensorType`: - 1. For all fixed `ndim` values `m`, where `m` ≥ 0, and `dtype`s `d`, `DynTensorType(ndim=m, dtype=d) <: DynTensorType(ndim=m, dtype=Void)`. - 2. For all fixed `ndim` values `m` and `dtype`s `d` that are not `Void`, `DynTensorType(ndim=m, dtype=d) <: DynTensorType(ndim=-1, dtype=d)`. - 3. Corollary: `DynTensorType(ndim=-1, dtype=Void)` is a supertype to all tensor types, since it refers to any possible tensor value. -5. Suppose we have types `T1 <: T1'`, `T2 <: T2'`, …, `Tn <: Tn'`. Then `TupleType(fields=[T1, T2, ..., Tn]) <: TupleType(fields=[T1', T2', ..., Tn'])`. -6. Rules for `FuncType`: - 1. Impure functions are supertypes to pure functions. Namely, if we have types `T1`, `T2`, …, `Tn` and `Tr`, then `FuncType(arg_types=[T1, T2, ..., Tn], ret_type=Tr, pure=True) <: FuncType(arg_types=[T1, T2, ..., Tn], ret_type=Tr, pure=False)`. - 2. Suppose we have types `T1' <: T1`, `T2' <: T2`, …, `Tn' <: Tn` and `Tr <: Tr'`. Then `FuncType(arg_types=[T1, T2, ... Tn], ret_type=Tr, pure=p) <: FuncType(arg_types=[T1', T2', ..., Tn'], ret_type=Tr', pure=p)`. Note the direction of the subtyping relationships for the argument and return types: We must be able to *call* this function with the *same* arguments and *use the returned value* wherever it is accepted—hence a function that takes more general arguments and returns a more specific return value can be used in place of the original. - -These rules allow us to define the least upper bound (LUB) for any two types `T1` and `T2`, meaning that it is the most specific type `T` for which `T1 <: T` and `T2 <: T` ("most specific" meaning that if there exists some other `T'` for which `T1 <: T'` and `T2 <: T'`, then `T <: T'`). The LUB is guaranteed to exist for any two types because `Object` is a supertype to all types. - -Note that the rule for obtaining the LUB of function types relies on the counterpart to the LUB, the greatest lower bound (GLB). The GLB is not guaranteed to exist for any two types in Relax, as there is no single type that is a subtype of all others. - -We can give an algorithm for determining the LUB and GLB for two types, in pseudocode: +# Structural Information (`StructInfo`) in Relax -```python -def find_glb(T1 : Type, T2 : Type) -> Type?: - if T1 == T2: # syntactic equality - return T2 - if T1 is ObjectType: - return T2 - if T2 is ObjectType: - return T1 - if T1 and T2 are not both DynTensorType, not both TupleType, not both FuncType, or not both ShapeType, or not both PackedFuncType: - return None - if T1 and T2 are both ShapeType: - ret_ndim = T1.ndim - if ret_ndim == -1: - ret_ndim == T2.ndim - if ret_ndim != -1 and T2.ndim != ret_ndim: - return None - return ShapeType(ret_ndim) - if T1 and T2 are both DynTensorType: - ret_ndim = T1.ndim - ret_dtype = T1.dtype - if ret_ndim == -1: - ret_ndim == T2.ndim - if ret_dtype == Void: - ret_dtype = T2.dtype - if ret_ndim != -1 and T2.ndim != ret_ndim: - # mismatch, so there's no common lower bound - return None - if ret_dtype != Void and T2.dtype != ret_dtype: - return None - return DynTensorType(ret_ndim, ret_dtype) - if T1 and T2 are both TupleType: - if they do not have the same length: - return None - fields = [] - for field1, field2 in zip(T1.fields, T2.fields): - glb = find_glb(field1, field2) - if glb is None: - return None - fields.append(glb) - return TupleType(fields) - if T1 and T2 are both FuncType: - «if they are not both pure or both impure:» - «return None» - purity = T1.purity - if they do not have the same arity: - return None - # mutual recursion with finding the LUB - arg_types = [ - find_lub(arg_type1, arg_type2) - for arg_type1, arg_type2 in zip(T1.arg_types, T2.arg_types) - ] - ret_type = find_glb(T1.ret_type, T2.ret_type) - if ret_type is None: - return None - return FuncType(arg_types, ret_type, purity) - -def find_lub(T1 : Type, T2 : Type) -> Type: - if T1 == T2: # syntactic equality - return T1 - if T1 or T2 is ObjectType: - return Object - if T1 or T2 are not both DynTensorType, or both TupleType, or both FuncType, or both ShapeType, or both PackedFuncType: - return ObjectType - if T1 and T2 are both ShapeType: - res_ndim = T1.ndim - if T1.ndim != T2.ndim: - res_ndim = -1 - return ShapeType(res_ndim) - if T1 and T2 are both DynTensorType: - res_ndim = T1.ndim - res_dtype = T1.dtype - if T1.ndim != T2.ndim: - res_ndim = -1 - if T1.dtype != T2.dtype: - res_dtype = Void - return DynTensorType(res_ndim, res_dtype) - if T1 and T2 are both TupleType: - if they do not have the same length: - return ObjectType - return TupleType([ - find_lub(field1, field2) - for field1, field2 in zip(T1.fields, T2.fields) - ]) - if T1 and T2 are both FuncType: - «purity = (True iff they're both pure)» - if they do not have the same arity: - return ObjectType - arg_types = [] - for arg_type1, arg_type2 in zip(T1.arg_types, T2.arg_types): - # potential mutual recursion - glb = find_glb(arg_type1, arg_type2) - if glb is None: - return ObjectType - arg_types.append(glb) - return FuncType(arg_types, find_lub(T1.ret_type, T2.ret_type), «purity») -``` +Structural information in Relax is intended to enforce basic guarantees that values are passed correctly between expressions, while also analyzing more complex properties like tensor shapes in a _"best-effort"_ fashion. Namely, anything that cannot be proved statically can instead be checked at run time. Each Relax expression has structural information associated with it. The best-effort nature of the structural system in Relax means that the analysis may detect _some_ errors at compile time and report them, but it may give warnings when it _cannot_ draw conclusions, perhaps suggesting that dynamic checks via `MatchCast` should be inserted. Note that the precision of the static analysis can potentially be improved by some compile-time optimizations like constant propagation, function inlining, and other partial evaluation–like transformations. -### When Type Conversions are Necessary - -For two types `T1` and `T2`, if `T1 <: T2`, then a value of type `T1` can be passed anywhere a value of type `T2` is expected without any need for type conversions or dynamic checks. - -*However*, if `T1 <: T2`, then passing a value of type `T2` where `T1` is expected can only be done if there has been some kind of dynamic check or conversion of that value. «Relax is *strongly typed*, meaning that the compiler will give an error in this situation and require an explicit conversion via a `MatchCast` node, which inspects the value's run-time representation.» - -If `T1` is not a subtype of `T2` and `T2` is not a subtype of `T1`, then it is always a type error to pass a value of either type where a value of the other is expected (no member of either type can be a member of the other). - -## Type Checking Rules - -The type checking rules for Relax are relatively simple and allow in some cases for types to be inferred without user annotations. Below, we describe how the types for each expression can be derived and when type checking should return an error. - -Let us consider a typing context `Γ`, which is a map of variables to types. - -1. «We type check the entire `IRModule` one function definition at a time. To handle mutual recursion, we prepopulate `Γ` with the annotated types of all global functions that are called mutually recursively. We then proceed to check the types of the global functions one at a time.» -2. Given a variable `v`, if `v` is in `Γ`, then we return `Γ[v]`. Otherwise, it is an error. -3. Given a constant expression `Constant(data)`, the type is `DynTensorType(ndim=n, dtype=d)` where `n` is the number of dimensions in `data` and `d` is its data type (recall that `data` is an `NDArray` literal). -4. Given a shape expression `ShapeExpr(dims)`, its type is `ShapeType(n)`, where `n` is the length of `dims`. -5. Given a tuple literal `Tuple([e1, e2, ..., en])`, suppose that `e1` has the type `T1`, `e2` has the type `T2`, …, and `en` has the type `Tn`. Then the type is `TupleType([T1, T2, .., Tn])`. -6. Given an `ExternFunc` expression, assign it the type `PackedFuncType()`. -7. Given a call node `Call(op=op, args=[a1, a2, ..., an], type_args=[aT1, aT2, ..., aTn])`: - 1. If `op` is a Relax `Op` node, then we look up its registered `FInferStructInfo` property. `FInferStructInfo` is a macro that takes in the `Call` node and produces structural information. Invoke `op.FInferStructInfo(Call(op, [a1, ..., an], type_args=[aT1, aT2, ..., aTn]))` and convert the result to a type using the `erase_struct_info` procedure defined above. The implementation of `FInferStructInfo` is free to throw errors. - 2. If `op` has `PackedFuncType`, note that packed functions may be passed any combination of values and return any value; it is the responsibility of the packed function's implementation to do any validation at run time. (TODO: `derive_func` should be used here, propagated from the structural information.) However, the type system uses the `type_args` field to determine the result type as follows: - 1. If there are no `type_args`, the resulting type is `ObjectType()`. - 2. If there is exactly one member of `type_args`, use that as the return type. - 3. If there are multiple members of `type_args`, then the type is `TupleType(fields=[aT1, aT2, ..., aTn])`. - 3. Otherwise, check the types of the subexpressions, left to right. Suppose `op` has the type `Tf`, `a1` has type `T1`, …, and `an` has type `Tn`. If `Tf` is not a function type with exactly `n` arguments, we consider it a type error and require an explicit cast. Suppose `Tf` has argument types `T1'`, `T2'`, …, `Tn'`. Consider it a type error if any of the following does not hold: `T1 <: T1'`, `T2 <: T2'`, …, or `Tn <: Tn'`. Then the return type is `Tf.ret_type`. -8. Given a conditional expression `If(cond, true_branch, false_branch)`, we first assert that the type of `cond` is `DynTensorType(ndim=0, dtype=bool)`, giving an error otherwise. Next, we recursively check the types of `true_branch` and `false_branch`; suppose they yield types `Tt` and `Tf`, respectively. The return type will be `T = LUB(Tt, Tf)`. -9. For a `TupleGetItem(t, i)` expression, suppose that the type of `t` is `T`. If `T` is a tuple type with at least `i` members, then return the `i`th type in `T`. «Give a type-checking error and require an explicit cast if `T` is not a tuple type with at least `i` members.» -10. Let us consider a sequence expression `SeqExpr(blocks = [b0, b1, ..., bn], body)`. - 1. We type check the binding blocks `b0`, `b1`, …, `bn` in order. For each block, we go through the bindings in order. - 1. «If the current block is a `DataflowBlock`, consider it an error if any binding contains a call to an expression with a function type that is not pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not have a `pure` attribute.» - 2. For each binding `VarBinding(v, e)` in the current block, check the type of `e` and suppose it is `T'`. If `v` has a structural annotation, then let `T` be the corresponding type (via the `erase_struct_info` procedure above). If there is no annotation, then add `v` to `Γ` with type `T'`. If `T` has been defined, then emit an error if `T'` is not a subtype of `T` and otherwise add `v` to `Γ` with type `T`. «If `T'` is a supertype of `T`, emit an error and require a cast.» Note that this means that annotated types can be *less specific* than the inferred type and that a user annotation forces the type system to consider the variable as having a less specific type than it does. (Note: In the case where `e` is a `Function` literal, we require `v` to have a structural annotation add `v` to `Γ` with its annotated type before type-checking the function body; see the rule for `Function` nodes.) - 3. For each `MatchCast(v, e, struct_info)`: - 1. Check the type of `e` and let it be `T'`. - 2. Let `T''` be the type corresponding to `struct_info` (via the `erase_struct_info` procedure). - 3. Emit a warning if `T'` is not a supertype of `T''` and `T''` is also not a supertype of `T'`; this indicates that the cast is _guaranteed_ to fail at run time. - 4. If `v` has been defined and it has a structural annotation, then let `T` be its corresponding type (via `erase_struct_info`). - 5. If `T` has been defined, then emit an error if `T` is not a supertype of `T''`. - 6. If `v` has been defined and does not have a structural annotation, then add `v` to `Γ` with type `T''`. If `T` has also been defined, then add `v` to `Γ` with type `T`. - 2. If the current block is a `DataflowBlock`, remove any `DataflowVar`s from `Γ` after we have finished processing the bindings in that block. - 3. Finally, the type of the `SeqExpr` is the type of `body` checked under the current `Γ`. Afterwards, remove all bindings added from the `b0`, `b1`, …, `bn` from `Γ`. -11. Let us consider a function `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`. All of the vars are required to have structural annotations; let `T1` be the type corresponding to `v1`'s annotation (via `erase_struct_info`), `T2` be the type corresponding to `v2`'s annotation, etc.. - 1. For handling recursive calls: If the function has been bound to a name `fv` (which may be a `Var` or a `GlobalVar`), then add `fv` to `Γ` with the type `FuncType([T1, T2, ..., Tn], Tr, pure=p)` and proceed as below, «where `p` is `True` if a `pure` attribute is included and `False` otherwise». Remove `fv` from `Γ` before returning. - 2. Add `v1` to `Γ` with type `T1`, `v2` with type `T2`, …, and `tn` with type `Tn`. Recursively type check `body` in this new context: - 1. «Determining purity: If `body` contains any call to a function whose return type does not specify that it is pure, a call to an `ExternFunc` that does not have the `pure` attribute, or a call to an `Op` that does not specify the `pure` attribute, then we consider the function to be (potentially) impure. If all calls are to functions whose return type specifies purity or that include the `pure` attribute on the call or `Op`, then the function is treated as pure.» - 2. «Suppose the purity defined in the previous step is `p'`. Suppose the annotated function purity (in the attributes) is `p`. If `p'` is false while `p` is true, then it is a type error; if `p` was omitted, use `p'` for `p`.» - 3. «If the function has the attribute "`force_pure`," then consider `p` to be true, even if the check above judged the function not to be pure. The compiler may emit a warning in this situation.» - 4. Suppose the result of type-checking `body` is `Tr'`. If the current function is not recursive or a mutually recursive global function and `ret_struct_info` is undefined, consider the function to have type `FuncType([T1, T2, ..., Tn], Tr', pure=p)`. If `ret_struct_info` is defined, then let `Tr` be `erase_struct_info(ret_struct_info)`. If `Tr' <: Tr`, then we consider the function to have type `FuncType([T1, T2, ..., Tn], Tr, pure=p)`. «If `Tr <: Tr'`, return an error and require an explicit cast in the function body.» If `Tr'` is not a subtype of `Tr` and `Tr` is also not a subtype of `Tr'` (meaning a dynamic cast cannot succeed), this is an error. - 5. Remove `v1`, `v2`, …, and `vn` from `Γ` before returning. - -# Structural Information in Relax - -In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. While this allows Relay's type system to make strong guarantees about tensor shapes, it results in greater complexity in type checking and makes it difficult to implement new operators or handle cases like tensors with symbolic shapes. - -Relax instead aims to facilitate analysis of more complex properties like shapes by tracking _structural information_ pertaining, encoding as much analysis as is feasible at compile-time in a _"best-effort"_ fashion. Anything that cannot be proved statically can instead be checked at run time. Each Relax expression has structural information associated with it just as it has a type. Indeed, the structural information for each expression can be simplified into a type (recall [the procedure for doing so](#erasing-structural-information-into-types)), so the structural information for an expression can be thought of as an extended type that is checked in a less precise manner. The best-effort nature of the structural system in Relax means that the analysis may detect _some_ errors at compile time and report them, but it may give warnings when it _cannot_ draw conclusions, perhaps suggesting that dynamic checks via `MatchCast` should be inserted. Note that the precision of the static analysis can potentially be improved by some compile-time optimizations like constant propagation, function inlining, and other partial evaluation–like transformations. - -Tensor shapes are the primary motivation for including structural information in Relax, as shape information is particularly important for memory planning. Relax's structural information system uses expressions to encode tensor shapes, which allows for using shape variables and arithmetic expressions to encode a rich variety of shape constraints. Note, however, that the structural system could potentially be extended to encode and analyze further information, like tensor sparsity or density. +Tensor shapes are the primary motivation for including structural information in Relax, as shape information is particularly important for memory planning. In Relay, shapes are part of tensor types and there is much analysis of tensor shapes done at compile time. While this allows Relay's type system to make strong guarantees about tensor shapes, it results in greater complexity in type checking and makes it difficult to implement new operators or handle cases like tensors with symbolic shapes. By contrast, Relax's `StructInfo` system uses expressions to encode tensor shapes, which allows for using shape variables and arithmetic expressions to encode a rich variety of shape constraints. Note, however, that the structural system could potentially be extended to encode and analyze further information, like tensor sparsity or density. ## Defining Structural Information -As with types, the structural information in Relax corresponds to the values in the language: -* `TensorStructInfo` describes tensor values. Like in `DynTensorType`, the `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` whose type is `ShapeType`. If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation (that returns a shape). which can be useful for memory planning. -* `ShapeStructInfo` describes shape values. Like `ShapeType`, it has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. -* `TupleStructInfo` describes tuple values, namely by giving the structural information for each of the tuple's members via `fields`. +The structural information in Relax corresponds to the values in the language: +* `TensorStructInfo` describes tensor values. The `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` (with `ShapeStructInfo`). If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation that returns a shape value, which can be useful for memory planning. +* `ShapeStructInfo` describes shape values. It has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. +* `TupleStructInfo` describes tuple values, namely by giving the `StructInfo` for each of the tuple's members via `fields`. * `FuncStructInfo` describes closure values or `PackedFunc`s. There are two ways in which to specify `FuncStructInfo`: - 1. By specifying `params` and `ret` (for closures). `params` gives the structural information corresponding to each of the function's parameters and `ret` gives the structural information corresponding to the result of calling the function. In this case, the `derive_func` field is left undefined. + 1. By specifying `params` and `ret` (for closures). `params` gives the `StructInfo` corresponding to each of the function's parameters and `ret` gives the `StructInfo` corresponding to the result of calling the function. In this case, the `derive_func` field is left undefined. 2. By giving a `derive_func` macro (for `PackedFunc`s). The `derive_func` macro is takes a call to the corresponding `PackedFunc` and the variable mapping context and returns the `StructInfo` of the result. In this case, the `params` field is left undefined and the `ret` field is ignored. * `ObjectStructInfo` describes arbitrary object values. -While these categories correspond closely to types, they serve as a mechanism for propagating further information (especially as given in shape annotations in variable bindings) throughout the program and facilitating more static analysis. - ### Expressing Shape Dimensions A tensor shape is a tuple of TIR `PrimExpr`s, where each `PrimExpr` corresponds to a dimension. The use of TIR `PrimExpr`s for shape dimensions allows shape computations to express complex constraints that include variables and integer arithmetic expressions in addition to just constant dimensions. @@ -544,7 +315,7 @@ This section describes the run-time checking performed by `MatchCast(var, value, 2. If `struct_info` is `TensorStructInfo(ndim, dtype, shape)`, then check that `value` is a tensor value, that it has a rank of `ndim` (if `ndim` is not -1), a datatype of `dtype` (if `dtype` is not `Void`). If `shape` is defined, consider the following cases: 1. If `shape` is a `Var`, then check that the concrete shape of `value` matches the value bound to the `Var`. 2. If `shape` is a `ShapeExpr`, then compare the fields of the `ShapeExpr` to the concrete shape of `value`, dimension by dimension (comparing the `i`th field of the `ShapeExpr` to the `i`th dimension of the shape of `value`). Give an error if the number of the dimensions does not match the number of fields in the `ShapeExpr`. - 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. + 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. 2. Otherwise, evaluate the field of the `ShapeExpr` and ensure that it matches the concrete value of the dimension. 3. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): 1. If the `i`th member of `values` consists of only an unbound shape variable, then bind that variable to the `i`th field of the the concrete shape value. @@ -575,9 +346,110 @@ def f(arg1, arg2, ..., argn): ``` » +## Subtyping for `StructInfo` + +Relax implements subtyping for `StructInfo`, which means that values with some `StructInfo` can be accepted where values with more general `StructInfo` are accepted We will denote the subtyping relationship as `S1 <: S2`, indicating that `S1` is a subtype of `S2`. For example. if `S1 <: S2` and some function expects an argument with `StructInfo` `S2`, then passing a value with `StructInfo` `S1` to that function is permitted; passing a value with `StructInfo` `S2` as an argument to a function that expects `S1` for that argument is *not* permitted—the value would have to be dynamically cast to `S1` using `MatchCast`. + +Note that judging subtyping requires potentially reasoning about arbitrary `ShapeExpr`s. We assume that the compiler is able to draw the following three conclusions about two shape expressions, acting conservatively (it will consider values to be _definitely_ equal or _definitely not_ equal only if it is certain): +* They are _definitely_ statically equal in all cases. +* They are _possibly_ statically equal. +* They are _definitely not_ statically equal in at least one case. + +1. Reflexivity: `S1 <: S1` for all `S1`. +2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <: S2` and `S2 <: S3`, then `S1 <<: S3`. +3. For all `S1`, `S1 <: ObjectStructInfo()`. +4. For `TensorStructInfo`: + 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=-1, dtype=d)`. + 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=Void, shape=s)`. + 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s`, `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. + 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ statically equal. We say that `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` _possibly_ holds if `s1` and `s2` are _possibly_ statically equal. +5. For `ShapeStructInfo`: + 1. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (possibly undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=-1)`. + 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=n, values=undefined)`. + 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ statically equal. We say that `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` _possibly_ holds if `v1` and `v2` are _possibly_ statically equal. +6. Given two lists of `StructInfo` `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <: fields2[i]`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships for the fields only possibly holds. +7. For `FuncStructInfo`: + 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. + 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <: F2` only if `F1.derive_func` and `F2.derive_func` are identical. + 3. Given two lists of `StructInfo` parameters `P1` and `P2` and two `StructInfo` annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <: P1[i]` and `R1 <: R2`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships given only possibly holds. + +These rules allow us to define the least upper bound (LUB) for any two `StructInfo` `S1` and `S2`, meaning that it is the most specific `StructInfo` `S` for which `S1 <: S` and `S2 <: S` ("most specific" meaning that if there exists some other `S'` for which `S1 <: S'` and `S2 <: S'`, then `S <: S'`), modulo reasoning about arithmetic (for example, the compiler may judge that two shape expressions are _possibly_ equivalent rather than _definitely_ equivalent). The LUB is guaranteed to exist for any two `StructInfo` because all `StructInfo` are subtypes of `ObjectStructInfo`. + +We can define how to find the LUB of two structural information annotations (modulo arithmetic reasoning) as follows, in pseudocode: + +```python +def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: + if S2 is ObjectStructInfo: + return S1 + if S1 is ObjectStructInfo: + return S2 + if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): + return ObjectStructInfo() + if S1 and S2 are both ShapeStructInfo: + if S1.ndim == -1: + return S1 + if S2.ndim == -1: + return S2 + if S1.ndim != S2.ndim: + return ShapeStructInfo(ndim=-1) + if S1.ndim == S2.ndim: + if S1.values is undefined: + return S1 + if S2.values is defined: + return S2 + if S1.values can be statically proven to match S2.values: + return S1 + # values either proven not to match or unknown + return ShapeStructInfo(ndim=S1.ndim) # leave values undefined + if S1 and S2 are both TensorStructInfo: + ndim = S1.ndim if S1.ndim == S2.ndim else -1 + dtype = S1.dtype if S1.dtype == S2.dtype else Void + if ( + S1.ndim == -1 or S2.ndim == -1 or S1.ndim != S2.ndim + or S1.shape is undefined or S2.shape is undefined + ): + return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + # both shapes are defined + if S1.shape can be proven to equal S2.shape: + return S1 + # either proven to be unequal or cannot be concluded whether they are equal + return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + if S1 and S2 are both TupleStructInfo: + if S1.fields and S2.fields are of different lengths: + return ObjectStructInfo() + return TupleStructInfo( + unify_struct_info(S1.fields[i], S2.fields[i]) + for 0 <= i < length of S1.fields + ]) + if S1 and S2 are both FuncStructInfo: + if S1.params and S2.params are not both defined or both undefined: + return ObjectStructInfo() + if S1.params and S2.params are both undefined: + # they must be the same function, not bothering to check eta-equivalence + if S1.derive_func == S2.derive_func: + return S1 + return FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive) + if S1.params and S2.params are both defined: + if S1.params and S2.params do not have the same length: + return ObjectStructInfo() + unified_params = [] + for 0 <= i < length of S1.params: + unified_param = unify_struct_info(S1.params[i], S2.params[i]) + # That is, if the params judged to be equal, use them. + # If there is some pair that is not equal, + # we can't unify these types except with ObjectStructInfo. + # This rule should suffice in practice; otherwise we would + # need to give a full definition of the GLB + if unified_param <: S1.params[i] and unified_param <: S2.params[i]: + unified_params[i] = unified_param + else: + return ObjectStructInfo() + return FuncStructInfo(params=unified_params, ret=unify_struct_info(S1.ret, S2.ret)) +``` + ## Deriving the Structural Information for Each Expression -For each expression type, we can recursively build up the structural information associated with the expression. +For each kind of expression, we can recursively build up the structural information associated with the expression. ### Auxiliary Procedures @@ -585,41 +457,14 @@ For each expression type, we can recursively build up the structural information There are two special `derive_func` values built into the compiler that are used for checking the structural information of `PackedFunc`s. -The first is `default_derive`, giving a simple way to determine the resulting structural information of a `PackedFunc` from its type arguments. `default_derive` takes one argument that is a `Call` node and is defined as follows: -1. Suppose its call node argument is `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`. -2. If `type_args` is of length 0, then return `ObjectStructInfo()`. -3. If `type_args` is of length 1, then return `wrap_type(aT1)`. -4. If `type_args` is of a greater length than 1, then return `TupleStructInfo(fields=[wrap_type(aT1), wrap_type(aT2), ..., wrap_type(aTn)])`. +The first is `default_derive`, giving a simple way to determine the resulting structural information of a `PackedFunc` from its `StructInfo` arguments. `default_derive` takes one argument that is a `Call` node and is defined as follows: +1. Suppose its call node argument is `Call(op, [arg1, arg2, ..., argn], sinfo_args=[aS1, aS2, ..., aSn])`. +2. If `sinfo_args` is of length 0, then return `ObjectStructInfo()`. +3. If `sinfo_args` is of length 1, then return `aS1`. +4. If `sinfo_args` is of a greater length than 1, then return `TupleStructInfo(fields=[aS1, aS2, ..., aSn])`. The second is `empty_derive`, which is the weakest possible derivation. It simply returns `ObjectStructInfo` regardless of its argument. This is used for worst-case deducation of `StructInfo` for a `PackedFunc`. -**Wrapping Types** - -For deriving the structural information for a `PackedFunc` call, the type arguments are converted into structural information. This is a straightforward procedure, given here in pseudocode: - -```python -def wrap_type(t: Type) -> StructInfo: - if t is ObjectType: - return ObjectStructInfo() - if t is PackedFuncType: - # leave params undefined; see default_derive below - return FuncStructInfo(ret=ObjectStructInfo(), derive_func=default_derive) - if t is FuncType: - # leave derive_func undefined - return FuncStructInfo( - params=[wrap_type(arg_type) for arg_type in t.arg_types], - ret=wrap_type(t.ret_type) - ) - if t is TupleType: - return TupleStructInfo(fields=[wrap_type(field) for field in t.fields]) - if t is ShapeType: - # leave values undefined - return ShapeStructInfo(ndim=t.ndim) - if t is DynTensorType: - # leave shape undefined - return TensorStructInfo(ndim=t.ndim, dtype=t.dtype) -``` - **Erasing Out-of-Scope Information** When returning a value from an inner scope to an outer scope (namely, the `body` field of a `SeqExpr`, which may use variables defined in the binding blocks, and the `body` field of a `Function`, which may use variables defined in the function body), it may be possible for the derived `TensorStructInfo` or `ShapeStructInfo` to contain Relax variables or shape vars that have gone out of scope. We defined a procedure to check for any of these out-of-scope variables and weaken the structural information not to include it. The procedure is defined below, in pseudocode: @@ -725,180 +570,59 @@ def get_shape_var_mapping(S1: StructInfo, S2: StructInfo) -> {tir::Var, PrimExpr return {} ``` -**Checking Compatibility** - -In many cases during the derivation of structural information, it is important to judge when two distinct structural information encodings are compatible with each other or when they are too different from each other to be reconciled, which can indicate an error. In the case of shape information, this could mean having two symbolic shapes that can be proven not to be equal to each other. Because shape expressions can contain arithmetic and it can be very difficult to statically prove whether two arithmetic expressions are equal, we permit the compiler implementation to make a best-effort attempt to prove equality for arithmetic expressions. (The user can insert a `MatchCast` to check definitively.) Since the checks are best-effort, the compatibility check will only report incompatibility if two values are _definitely_ different from each other. - -We can check if some structural information `S1` is accepted where structural information `S2` is expected by the process given below, which we refer to as `check_compability(S1, S2)` for convenience. `check_compatibility` can find that `S1` and `S2` are compatible, possibly compatible, or incompatible. "Incompatible" indicates a definite mismatch that should result in a compiler error; "possibly compatible" indicates that the structures may or may not match and should likely result in a compiler warning (indicating that a user may want to insert a dynamic check). An invariant that should should is that if `check_compatibility(S1, S2)` returns "compatible" or "possible compatible", `erase_struct_info(S1) <: erase_struct_info(S2)` should hold; that is, compatibility of structural information should be consistent with typing rules. - -1. If `S2` is `ObjectStructInfo`, then they are compatible. -2. Otherwise, if `S1` and `S2` are not both `TensorStructInfo` or both `TupleStructInfo`, etc. (besides `ObjectStructInfo`), then report an incompatibility. -3. If `S1` and `S2` are both `TupleStructInfo`: - 1. If `S1.fields` is not the same length as `S2.fields`, they are incompatible - 2. Call `check_compability(S1.fields[i], S2.fields[i])` for all `i`. If any pair of fields is incompatible, then `S1` and `S2` are incompatible. If no pair of fields is incompatible but at least one is possibly compatible, then `S1` and `S2` are possibly compatible. If all pairs of fields are compatible, then `S1` and `S2` are compatible. -4. If `S1` and `S2` are both `ShapeStructInfo`: - 1. `S2.ndim` is -1, then they are compatible. - 2. Otherwise, give an error if `S1.ndim` does not match `S2.ndim`. - 3. If `values` is not defined for `S2`, then they are compatible. - 4. If `values` is defined for `S2` but not defined for `S1`, then they are possibly compatible. - 5. If `values` is defined for both `S1` and `S2`, then the two are incompatible if `S1.values[i]` can be proven to be _not_ equal to `S2.values[i]` for some `i`. If all members can be proven to be equal, then they are compatible. Otherwise, if at least one pair of values cannot be proven to be either equal or unequal, then they are possibly compatible. -5. If `S1` and `S2` are both `TensorStructInfo`: - 1. If `S2.dtype` is not `Void` and does not match `S1.dtype`, then they are incompatible. - 2. If `S2.ndim` is not -1 and does not match `S1.ndim`, then they are incompatible. - 3. If `S2.shape` is not defined, then they are compatible. - 4. If `S2.shape` is defined and `S1.shape` is not defined, then they are possibly compatible. - 5. Otherwise, if both `shape` fields are given and either is a `Var`, then consider `S1` and `S2` compatible if the compiler can statically prove that the `Var` holds the same value as the other `shape` field, consider them possibly compatible if the compiler cannot draw a conclusion one way or the other, and consider them incompatible if the `Var` definitely has a different value from the other `shape`. - 6. If both `shape` fields are given and they are both `ShapeExpr` nodes, then `S1` and `S2` are incompatible if the compiler can prove that some dimension of `S1.shape` is _not_ equal to the corresponding dimension of `S2.shape`. Otherwise, if the all dimensions can be proven to be equal, then consider them compatible. If at least one pair of dimensions cannot be proven to be equal or unequal, consider them possibly compatible. -6. If `S1` and `S2` are both `FuncStructInfo`: - 1. If `S1` and `S2` don't both have defined `params` or both have undefined `params`, consider them incompatible. - 2. If both `S1` and `S2` have undefined `params`, consider them compatible if they have an identical `derive_func` and consider them possibly compatible if they have different `derive_func`s (as they is no further way to introspect the `derive_func` and draw static conslusions about `PackedFunc`s). - 3. If `params` is defined for both `S1` and `S2`: - 1. Consider them incompatible if the `params` have different lengths. - 2. Next, map unbound shape variables as follows: Get a variable mapping `m` by applying `get_shape_var_mapping(S1.params[i], S2.params[i])` for all values of `i`, taking the union of all resulting mappings. Next, substitute all occurrences of the shape variables in `S1` with their values in `m`. - 3. If `check_compatible(S2.params[i], S1.params[i])` (note the direction of the check: see the subtyping rule for `FuncType`) is incompatible for any `i` or if `check_compatible(S1.ret, S2.ret)` is incompatible, then they are incompatible. Otherwise, if `check_compatible(S2.params[i], S1.params[i])` is possibly compatible for any `i` or if `check_compatible(S1.ret, S2.ret)` is possibly compatible, consider `S1` and `S2` possibly compatible. Consider `S1` and `S2` compatible only if all checks are compatible. - -**Unification** - -Analogously to subtyping, we can also consider a hierarchy of structural information, considering some structural information to more or less specific than other structural information. Accordingly, we can also define a least upper bound for structural information, as with types. - -We can define an analogue to subtyping for structural information, as below. We say that `S1` is more specific than `S2` and denote it as `S1 <<: S2` (to distinguish from the notation on subtyping) based on the conditions given here. As an invariant, if `S1 <<: S2` holds, then `erase_struct_info(S1) <: erase_struct_info(S2)`, though the converse may not be true. -1. Reflexivity: `S1 <<: S1` for all `S1`. -2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <<: S2` and `S2 <<: S3`, then `S1 <<: S3`. -3. For all `S1`, `S1 <<: ObjectStructInfo()`. -4. For `TensorStructInfo`: - 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=-1, dtype=d)`. - 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=n, dtype=Void, shape=s)`. - 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (not undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <<: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. - 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <<: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ or _possibly_ statically equal. -5. For `ShapeStructInfo`: - 1. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (possibly undefined), `ShapeStructInfo(ndim=n, values=v) <<: ShapeStructInfo(ndim=-1)`. - 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <<: ShapeStructInfo(ndim=n, values=undefined)`. - 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <<: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ or _possibly_ statically equal. -6. Given two lists of structural information `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <<: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <<: fields2[i]`. -7. For `FuncStructInfo`: - 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <<: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. - 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <<: F2` only if `F1.derive_func` and `F2.derive_func` are identical. - 3. Given two lists of structural information parameters `P1` and `P2` and two structural information annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <<: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <<: P1[i]` and `R1 <<: R2`. - -Given these rules, we can define how to unify (get the LUB) of two structural information annotations as follows (in pseudocode): -```python -def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: - if S2 is ObjectStructInfo: - return S1 - if S1 is ObjectStructInfo: - return S2 - if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): - return ObjectStructInfo() - if S1 and S2 are both ShapeStructInfo: - if S1.ndim == -1: - return S1 - if S2.ndim == -1: - return S2 - if S1.ndim != S2.ndim: - return ShapeStructInfo(ndim=-1) - if S1.ndim == S2.ndim: - if S1.values is undefined: - return S1 - if S2.values is defined: - return S2 - if S1.values can be statically proven to match S2.values: - return S1 - # values either proven not to match or unknown - return ShapeStructInfo(ndim=S1.ndim) # leave values undefined - if S1 and S2 are both TensorStructInfo: - ndim = S1.ndim if S1.ndim == S2.ndim else -1 - dtype = S1.dtype if S1.dtype == S2.dtype else Void - if ( - S1.ndim == -1 or S2.ndim == -1 or S1.ndim != S2.ndim - or S1.shape is undefined or S2.shape is undefined - ): - return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined - # both shapes are defined - if S1.shape can be proven to equal S2.shape: - return S1 - # either proven to be unequal or cannot be concluded whether they are equal - return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined - if S1 and S2 are both TupleStructInfo: - if S1.fields and S2.fields are of different lengths: - return ObjectStructInfo() - return TupleStructInfo( - unify_struct_info(S1.fields[i], S2.fields[i]) - for 0 <= i < length of S1.fields - ]) - if S1 and S2 are both FuncStructInfo: - if S1.params and S2.params are not both defined or both undefined: - return ObjectStructInfo() - if S1.params and S2.params are both undefined: - # they must be the same function, not bothering to check eta-equivalence - if S1.derive_func == S2.derive_func: - return S1 - return FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive) - if S1.params and S2.params are both defined: - if S1.params and S2.params do not have the same length: - return ObjectStructInfo() - unified_params = [] - for 0 <= i < length of S1.params: - unified_param = unify_struct_info(S1.params[i], S2.params[i]) - # That is, if the params judged to be equal, use them. - # If there is some pair that is not equal, - # we can't unify these types except with ObjectStructInfo - # See the use of GLB with FuncTypes - if unified_param <<: S1.params[i] and unified_param <<: S2.params[i]: - unified_params[i] = unified_param - else: - return ObjectStructInfo() - return FuncStructInfo(params=unified_params, ret=unify_struct_info(S1.ret, S2.ret)) -``` - ### Derivation Rules -Let `Δ` be the structural information context for Relax variables (to distinguish from `Γ` for types) and let `Σ` track which shape variables are in scope. +Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track which shape variables are in scope. -1. «Prepopulate `Δ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Δ` corresponding to that `GlobalVar`.» -2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Δ[v]` for the structural information. +1. «Prepopulate `Γ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Γ` corresponding to that `GlobalVar`.» +2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Γ[v]` for the structural information. 3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. -4. For `Tuple(fields)`, the resulting structural information is `TupleStructInfo([f.struct_info for f in fields])`, after deriving the structural information for the fields recursively. +4. For `Tuple(fields)`, suppose that `fields` is comprised of expressions `E1`, `E2`, ..., `En`. Let the `StructInfo` for these expressions be `S1`, `S2`, ..., `Sn`, respectively. Then the resulting `StructInfo` is `TupleStructInfo(fields=[S1, S2, ..., Sn])`. 5. For `ShapeExpr(values)`, the resulting structural information is `ShapeStructInfo(ndim, values)`, where `ndim` is the length of `values`. 6. For `If(cond, true_branch, false_branch)`, we compare the structural information of `true_branch` and `false_branch` (call these `S_t` and `S_f`, respectively). The resulting structural information is `unify_struct_info(S_t, S_f)`. 7. For `SeqExpr(blocks, body)`: 1. For each binding block in `blocks` (call the current one `block`): - 1. Process each binding in the block, updating `Δ` and `Σ` accordingly (this is discussed in detail below). - 2. If `block` is a `DataflowBlock`, then remove all `DataflowVar`s introduced in `block` from `Δ` before proceeding to the next block. + 1. Process each binding in the block, updating `Γ` and `Σ` accordingly (this is discussed in detail below). + 2. If `block` is a `DataflowBlock`, then remove all `DataflowVar`s introduced in `block` from `Γ` before proceeding to the next block. 2. Next derive the structural information for `body`. Let us call this `S`. - 3. Remove all Relax variables introduced in `blocks` from `Δ` and all shape variables introduced in `blocks` from `Σ`. - 4. The structural information of the entire `SeqExpr` is `erase_to_well_defined(S, Δ, Σ)`. + 3. Remove all Relax variables introduced in `blocks` from `Γ` and all shape variables introduced in `blocks` from `Σ`. + 4. The structural information of the entire `SeqExpr` is `erase_to_well_defined(S, Γ, Σ)`. 8. For handling variable bindings: - 1. If `v` is the argument to a function, then if `v` has a structural annotation `S`, set `Δ[v]` to `S`. Add any unbound shape variables in `S` to `Σ`. If `v` does not have a structural annotation, set `Δ[v]` to `ObjectStructInfo()`. + 1. If `v` is the argument to a function, then if `v` has a structural annotation `S`, set `Γ[v]` to `S`. Add any unbound shape variables in `S` to `Σ`. If `v` does not have a structural annotation, set `Γ[v]` to `ObjectStructInfo()`. 2. In the general `VarBinding(v, e)`: - 1. If `e` is a function literal, then recursion is permitted. In this case, `v` must have a structural annotation `Sv`. Derive the structural information for `e` as follows: Set `Δ[v]` to `Sv`, apply the normal rule for function literals (given below) to `e` to derive structural information `Se`, and finally remove `v` from `Δ`. Raise an error if `Se` and `Sv` are not compatible (via `check_compatibility`). + 1. If `e` is a function literal, then recursion is permitted. In this case, `v` must have a structural annotation `Sv`. Derive the structural information for `e` as follows: Set `Γ[v]` to `Sv`, apply the normal rule for function literals (given below) to `e` to derive structural information `Se`, and finally remove `v` from `Γ`. Raise an error if `Se` and `Sv` are not compatible (via `check_compatibility`). 2. Otherwise, derive the structural information of `e` and call it `Se`. - 3. If `v` has a structural annotation `Sv`, then apply `check_compatibility` to `Sv` and `Se`. If they are compatible, then set `Δ[v]` to `Sv` (respecting the user's intent in giving an annotation). Give a warning if `Sv` is more specific than `Se`. If are not compatible, then raise an error. - 4. If `v` does not have a structural annotation, then set `Δ[v]` to `Se`. + 3. If `v` has a structural annotation `Sv`, then apply `check_compatibility` to `Sv` and `Se`. If they are compatible, then set `Γ[v]` to `Sv` (respecting the user's intent in giving an annotation). Give a warning if `Sv` is more specific than `Se`. If are not compatible, then raise an error. + 4. If `v` does not have a structural annotation, then set `Γ[v]` to `Se`. 3. For `MatchCast(v, value, S)`: 1. Derive the structural information of `value` and call it `Sv`. 2. Add any new shape variables in `S` to `Σ`. - 3. If `S <<: Sv` and `Sv <<: S` do not both hold, give a warning, as this indicates a cast that will always fail at run time. - 4. If `v` is given and it has a structural annotation `S'`, then give an error if `S` and `S'` are not compatible via `check_compatibility`. If they are compatible, then set `Δ[v]` to `S'` (respecting the user's intent in giving an annotation). (TODO: It doesn't seem very sensible to have a dynamic cast and give a different annotation, perhaps we should simply not permit doing that.) - 5. If `v` is given and it does not have a structural annotation, then set `Δ[v]` to `S`. -9. For `TupleGetItem(tuple_value, i)`, derive the structural information for `tuple_value` and call it `St`. Raise an error if `St` is not `TupleStructInfo`. If `St` is `TupleStructInfo(fields)`, then raise an error if `fields` value has less than `i + 1` members (this should not happen if type checking has passed) and use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. + 3. If `S <: Sv` and `Sv <: S` both do not hold, give a warning, as this indicates a cast that will _always_ fail at run time. (Conversely, if `Sv <: S`, then the cast will always succeed.) + 4. If `v` is given and it has a structural annotation `S'`, then give an error if `S <: S'` does not hold. If they are compatible, then set `Γ[v]` to `S'` (respecting the user's intent in giving an annotation). (TODO: It doesn't seem very sensible to have a dynamic cast and give a different annotation, perhaps we should simply not permit doing that.) + 5. If `v` is given and it does not have a structural annotation, then set `Γ[v]` to `S`. +9. For `TupleGetItem(tuple_value, i)`: + 1. Derive the structural information for `tuple_value` and call it `St`. + 2. Raise an error if `St` is not `TupleStructInfo`. + 3. If `St` is `TupleStructInfo(fields)`, then raise an error if `fields` value has less than `i + 1` members. + 4. Use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. 10. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. 11. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: 1. For a call to an `Op`, we use the manually defined `FInferStructInfo` macro if it has been defined and `ObjectStructInfo()` if it has not. `FInferStructInfo` is a function that takes in the call node and returns the structural information of the result. 2. Otherwise, derive the structural information for `op` and call it `Sf`. Next derive the structural information for the args and call it `S1`, `S2`, ..., and `Sn`. 1. Give an error if `Sf` is not `FuncStructInfo`. 2. If the `derive_func` field of `Sf` is defined, then apply the `derive_func` macro to the call node to derive the structural information for the call node, ignoring the `ret` field of `Sf`. - 3. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. + 3. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. Let the members of params be `P1`, `P2`, ..., `Pn`. 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. Replace all variables in `m` with their mapping in `Sf`. - 5. After the substitutions, give an error if `check_compatibility` indicates that the `i`th member of `params` and `Si` are incompatible for some `i` (warn if they are only possibly compatible). - 6. Use `erase_to_well_defined(Sf.ret, Δ, Σ)` as the resulting structural information. + 5. After the substitutions, give an error if `Pi <: Si` does not hold for some `i` (give a warning if it _possibly_ holds). + 6. Use `erase_to_well_defined(Sf.ret, Γ, Σ)` as the resulting structural information. 12. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: 1. Let `S1`, `S2`, ..., `Sn` be the structural information of the parameters. If `vi` has a structural annotation, then use that annotation for `Si`; if not, use `ObjectStructInfo()`. Let `Sr` be `ret_struct_info` if it is defined and `ObjectStructInfo()` if not. - 2. If the function is bound to a `GlobalVar` `gv`, set `Δ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. - 3. For each of the `vi`, set `Δ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. + 2. If the function is bound to a `GlobalVar` `gv`, set `Γ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. + 3. For each of the `vi`, set `Γ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. 4. Derive the structural information for `body`, calling it `Sb`. 5. Give an error if `Sb` is incompatible with `Sr` via `check_compatibility` (warn if only possibly compatible). - 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Δ, Σ))`. - 7. Remove all variables added to `Δ` and `Σ` during the derivation. + 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Γ, Σ))`. + 7. Remove all variables added to `Γ` and `Σ` during the above steps of the derivation. ### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks @@ -914,7 +638,41 @@ Since most dynamic structure checks are done for safety, it may be feasible to i A further case that may be of interest might be using an explicit wildcard dimension (e.g., using `tir::Any`) to allow for dimensions to be specified as "unknown" in function return shapes. As described at present, the only way for a function to specify a partly unknown return shape is to make the entire return shape unknown (`RuntimeDepShape`), which loses partial shape information. -This addition would entail some, as `FInferStructInfo` and `derive_func` macros would have to deal with potential `tir::Any` nodes. However, the advantage of implementing it would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using undefined `shape` fields means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchCast` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. +This addition would entail some, as `FInferStructInfo` and `derive_func` macros would have to deal with potential `tir::Any` nodes. However, the advantage of implementing it would be increasing the amount of shape information present at compile time and hence that could be used by lower levels of the compiler stack. The present defaults of using more general `StructInfo` means that either of these changes could be pursued in the future without breaking existing code, since these would generally have to be paired with explicit `MatchCast` dynamic checks, which will still work even if we add rules to automatically infer the shapes in those cases. + +## Traditional Types + +For comparison with Relay, it may be useful to simplify `StructInfo` into more traditional types that do not contain any expressions (such as in `TensorStructInfo` and `ShapeStructInfo`). We can define Relax types as follows: + +``` +Type ::= + DynTensorType(ndim: int, dtype: DataType) + | ShapeType(ndim: int) + | TupleType(fields: [Type]) + | PackedFuncType() + | FuncType(arg_types: [Type], ret_type: Type) + | ObjectType() +``` + +We can "erase" `StructInfo` into types by the following procedure (in psuedocode): +```python +def erase_struct_info(si: StructInfo) -> Type: + if si is TensorStructInfo: + return DynTensorType(ndim=si.ndim, dtype=si.dtype) + if si is ShapeStructInfo: + return ShapeType(ndim=si.ndim) + if si is TupleStructInfo: + return TupleType(fields=[erase_struct_info(field) for field in si.fields]) + if si is FuncStructInfo: + # this should be the case only for packed funcs + if si.params is not specified: + return PackedFuncType() + return FuncType( + arg_types=[erase_struct_info(arg_type) for arg_type in si.params], + ret_type=erase_struct_info(si.ret)) + # only remaining case is ObjectStructInfo + return ObjectType() +``` # Detailed Semantics @@ -922,7 +680,7 @@ This addition would entail some, as `FInferStructInfo` and `derive_func` macros In the `IRModule`, every mapping of a `GlobalVar` to a `Function` node or a TIR `PrimFunc` should be processed first and added to the global scope. «Global functions that have a `global_symbol` attribute should be externally linked, meaning that they can be invoked as program entry points; those that do not have a `global_symbol` attribute can be called only from within the global functions in the `IRModule`.» -The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax; these objects have type `Object` and can be used only by the `call_tir` operator. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. +The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax; these objects are of `ObjectStructInfo` and can be used only by the `call_tir` operator. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. ## Evaluating Expressions @@ -931,11 +689,11 @@ For each expression, we define how it affects the program's visible state and th 1. The node `Constant(value)` creates a new tensor whose contents are `value`. 2. A variable (whether `Var`, `DataflowVar` , or `GlobalVar`) evaluates to the stored value for that variable in the current scope. 3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. -4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per type checking, must evaluate to a tuple) and then returning the `i`th field of the result. +4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per `StructInfo` checking, must evaluate to a tuple) and then returning the `i`th field of the result. 5. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. 6. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. 7. The node `If(cond, true_branch, false_branch)` is evaluated as follows: - 1. First `cond` is evaluated. Let the result be `r` (per type checking, it must be a bool scalar). + 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a bool scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. 8. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. @@ -971,9 +729,6 @@ These semantic rules assume a single thread of evaluation on a single host machi The above evaluation rules are general, but leave much room for implementations of operators to specify custom semantics. Certain operators are used to perform common operations and will be discussed here as well. -- `call_tir(prim_func, arg1, arg2, ..., argn, shape, type_args=[aT])`: `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). The `shape` argument gives the shapes of the result of calling the TIR `PrimFunc`: It must be either of `ShapeType` (corresponding to returning a single tensor) or `TupleType` whose members are `ShapeType` (corresponding to returning a tuples of tensors). The type arg `aT` gives the type of the result of calling the `PrimFunc` and it must correspond to `shape` (namely, if `shape` is of `ShapeType`, `aT` must be a `DynTensorType`; if `shape` is of `TupleType`, `aT` must be a `TupleType` whose fields are `ShapeType`). `aT` is used especially to provide the `dtype` of returned tensors. - - Based on `shape`, the resulting tensor or tuple `r` will be allocated according to the sizes given in `shape`. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. - -- «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, shape, type_args=[aT])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the result tensor, so purity is not assumed. A type argument `aT` must be given to specify the return type.» +- `call_tir(prim_func, arg1, arg2, ..., argn, sinfo_args=[aS])`: `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). The `StructInfo` arg `aS` gives the `StructInfo` of the result of calling the `PrimFunc`; it must be a `TensorStructInfo` with a `shape` field corresponding to a constant shape expression and a non-`Void` `dtype`, denoting the shape of the resulting tensor, or a a `TupleStringInfo` where all the `fields` are `TensorStructInfo`. Based on `aS`, the resulting tensor or tuple `r` will be allocated according to the sizes given in their `shape` fields. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. «If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. +- «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, sinfo_args=[aS])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the results, so purity is not assumed. `aS` denotes the `StructInfo` for the result.» - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. From b2a54e894bb1fa38a8968a199cd92620d524b502 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 31 Jan 2023 20:47:11 -0500 Subject: [PATCH 19/47] First draft of PrimValues --- relax_spec.md | 92 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 62 insertions(+), 30 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index cc911b314a53..fc7a2bf33fb0 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -58,6 +58,7 @@ DataType ::= Int(bitwidth: int) StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) | ShapeStructInfo(values: [PrimExpr]?, ndim: int) + | PrimStructInfo(dtype: DataType) | ObjectStructInfo() | TupleStructInfo(fields: [StructInfo]) | FuncStructInfo(params: [StructInfo]?, ret: StructInfo, derive_func: EnvFunc?*) @@ -71,6 +72,9 @@ Expr ::= Constant(data: NDArray) | GlobalVar(name_hint: string) | Tuple(fields: [Expr]) | SeqExpr(blocks: [BindingBlock], body: Expr) + | PrimValue(value: PrimExpr) + | StringImm(value: string) + | DataTypeImm(value: DataType) | Function(params: [Var], body: Expr, ret_struct_info: StructInfo?, attrs: Attrs?) | If(cond: Expr, true_branch: Expr, false_branch: Expr) | ExternFunc(global_symbol: string) @@ -106,15 +110,18 @@ This specification provides a more detailed description of what each expression 3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. -6. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. +5. `PrimValue` nodes construct immutable scalar values from `PrimExpr`s, primarily for interacting with `ExternFunc`s or operators. These scalars are boxed within TVM objects, allowing them to be nested inside TVM's containers. (By contrast, zero-dimensional tensors defined via `Constant` are mutable.) +6. `StringImm` nodes construct strings, intended primarily for interacting with `ExternFunc`s or operators. +7. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators. +8. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s). 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. -7. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. -8. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. -9. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: +9. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. +10. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. +11. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. 1. The types must match: All `StructInfo` variants correspond to a category of value value (`TensorStructInfo` to a tensor value, `ShapeStructInfo` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: @@ -126,8 +133,8 @@ This specification provides a more detailed description of what each expression The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. -10. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. -11. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. +12. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +13. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. @@ -155,9 +162,10 @@ Exiting with an error and infinitely looping are traditionally considered "[dive Analogously to a type system in most languages, Relax tracks structural information (referred to as `StructInfo` in the implementation) related to the categories of values in Relax: 1. `TensorStructInfo` corresponds to tensor values, giving the scalar data type, the number of dimensions (rank), and an expression that computes the tensor's shape (either a `ShapeExpr` or a `Var`), all of which are optional. 2. `TupleStructInfo` corresponds to tuple values, giving the `StructInfo` for each member of the tuple. -3. `ShapeStructInfo` corresponds to shape values, optionally giving the number of dimensions in the shape and an expression that computes the shape's dimensions (either a `ShapeExpr` or a `Var`). -4. `FunctionStructInfo` corresponds to function values (closures) and `PackedFunc`s (external functions), giving the types of the parameters, the return type, «and whether the function is pure.» -5. `ObjectStructInfo` is a parent to all Relax `StructInfo` and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. +3. `PrimStructInfo` corresponds to `PrimValue`s (immutable scalar values), giving their TIR datatype. +4. `ShapeStructInfo` corresponds to shape values, optionally giving the number of dimensions in the shape and an expression that computes the shape's dimensions (either a `ShapeExpr` or a `Var`). +5. `FunctionStructInfo` corresponds to function values (closures) and `PackedFunc`s (external functions), giving the types of the parameters, the return type, «and whether the function is pure.» +6. `ObjectStructInfo` is a parent to all Relax `StructInfo` and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. `StructInfo` is assigned to every variable in scope and every type of expression based on the values it returns via a set of inference rules defined later in the specification, making use of subtyping to assign more general `StructInfo` when a more specific one cannot be determined. «Relax is strongly typed, meaning that if the `StructInfo` inferred is less specific than the one expected, an error will be issued and an explicit check via `MatchCast` will be required.» @@ -180,7 +188,8 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. -- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. +- *Primitive values* (`PrimValue`s) represent immutable scalar values that are primarily intended for being passed to external procedures, like calls to `PackedFunc`s. As a rule of thumb, scalar values intended for arithmetical computations should be 0-rank tensors while scalar values meant to serve as metadata should be `PrimValue`s. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. ## Representation of Values at Run Time @@ -192,6 +201,7 @@ Possible specification in terms of the TVM object system: - Tuples are represented using TVM ADTs (algebraic data types), which are arrays of TVM objects with a tag (see `include/tvm/runtime/container/adt.h`). Tuples use a tag of 0. - At run time, closures are represented as a `ClosureObj` (see `include/tvm/runtime/container/closure.h`); in the Relax VM these more specifically use the `VMClosureObj` (see [`https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h`](https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h)). - Shape values are represented at run time as a `ShapeTuple` (see `include/tvm/runtime/container/shape_tuple.h`). +- Strings are represented using TVM's `String` container (see `include/tvm/runtime/container/string.h`). - We require objects other than the above values used by and returned by `PackedFunc` to inherit from TVM's `Object` class (defined in `include/tvm/runtime/Object.h`). Note that `PackedFunc`s are capable of using and returning all TVM POD (plain-old data) values (see `include/tvm/runtimes/packed_func.h`), which includes some representations that do not inherit from `Object`. In the future, we may define semantics for other values, but at present, these are *unsupported* in Relax and we make no guarantees about the semantics of calling `PackedFunc`s that use or return anything that does not inherit from `Object`. # Variable Scoping @@ -284,6 +294,7 @@ Tensor shapes are the primary motivation for including structural information in The structural information in Relax corresponds to the values in the language: * `TensorStructInfo` describes tensor values. The `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` (with `ShapeStructInfo`). If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation that returns a shape value, which can be useful for memory planning. * `ShapeStructInfo` describes shape values. It has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. +* `PrimStructInfo` describes `PrimValue`s, giving their TIR datatype. * `TupleStructInfo` describes tuple values, namely by giving the `StructInfo` for each of the tuple's members via `fields`. * `FuncStructInfo` describes closure values or `PackedFunc`s. There are two ways in which to specify `FuncStructInfo`: 1. By specifying `params` and `ret` (for closures). `params` gives the `StructInfo` corresponding to each of the function's parameters and `ret` gives the `StructInfo` corresponding to the result of calling the function. In this case, the `derive_func` field is left undefined. @@ -317,11 +328,12 @@ This section describes the run-time checking performed by `MatchCast(var, value, 2. If `shape` is a `ShapeExpr`, then compare the fields of the `ShapeExpr` to the concrete shape of `value`, dimension by dimension (comparing the `i`th field of the `ShapeExpr` to the `i`th dimension of the shape of `value`). Give an error if the number of the dimensions does not match the number of fields in the `ShapeExpr`. 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. 2. Otherwise, evaluate the field of the `ShapeExpr` and ensure that it matches the concrete value of the dimension. -3. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): +3. If `struct_info` is `PrimStructInfo(dtype)`, then check that `value` is a `PrimValue` and that the underlying scalar has datatype `dtype` in TIR (according to TIR's type-checking rules). +4. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): 1. If the `i`th member of `values` consists of only an unbound shape variable, then bind that variable to the `i`th field of the the concrete shape value. 2. Otherwise, evaluate the `i`th member of `values` and check that it is equal to teh `i`th field of the concrete shape value. -4. If `struct_info` is `TupleStructInfo(fields)`, then check that `value` is a tuple value with `n` fields, where `n` is the length of `fields`. Also recursively check the `i`th field of the tuple value against the `i`th member of `fields`. -5. If `struct_info` is `FuncStructInfo(params, ret, derive_func)`, then if `params` is defined, check that `value` is a closure value; if `derive_func` is defined, check that `value` is a `PackedFunc`. No further validation may be done on a `PackedFunc`. «If `value` is a closure value, then it can contain run-time structural information indicating the structural information of its intended arguments and return value that can be compared against `params` and `ret`.» +5. If `struct_info` is `TupleStructInfo(fields)`, then check that `value` is a tuple value with `n` fields, where `n` is the length of `fields`. Also recursively check the `i`th field of the tuple value against the `i`th member of `fields`. +6. If `struct_info` is `FuncStructInfo(params, ret, derive_func)`, then if `params` is defined, check that `value` is a closure value; if `derive_func` is defined, check that `value` is a `PackedFunc`. No further validation may be done on a `PackedFunc`. «If `value` is a closure value, then it can contain run-time structural information indicating the structural information of its intended arguments and return value that can be compared against `params` and `ret`.» ### Checking Structural Information at the Start and End of a Function @@ -368,7 +380,8 @@ Note that judging subtyping requires potentially reasoning about arbitrary `Shap 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=n, values=undefined)`. 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ statically equal. We say that `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` _possibly_ holds if `v1` and `v2` are _possibly_ statically equal. 6. Given two lists of `StructInfo` `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <: fields2[i]`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships for the fields only possibly holds. -7. For `FuncStructInfo`: +7. For `PrimStructInfo`, `PrimStructInfo(dt1) <: PrimStructInfo(dt2)` holds if `dt1` and `dt2` are the same. That is, we do not have subtyping for TIR datatypes or `PrimStructInfo`. +8. For `FuncStructInfo`: 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <: F2` only if `F1.derive_func` and `F2.derive_func` are identical. 3. Given two lists of `StructInfo` parameters `P1` and `P2` and two `StructInfo` annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <: P1[i]` and `R1 <: R2`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships given only possibly holds. @@ -385,6 +398,10 @@ def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: return S2 if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): return ObjectStructInfo() + if S1 and S2 are both PrimStructInfo: + if S1.dtype == S2.dtype: + return S1 + return ObjectStructInfo() if S1 and S2 are both ShapeStructInfo: if S1.ndim == -1: return S1 @@ -478,6 +495,8 @@ def erase_to_well_defined( if s is ObjectStructInfo: return s + if s is PrimStructInfo: + return s if s is TensorStructInfo: if s.shape is defined: if (s.shape is a Relax var that is not in var_scope @@ -530,6 +549,8 @@ For clarity, additional detail on how the mapping should be constructed is given def get_shape_var_mapping(S1: StructInfo, S2: StructInfo) -> {tir::Var, PrimExpr}: if S1 and S2 are not the same type: return {} + if S1 and S2 are both PrimStructInfo: + return {} if S1 and S2 are both TupleStructInfo: if S1.fields and S2.fields don't have the same length: return {} @@ -577,17 +598,20 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 1. «Prepopulate `Γ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Γ` corresponding to that `GlobalVar`.» 2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Γ[v]` for the structural information. 3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. -4. For `Tuple(fields)`, suppose that `fields` is comprised of expressions `E1`, `E2`, ..., `En`. Let the `StructInfo` for these expressions be `S1`, `S2`, ..., `Sn`, respectively. Then the resulting `StructInfo` is `TupleStructInfo(fields=[S1, S2, ..., Sn])`. -5. For `ShapeExpr(values)`, the resulting structural information is `ShapeStructInfo(ndim, values)`, where `ndim` is the length of `values`. -6. For `If(cond, true_branch, false_branch)`, we compare the structural information of `true_branch` and `false_branch` (call these `S_t` and `S_f`, respectively). The resulting structural information is `unify_struct_info(S_t, S_f)`. -7. For `SeqExpr(blocks, body)`: +4. For `PrimValue(prim_expr)`, the resulting `StructInfo` is `PrimStructInfo(dt)`, where `dt` is the datatype of `prim_expr`, derived according to the type-checking rules for TIR. +5. For `StringImm(s)`, the resulting `StructInfo` is `ObjectStructInfo()`. +6. For `DataTypeImm(dt)`, the resulting `StructInfo` is `ObjectStructInfo()`. +7. For `Tuple(fields)`, suppose that `fields` is comprised of expressions `E1`, `E2`, ..., `En`. Let the `StructInfo` for these expressions be `S1`, `S2`, ..., `Sn`, respectively. Then the resulting `StructInfo` is `TupleStructInfo(fields=[S1, S2, ..., Sn])`. +8. For `ShapeExpr(values)`, the resulting structural information is `ShapeStructInfo(ndim, values)`, where `ndim` is the length of `values`. +9. For `If(cond, true_branch, false_branch)`, we compare the structural information of `true_branch` and `false_branch` (call these `S_t` and `S_f`, respectively). The resulting structural information is `unify_struct_info(S_t, S_f)`. +10. For `SeqExpr(blocks, body)`: 1. For each binding block in `blocks` (call the current one `block`): 1. Process each binding in the block, updating `Γ` and `Σ` accordingly (this is discussed in detail below). 2. If `block` is a `DataflowBlock`, then remove all `DataflowVar`s introduced in `block` from `Γ` before proceeding to the next block. 2. Next derive the structural information for `body`. Let us call this `S`. 3. Remove all Relax variables introduced in `blocks` from `Γ` and all shape variables introduced in `blocks` from `Σ`. 4. The structural information of the entire `SeqExpr` is `erase_to_well_defined(S, Γ, Σ)`. -8. For handling variable bindings: +11. For handling variable bindings: 1. If `v` is the argument to a function, then if `v` has a structural annotation `S`, set `Γ[v]` to `S`. Add any unbound shape variables in `S` to `Σ`. If `v` does not have a structural annotation, set `Γ[v]` to `ObjectStructInfo()`. 2. In the general `VarBinding(v, e)`: 1. If `e` is a function literal, then recursion is permitted. In this case, `v` must have a structural annotation `Sv`. Derive the structural information for `e` as follows: Set `Γ[v]` to `Sv`, apply the normal rule for function literals (given below) to `e` to derive structural information `Se`, and finally remove `v` from `Γ`. Raise an error if `Se` and `Sv` are not compatible (via `check_compatibility`). @@ -600,13 +624,13 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 3. If `S <: Sv` and `Sv <: S` both do not hold, give a warning, as this indicates a cast that will _always_ fail at run time. (Conversely, if `Sv <: S`, then the cast will always succeed.) 4. If `v` is given and it has a structural annotation `S'`, then give an error if `S <: S'` does not hold. If they are compatible, then set `Γ[v]` to `S'` (respecting the user's intent in giving an annotation). (TODO: It doesn't seem very sensible to have a dynamic cast and give a different annotation, perhaps we should simply not permit doing that.) 5. If `v` is given and it does not have a structural annotation, then set `Γ[v]` to `S`. -9. For `TupleGetItem(tuple_value, i)`: +12. For `TupleGetItem(tuple_value, i)`: 1. Derive the structural information for `tuple_value` and call it `St`. 2. Raise an error if `St` is not `TupleStructInfo`. 3. If `St` is `TupleStructInfo(fields)`, then raise an error if `fields` value has less than `i + 1` members. 4. Use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. -10. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. -11. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: +13. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. +14. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: 1. For a call to an `Op`, we use the manually defined `FInferStructInfo` macro if it has been defined and `ObjectStructInfo()` if it has not. `FInferStructInfo` is a function that takes in the call node and returns the structural information of the result. 2. Otherwise, derive the structural information for `op` and call it `Sf`. Next derive the structural information for the args and call it `S1`, `S2`, ..., and `Sn`. 1. Give an error if `Sf` is not `FuncStructInfo`. @@ -615,7 +639,7 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. Replace all variables in `m` with their mapping in `Sf`. 5. After the substitutions, give an error if `Pi <: Si` does not hold for some `i` (give a warning if it _possibly_ holds). 6. Use `erase_to_well_defined(Sf.ret, Γ, Σ)` as the resulting structural information. -12. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: +15. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: 1. Let `S1`, `S2`, ..., `Sn` be the structural information of the parameters. If `vi` has a structural annotation, then use that annotation for `Si`; if not, use `ObjectStructInfo()`. Let `Sr` be `ret_struct_info` if it is defined and `ObjectStructInfo()` if not. 2. If the function is bound to a `GlobalVar` `gv`, set `Γ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. 3. For each of the `vi`, set `Γ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. @@ -648,6 +672,7 @@ For comparison with Relay, it may be useful to simplify `StructInfo` into more t Type ::= DynTensorType(ndim: int, dtype: DataType) | ShapeType(ndim: int) + | PrimType(dtype: DataType) | TupleType(fields: [Type]) | PackedFuncType() | FuncType(arg_types: [Type], ret_type: Type) @@ -661,6 +686,8 @@ def erase_struct_info(si: StructInfo) -> Type: return DynTensorType(ndim=si.ndim, dtype=si.dtype) if si is ShapeStructInfo: return ShapeType(ndim=si.ndim) + if si is PrimStructInfo: + return PrimType(dtype=si.dtype) if si is TupleStructInfo: return TupleType(fields=[erase_struct_info(field) for field in si.fields]) if si is FuncStructInfo: @@ -689,20 +716,23 @@ For each expression, we define how it affects the program's visible state and th 1. The node `Constant(value)` creates a new tensor whose contents are `value`. 2. A variable (whether `Var`, `DataflowVar` , or `GlobalVar`) evaluates to the stored value for that variable in the current scope. 3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. -4. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per `StructInfo` checking, must evaluate to a tuple) and then returning the `i`th field of the result. -5. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. -6. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. -7. The node `If(cond, true_branch, false_branch)` is evaluated as follows: +4. The node `PrimType(prim_expr)` evaluates the `PrimExpr` `prim_expr` first, obtaining a resulting `pv`. It then creates an immutable `PrimValue` containing `pv`. +5. The node `StringImm(s)` creates an immutable string container whose contents is `s`. It does not necessarily have to be a _new_ string container if, for example, string interning is implemented. +6. The node `DataTypeImm(dt)` creates a new immutable datatype representation. +7. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per `StructInfo` checking, must evaluate to a tuple) and then returning the `i`th field of the result. +8. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. +9. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. +10. The node `If(cond, true_branch, false_branch)` is evaluated as follows: 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a bool scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. -8. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. -9. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: +11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. +12. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. 1. If `op` evaluated to a closure, push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) 2. If `op` evaluated to a `PackedFunc`, simply invoke it. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. -9. For the node `SeqExpr(blocks, body)`, we evaluate as follows: +13. For the node `SeqExpr(blocks, body)`, we evaluate as follows: 1. Push a new scope onto the stack. 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: 1. If the binding is `MatchCast(var, value, struct_info)`, perform the structure matching and shape variable updates as described in the structural information section. If `var` is provided, `var` will be bound to `value` in the current scope; this assignment is aliasing and no new value is allocated. If `var` is not provided, then the structural check is performed and shape variables are updated, but no new binding is introduced. @@ -717,6 +747,8 @@ Optimizations are allowed to reorder and modify the operations of a program in a «Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchCast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» +For immutable containers like those for the results of `PrimValue`, `StringImm`, and `DataTypeImm`, it is not required for the results of evaluating these expressions to be _new_ containers—it is permitted for the compiler to reuse existing objects provided that the values contained within are identical. This optimization is called [interning](https://en.wikipedia.org/wiki/String_interning). However, for operations that return new mutable values (in particular, operations that return tensor values), those _must_ be newly allocated, since reusing values can affect the behavior under aliasing. + The specification makes no guarantees about certain memory-related properties and hence also does not consider them to be "visible behaviors": - Whether an allocation happens at a given point. Compiler implementations are permitted to reuse already-allocated memory if it would not interfere with visible state in any other way, per the aliasing rules (`PackedFunc`s or operators may mutate values that are passed to them and those mutations should be visible as per aliasing in this specification). Copying values or sharing representations (e.g., interning constants) between values may be done only if they will not affect any other visible behaviors, dependent on the aliasing behavior. From 37f9bdd0b6591a40518d2a87b45451c22535bfac Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 17:19:02 -0500 Subject: [PATCH 20/47] Example of what DataTypeImm is used for --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index fc7a2bf33fb0..d502302be21c 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -112,7 +112,7 @@ This specification provides a more detailed description of what each expression 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. 5. `PrimValue` nodes construct immutable scalar values from `PrimExpr`s, primarily for interacting with `ExternFunc`s or operators. These scalars are boxed within TVM objects, allowing them to be nested inside TVM's containers. (By contrast, zero-dimensional tensors defined via `Constant` are mutable.) 6. `StringImm` nodes construct strings, intended primarily for interacting with `ExternFunc`s or operators. -7. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators. +7. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators (e.g., for TIR intrinsics that take a datatype as an input). 8. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s). 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. From 6e0578f4a932b96be734d53bce4535ab7c3ca340 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 17:22:08 -0500 Subject: [PATCH 21/47] Restrictions on PrimValue and PrimStructInfo --- relax_spec.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index d502302be21c..8b02492e61d0 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -280,6 +280,8 @@ The following criteria apply to all programs (including before normalization): 13. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. 14. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. 15. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. +16. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s. +17. `PrimStructInfo` annotations should use only the `Int` and `Float` datatypes for their `dtype` fields. Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. From 398c46ff27cb282130b308f4fa82e0bf2fc03554 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 17:39:33 -0500 Subject: [PATCH 22/47] Unify the notation used for TIR dtypes and Relax dtypes --- relax_spec.md | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 8b02492e61d0..a3c187855603 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -51,10 +51,11 @@ PrimExpr ::= | Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr) # (others may be added later, as deemed necessary) -DataType ::= Int(bitwidth: int) - | Float(bitwidth: int) - | Bool() - | Void() +# Also from TIR +DataType ::= Int(bits: int, lanes: int) + | UInt(bits: int, lanes: int) + | Float(bits: int, lanes: int) + | Handle(bits: int, lanes: int) StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) | ShapeStructInfo(values: [PrimExpr]?, ndim: int) @@ -99,7 +100,19 @@ Binding ::= Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) ``` -*The `derive_func` field of `FuncStructInfo` is a macro in the meta-language: Given a function call and the variable mapping context, return the `StructInfo` of the result. This field is used only at compile time for reasoning about the `StructInfo` of calls to `ExternFunc`s. +### Notes on `derive_func` + +The `derive_func` field of `FuncStructInfo` is a macro in the meta-language: Given a function call and the variable mapping context, return the `StructInfo` of the result. This field is used only at compile time for reasoning about the `StructInfo` of calls to `ExternFunc`s. + +### Notes on `DataType` and Related Terminology + +The representation of datatypes, `DataType`, in the above AST is taken directly from TIR. However, the usage of datatypes in Relax is more restricted than in TIR. +1. The `lanes` field for the `Int`, `UInt`, and `Float` datatypes must always be 1; we do not directly consider vectorized values in Relax. +2. The `lanes` field for the `Handle` datatype must always be 0, indicating that it is `Void` (see below). The `bits` field for `Handle` should always be set to 64 (it will not be used by Relax). + +We also define the following special notation for datatypes, to be used in the rest of the specification: +1. `Bool()`: This is shorthand for `UInt(bits=1, lanes=1)`, since TIR does not have a separate Boolean type. "True" refers to a value of 1 in this datatype and "false" refers to a value of 0. For convenience, we will refer to Boolean values as a separate datatype in the specification, due to their significance in `If` nodes. +2. `Void()`: This is shorthand for `Handle(bits=64, lanes=0)`. TIR uses this datatype to refer to opaque objects; in Relax, it is used to denote an unknown datatype. ## Expression Survey @@ -183,7 +196,7 @@ Oftentimes, compiler passes operate only on particular functions or add new func Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. -- *Tensors* are n-dimensional arrays of scalar values (which can be integers of fixed bitwidths, floats of fixed bitwidths, or bools). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. +- *Tensors* are n-dimensional arrays of scalar values (which can be signed or unsigned integers of fixed bitwidths, floats of fixed bitwidths, or Boolean values). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. - *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. @@ -280,10 +293,11 @@ The following criteria apply to all programs (including before normalization): 13. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. 14. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. 15. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. -16. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s. -17. `PrimStructInfo` annotations should use only the `Int` and `Float` datatypes for their `dtype` fields. +16. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s (with `lanes` set to 1). +17. `PrimStructInfo` annotations should use only the `Int`, `UInt`, or `Float` datatypes for their `dtype` fields. +18. Per [the notes on `DataType`](#notes-on-datatype-and-related-terminology), any `DataType` annotation must have a `lanes` value of 1 for the `Int`, `UInt`, or `Float` datatypes and a `lanes` value of 0 for the `Handle` (`Void`) datatype. Additionally, `bits` must be 64 for `Void`. The supported bitwidths for `Int` and `UInt` are 1, 8, 16, 32, and 64; the supported bitwidths for `Float` are 16, 32, and 64. -Additionally, the criteria for normal form listed in the previous section must apply to any program that has been normalized. +Additionally, the criteria for normal form listed in [the previous section](#normal-form) must apply to any program that has been normalized. # Structural Information (`StructInfo`) in Relax @@ -374,7 +388,7 @@ Note that judging subtyping requires potentially reasoning about arbitrary `Shap 3. For all `S1`, `S1 <: ObjectStructInfo()`. 4. For `TensorStructInfo`: 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=-1, dtype=d)`. - 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=Void, shape=s)`. + 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=Void(), shape=s)`. 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s`, `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ statically equal. We say that `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` _possibly_ holds if `s1` and `s2` are _possibly_ statically equal. 5. For `ShapeStructInfo`: @@ -725,7 +739,7 @@ For each expression, we define how it affects the program's visible state and th 8. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. 9. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. 10. The node `If(cond, true_branch, false_branch)` is evaluated as follows: - 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a bool scalar). + 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a `Bool` scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. 11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. From db68fcb75f04e54ea8ef8650c0b0b4578b413bc0 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 18:00:49 -0500 Subject: [PATCH 23/47] Note immutability of tuples and shapes and possibility of interning --- relax_spec.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index a3c187855603..89e70c3b7945 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -119,7 +119,7 @@ We also define the following special notation for datatypes, to be used in the r This specification provides a more detailed description of what each expression and `StructInfo` represents and what conditions make them valid. To motivate and provide more context for the full specification later in this document, this section will briefly summarize the purpose of each node. 1. `Constant` nodes construct tensor constants (n-dimensional arrays of scalars). -2. `Tuple` nodes construct a tuple (fixed-size ordered grouping) of Relax values. +2. `Tuple` nodes construct a tuple (immutable fixed-size ordered grouping) of Relax values. 3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. @@ -146,7 +146,7 @@ This specification provides a more detailed description of what each expression The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. -12. `ShapeExpr` nodes construct shape literals. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +12. `ShapeExpr` nodes construct shape literals, which are immutable collections of shape dimensions. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. 13. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. @@ -197,9 +197,9 @@ Oftentimes, compiler passes operate only on particular functions or add new func Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. - *Tensors* are n-dimensional arrays of scalar values (which can be signed or unsigned integers of fixed bitwidths, floats of fixed bitwidths, or Boolean values). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. -- *Tuples* represent a fixed-size grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). +- *Tuples* represent a fixed-size immutable grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» -- *Tensor shapes* (shape values) are tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. +- *Tensor shapes* (shape values) are immutable tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. - *Primitive values* (`PrimValue`s) represent immutable scalar values that are primarily intended for being passed to external procedures, like calls to `PackedFunc`s. As a rule of thumb, scalar values intended for arithmetical computations should be 0-rank tensors while scalar values meant to serve as metadata should be `PrimValue`s. - Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. @@ -731,12 +731,12 @@ For each expression, we define how it affects the program's visible state and th 1. The node `Constant(value)` creates a new tensor whose contents are `value`. 2. A variable (whether `Var`, `DataflowVar` , or `GlobalVar`) evaluates to the stored value for that variable in the current scope. -3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and creates a new tuple value containing `v1`, `v2`, …, and `vn` in that order. +3. The node `Tuple([e1, e2, ..., en])` evaluates `e1` (yielding value `v1`), then `e2` (yielding value `v2`), …, and finally `en` (yielding value `vn`) in that order and returns a tuple value containing `v1`, `v2`, …, and `vn` in that order. 4. The node `PrimType(prim_expr)` evaluates the `PrimExpr` `prim_expr` first, obtaining a resulting `pv`. It then creates an immutable `PrimValue` containing `pv`. 5. The node `StringImm(s)` creates an immutable string container whose contents is `s`. It does not necessarily have to be a _new_ string container if, for example, string interning is implemented. 6. The node `DataTypeImm(dt)` creates a new immutable datatype representation. 7. The node `TupleGetItem(t, i)` is evaluated by first evaluating `t` (which, per `StructInfo` checking, must evaluate to a tuple) and then returning the `i`th field of the result. -8. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and creates a new shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. +8. The node `ShapeExpr([p1, p2, ..., pn])` evaluates the `PrimExpr`s `p1` (yielding dimension value `v1`), `p2` (yielding dimension value `v2`), …, and finally `pn` (yielding dimension value `vn`) in that order, using the current shape context, and returns a shape value whose dimensions are `v1`, `v2`, …, `vn`, in that order. 9. The node `Function([v1, v2, ..., vn], body)` returns a new closure containing the function definition itself and a mapping of any free Relax variables or shape variables in `body` to the values they hold in the current scope when the `Function` node is encountered. If the function is the RHS of a local binding, the bound variable should also be included in the closure's binding map and should be mapped to the closure itself (to allow for recursive calls). Closure capturing is done *by reference*; no values will be copied and references to captured values will alias their values in the outer scope. `DataflowVar`s are not captured by closures. 10. The node `If(cond, true_branch, false_branch)` is evaluated as follows: 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a `Bool` scalar). @@ -763,7 +763,7 @@ Optimizations are allowed to reorder and modify the operations of a program in a «Within `DataflowBlock`s, it is permitted for the compiler to remove or reorder `MatchCast` operations even though this can affect the "visible behavior" of the program (since they can exit with an error). It is also permitted for the compiler to optimize away potential non-termination within `DataflowBlock`s: For example, if some pure function `f` has an integer return type and does not terminate, it is permissible to optimize `f() - f()` to 0 within a `DataflowBlock`. In general, the compiler is permitted to make programs "more defined" (terminating when the original did not terminate, not raising an error when the original raised an error) within a `DataflowBlock`, but never "less defined" (giving an error when the original did not give an error, not terminating when the original did not terminate). Outside of `DataflowBlock`s, error messages and potential non-termination must be preserved faithfully.» -For immutable containers like those for the results of `PrimValue`, `StringImm`, and `DataTypeImm`, it is not required for the results of evaluating these expressions to be _new_ containers—it is permitted for the compiler to reuse existing objects provided that the values contained within are identical. This optimization is called [interning](https://en.wikipedia.org/wiki/String_interning). However, for operations that return new mutable values (in particular, operations that return tensor values), those _must_ be newly allocated, since reusing values can affect the behavior under aliasing. +For immutable containers like those for the results of `Tuple`, `ShapeExpr`, `PrimValue`, `StringImm`, and `DataTypeImm`, it is not required for the results of evaluating these expressions to be _new_ containers—it is permitted for the compiler to reuse existing objects provided that the values contained within are identical. This optimization is called [interning](https://en.wikipedia.org/wiki/String_interning). However, for operations that return new mutable values (in particular, operations that return tensor values), those _must_ be newly allocated, since reusing values can affect the behavior under aliasing. The specification makes no guarantees about certain memory-related properties and hence also does not consider them to be "visible behaviors": From 2173b9f54608feb356676a6193e8167008a33628 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 1 Feb 2023 18:01:36 -0500 Subject: [PATCH 24/47] Fix numbering in expression summary --- relax_spec.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 89e70c3b7945..aea084041172 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -123,18 +123,18 @@ This specification provides a more detailed description of what each expression 3. `Var`, `DataflowVar`, and `GlobalVar` nodes are all variables, referring to named stored values of different kinds. Variables in Relax must be bound exactly once. `GlobalVar`s are bound in the `IRModule` itself and refer to Relax functions or TIR `PrimFunc`s. `Var` nodes are bound either within functions, where they represent function parameters, or in `VarBinding` or `MatchCast` nodes in `BindingBlock`s, as we will discuss below. `DataflowVar`s are similar to `Var`s and can be bound only within `DataflowBlock`s. 4. `PrimExpr`s are used to represent dimensions of shapes in `ShapeExpr` and `MatchCast` nodes. These represent operations on integers with their own `Var` nodes (`tir::Var`), which we will refer to as "shape variables". Shape variables can only be used in other `PrimExpr`s and are scoped like `Var` nodes (`relax::Var`), which we will call "Relax variables." 5. `ExternFunc` nodes evaluate into `PackedFunc`s; the implementation will look up the registered `PackedFunc` by its global symbol. -5. `PrimValue` nodes construct immutable scalar values from `PrimExpr`s, primarily for interacting with `ExternFunc`s or operators. These scalars are boxed within TVM objects, allowing them to be nested inside TVM's containers. (By contrast, zero-dimensional tensors defined via `Constant` are mutable.) -6. `StringImm` nodes construct strings, intended primarily for interacting with `ExternFunc`s or operators. -7. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators (e.g., for TIR intrinsics that take a datatype as an input). -8. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. +6. `PrimValue` nodes construct immutable scalar values from `PrimExpr`s, primarily for interacting with `ExternFunc`s or operators. These scalars are boxed within TVM objects, allowing them to be nested inside TVM's containers. (By contrast, zero-dimensional tensors defined via `Constant` are mutable.) +7. `StringImm` nodes construct strings, intended primarily for interacting with `ExternFunc`s or operators. +8. `DataTypeImm` nodes construct representations of TIR datatypes, intended primarily for interacting with `ExternFunc`s or operators (e.g., for TIR intrinsics that take a datatype as an input). +9. `Call` nodes represent function calls. The callee argument (the `op`) can be an `ExternFunc` node (representing a call to a `PackedFunc`), an `Op` node (representing a call to a Relax operator), or an arbitrary expression. 1. `Op` nodes refer to built-in Relax operators, which the compiler is free to implement as is deemed appropriate. Certain operators implement important operations, like `call_tir` (allows for calling TIR `PrimFunc`s). 2. Any other expression must evaluate to a `PackedFunc` or a closure; the result of evaluating `op` will then be called with the given arguments. Calls to `ExternFunc`s and operators may perform side effects, hence it is important to reason about whether a function call is permitted inside a `DataflowBlock`. -9. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. -10. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. -11. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: +10. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. +11. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. +12. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. 1. The types must match: All `StructInfo` variants correspond to a category of value value (`TensorStructInfo` to a tensor value, `ShapeStructInfo` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: @@ -146,8 +146,8 @@ This specification provides a more detailed description of what each expression The `SeqExpr`'s `body` expression is allowed to reference any `Var`s introduced within the `SeqExpr`'s binding blocks in addition to those that were in the outer scope; the `body` expression is evaluated after the binding blocks and its value is what is returned. Any Relax variables and shape variables introduced in the `SeqExpr` are removed from scope after the expression finishes evaluating. -12. `ShapeExpr` nodes construct shape literals, which are immutable collections of shape dimensions. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. -13. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. +13. `ShapeExpr` nodes construct shape literals, which are immutable collections of shape dimensions. The `PrimExpr`s within it describe how to compute each dimension; they are free to use any shape variables that are in scope. +14. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. From d9a57fde1df4c95a5054c463b3e3ce00a0602566 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 2 Feb 2023 17:58:54 -0500 Subject: [PATCH 25/47] Update the description of call_tir --- relax_spec.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index aea084041172..468788db7bf2 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -777,6 +777,16 @@ These semantic rules assume a single thread of evaluation on a single host machi The above evaluation rules are general, but leave much room for implementations of operators to specify custom semantics. Certain operators are used to perform common operations and will be discussed here as well. -- `call_tir(prim_func, arg1, arg2, ..., argn, sinfo_args=[aS])`: `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). The `StructInfo` arg `aS` gives the `StructInfo` of the result of calling the `PrimFunc`; it must be a `TensorStructInfo` with a `shape` field corresponding to a constant shape expression and a non-`Void` `dtype`, denoting the shape of the resulting tensor, or a a `TupleStringInfo` where all the `fields` are `TensorStructInfo`. Based on `aS`, the resulting tensor or tuple `r` will be allocated according to the sizes given in their `shape` fields. `f` will be called in destination-passing style, like so: `f(arg1, ..., argn, *r)`. The asterisk denotes that if `r` is a tuple, it will be "unrolled," so the call will be `f(arg1, ..., argn, r1, ..., rn)`, where the `ri` are the fields of `r`. `f` is expected to mutate *only* `r` to give the output of the function, hence `call_tir` is considered pure. «If the shape or data type of the actual result do not correspond to `shape` or `aT`, an error is issued.» After the call, `r` is returned. -- «`call_dps_packed(global_symbol, arg1, arg2, ..., argn, sinfo_args=[aS])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol`. The `PackedFunc` may modify `arg1`, `arg2`, …, or `argn` in addition to the results, so purity is not assumed. `aS` denotes the `StructInfo` for the result.» -- `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a new shape object. +- `call_tir(prim_func, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: + - `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). + - `args` must should be an expression that evaluates to a tuple of tensor values (where each member of a tuple will be a tensor argument to the `PrimFunc`). Let us call the members of the tuple `arg1`, `arg2`, ..., `argn`. + - `packed_ints` is an optional argument. If present, it must be a shape value (with `ShapeStructInfo`). Each dimension of the value (which we will call `shape1`, `shape2`, ..., `shapem`) + - The `StructInfo` arguments `aS1` through `aSk` give the `StructInfo` of the results of calling the `PrimFunc`. + - All the `aSi` must be `TensorStructInfo` with a `shape` field consisting of a `ShapeExpr` (possibly containing shape variables) and a non-`Void` `dtype`, denoting the shape of the resulting tensors. + - If there is exactly one member of `sinfo_args`, then the operation returns a single tensor with that shape; if there are multiple or zero members of `sinfo_args`, then the result will have the `StructInfo` `TupleStructInfo(fields=[aS1, as2, ..., aSk])`. + - Based on the `aSi`, the resulting tensors `r1`, `r2`, ..., `rk` will be allocated according to the sizes given in their `shape` fields. + - `f` will be called in destination-passing style, like so: `f(arg1, arg2, ..., argn, shape1, shape2, ..., shapem, r1, r2, ..., rk)`, omitting the `shapei` if `packed_ints` is not given. `f` is expected to mutate *only* the `ri` to give the output of the function, hence `call_tir` is considered pure. + - «If the shape or data type of the actual result do not correspond to the `aSi`, an error is issued.» + - After the call, the `ri` will be returned (returning `r1` directly if there is only a single result, otherwise returning `Tuple(fields=[r1, r2, ..., rk])`). +- «`call_dps_packed(global_symbol, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol` instead of a `PrimFunc` object. The `PackedFunc` may modify any member of `args` (`packed_ints`, if present, is immutable) in addition to the results, so purity is not assumed. The `StructInfo` for the result will be determined int he same manner as in `call_tir`, where it will be `aS1` if `sinfo_args` has a length of 1 and `TupleStructInfo(fields=[aS1, aS2, ..., aSk])` otherwise.» +- `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a shape object. From c718203f31086067b573239804a66132dd26ca48 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 6 Feb 2023 18:18:53 -0500 Subject: [PATCH 26/47] Specify invariants for TensorStructInfo --- relax_spec.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index 468788db7bf2..8e84d2c68d91 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -374,6 +374,16 @@ def f(arg1, arg2, ..., argn): ``` » +### Invariants for `TensorStructInfo` + +Because the `shape` field of `TensorStructInfo` is an expression (either a `Var` or `ShapeExpr`), that expression may have its own `StructInfo`. In any `TensorStructInfo` derived by the below inference rules for `StructInfo` or in any `StructInfo` annotation, the following properties must hold of the `shape` field in `TensorStructInfo`: +1. If the `shape` field is a `Var`, the `Var` must have `ShapeStructInfo`. The `ndim` for the `Var`'s `ShapeStructInfo` must match that of the `TensorStructInfo`. +2. If the `shape` field is a literal `ShapeExpr`, then the `ndim` for the `TensorStructInfo` must match the number of fields in the `shape`'s `values` field (this is noted in the [well-formedness rules](#well-formedness-criteria)). + +Any shape variables that appear in the `ShapeStructInfo` must be in scope where the annotation appears. + +In particular, it is not permitted for the `TensorStructInfo` to have an unknown rank (`ndim` of -1) when the `shape` field has a non-negative `ndim`. + ## Subtyping for `StructInfo` Relax implements subtyping for `StructInfo`, which means that values with some `StructInfo` can be accepted where values with more general `StructInfo` are accepted We will denote the subtyping relationship as `S1 <: S2`, indicating that `S1` is a subtype of `S2`. For example. if `S1 <: S2` and some function expects an argument with `StructInfo` `S2`, then passing a value with `StructInfo` `S1` to that function is permitted; passing a value with `StructInfo` `S2` as an argument to a function that expects `S1` for that argument is *not* permitted—the value would have to be dynamically cast to `S1` using `MatchCast`. From 88ab0e9ae7b6a95c5242ef7273a576526c0b7c19 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 9 Feb 2023 21:09:15 -0500 Subject: [PATCH 27/47] PrimValue, StringImm, DataTypeImm are leaf nodes --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 8e84d2c68d91..978eccf6b60d 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -252,7 +252,7 @@ def func(x: Tensor) -> Tensor: To simplify the writing of Relax passes, we define a normal form for Relax programs, based on the [administrative normal form](https://en.wikipedia.org/wiki/A-normal_form) (A-normal form, or ANF). See [this post](https://matt.might.net/articles/a-normalization/) by Matt Might for a discussion of some of the advantages of ANF in traditional compilation; in particular, ANF results in programs without nesting, which is very convenient for writing program transformations. Because the `StructInfo`-checking rules for operators rely on macros (`FInferShapeInfo`), _this means that the structure of the program can affect `StructInfo` inference_. Putting programs into normal form (and lacking nesting) not only simplifies the writing of these macros but it also ensures that these `StructInfo`-checking rules will be predictable, hence _it is required to transform programs into normal form_ before applying `StructInfo` checking. The normal form for Relax is very similar to ANF; differences will be noted. Here are the criteria required for a program to be in normal form: -1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. +1. Within a `SeqExpr`, the right-hand side of any binding (the `value` field in the AST) must either be a "leaf expression" or a non-leaf expression where all subexpressions are leaf expressions. Leaf expressions are the following: Variables (`Var`, `DataflowVar`, or `GlobalVar`), `Constant`, `ShapeExpr`, `PrimValue`, `StringImm`, `DataTypeImm`, or (_unlike_ ANF) `Tuple`. `Tuple` nodes are considered "leaf" expressions even though they contain nesting purely for convenience in writing passes; many operators rely on grouping arguments using tuples, so that is a form of nesting permitted and expected. Otherwise, non-leaf expressions used as subexpressions must be bound to variables; this includes any non-leaf expressions nested inside a `Tuple`. 2. `SeqExpr`s may appear only in the following locations: 1. In the `body` field of a `Function` node. 2. In the `true_branch` and `false_branch` fields of `If` nodes. From 3971f7c9c4f0007fc00e5b7f7568f51ab6398c82 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 10 Feb 2023 16:34:15 -0500 Subject: [PATCH 28/47] Tuples are represented using Arrays, not ADTs now --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 978eccf6b60d..27953d2023a2 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -211,7 +211,7 @@ Because Relax supports calls to arbitrary `PackedFunc`s that can operate on a lo Possible specification in terms of the TVM object system: - Tensors are represented at run time as `NDArray`s (see `include/tvm/NDArray.h`). -- Tuples are represented using TVM ADTs (algebraic data types), which are arrays of TVM objects with a tag (see `include/tvm/runtime/container/adt.h`). Tuples use a tag of 0. +- Tuples are represented using TVM `Array`s (in contrast to `NDArray`s), which are immutable (see `include/tvm/runtime/container/array.h`). - At run time, closures are represented as a `ClosureObj` (see `include/tvm/runtime/container/closure.h`); in the Relax VM these more specifically use the `VMClosureObj` (see [`https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h`](https://github.com/tlc-pack/relax/blob/relax/include/tvm/runtime/relax_vm/executable.h)). - Shape values are represented at run time as a `ShapeTuple` (see `include/tvm/runtime/container/shape_tuple.h`). - Strings are represented using TVM's `String` container (see `include/tvm/runtime/container/string.h`). From 356ddf109136fbb7b6b8a20b90ce410cda32848d Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 13 Feb 2023 15:41:40 -0500 Subject: [PATCH 29/47] Fix typo --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 27953d2023a2..839c96d9f7e2 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -752,7 +752,7 @@ For each expression, we define how it affects the program's visible state and th 1. First `cond` is evaluated. Let the result be `r` (per `StructInfo` checking, it must be a `Bool` scalar). 2. If `r` is true, evaluate the `true_branch` and return its result. 3. If `r` is false, evaluate the `false_branch` and return its result. -11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not.) Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. +11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not). Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. 12. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. From 4f8bdb40082133a8ed4a122b63f9794246c85a5e Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 14 Feb 2023 21:35:16 -0500 Subject: [PATCH 30/47] Add mention of null value --- relax_spec.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 839c96d9f7e2..60de82fa281f 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -202,7 +202,7 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Tensor shapes* (shape values) are immutable tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. - *Primitive values* (`PrimValue`s) represent immutable scalar values that are primarily intended for being passed to external procedures, like calls to `PackedFunc`s. As a rule of thumb, scalar values intended for arithmetical computations should be 0-rank tensors while scalar values meant to serve as metadata should be `PrimValue`s. -- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. Another noteworthy value in this category is the _null object_ (the result of returning a null pointer in C++ or passing in `None` through the Python FFI), which is returned by the `null_value()` operator. ## Representation of Values at Run Time @@ -800,3 +800,4 @@ The above evaluation rules are general, but leave much room for implementations - After the call, the `ri` will be returned (returning `r1` directly if there is only a single result, otherwise returning `Tuple(fields=[r1, r2, ..., rk])`). - «`call_dps_packed(global_symbol, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol` instead of a `PrimFunc` object. The `PackedFunc` may modify any member of `args` (`packed_ints`, if present, is immutable) in addition to the results, so purity is not assumed. The `StructInfo` for the result will be determined int he same manner as in `call_tir`, where it will be `aS1` if `sinfo_args` has a length of 1 and `TupleStructInfo(fields=[aS1, aS2, ..., aSk])` otherwise.» - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a shape object. +- `null_value()`: Returns a null object (treated as `ObjectStructInfo`). This is used for indicating to operators that an optional argument has been omitted. \ No newline at end of file From 3437b3387246a358144262e9418feec45e290a60 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 2 Mar 2023 17:33:09 -0500 Subject: [PATCH 31/47] Clarify scoping for vars and dataflow vars in a seqexpr --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 60de82fa281f..d12a57e58901 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -135,7 +135,7 @@ This specification provides a more detailed description of what each expression 10. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. 11. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. 12. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: - 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope (`DataflowVar`s are scoped only to the `DataflowBlock` in which they appear; `Var`s are scoped to the entire `SeqExpr`). + 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope: `DataflowVar`s are scoped only to any later bindings in the `DataflowBlock` in which it was defined; `Var`s are scoped to any later bindings within the `BindingBlock` in which they were defined, as well as any bindings in subsequent `BindingBlock`s in the `SeqExpr` and in the `body` field of the `SeqExpr`. 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. 1. The types must match: All `StructInfo` variants correspond to a category of value value (`TensorStructInfo` to a tensor value, `ShapeStructInfo` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: 1. For comparing tensor values to `TensorStructInfo`, `ndim` must match the number of dimensions in the tensor value (unless `ndim` is -1) and `dtype` must match the datatype used (unless `dtype` is `Void`). If `shape` has been specified, the shape of the value must match that encoded by `shape`; if specified, `shape` must be either a `Var` already bound in the current scope or a `ShapeExpr`. From 575c1cd9a69998f7632dcb0a479a1ee9774e578b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 2 Mar 2023 17:36:04 -0500 Subject: [PATCH 32/47] Correct the definition of `Void`. --- relax_spec.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index d12a57e58901..56bf893f763c 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -108,11 +108,11 @@ The `derive_func` field of `FuncStructInfo` is a macro in the meta-language: Giv The representation of datatypes, `DataType`, in the above AST is taken directly from TIR. However, the usage of datatypes in Relax is more restricted than in TIR. 1. The `lanes` field for the `Int`, `UInt`, and `Float` datatypes must always be 1; we do not directly consider vectorized values in Relax. -2. The `lanes` field for the `Handle` datatype must always be 0, indicating that it is `Void` (see below). The `bits` field for `Handle` should always be set to 64 (it will not be used by Relax). +2. The `bits` field for the `Handle` datatype must always be 0, indicating that it is `Void` (see below). The `lanes` field for `Handle` should always be set to 0 (it will not be used by Relax). We also define the following special notation for datatypes, to be used in the rest of the specification: 1. `Bool()`: This is shorthand for `UInt(bits=1, lanes=1)`, since TIR does not have a separate Boolean type. "True" refers to a value of 1 in this datatype and "false" refers to a value of 0. For convenience, we will refer to Boolean values as a separate datatype in the specification, due to their significance in `If` nodes. -2. `Void()`: This is shorthand for `Handle(bits=64, lanes=0)`. TIR uses this datatype to refer to opaque objects; in Relax, it is used to denote an unknown datatype. +2. `Void()`: This is shorthand for `Handle(bits=0, lanes=0)`. TIR uses this datatype to refer to opaque objects; in Relax, it is used to denote an unknown datatype. ## Expression Survey From de6d26b1f462c519d45ce3ae3485d6737b907079 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 2 Mar 2023 19:13:58 -0500 Subject: [PATCH 33/47] Grammar fix --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 56bf893f763c..532c827fddb8 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -135,7 +135,7 @@ This specification provides a more detailed description of what each expression 10. `If` nodes represent branching control flow. First the condition expression is evaluated, and it must evaluate to a Boolean scalar. If the condition is true, the true branch is evaluated and its result is used; otherwise, the false branch is evaluated and its result is used. 11. `TupleGetItem` nodes represent tuple indexing. The `tuple_value` expression must evaluate to a tuple with at least `index + 1` items and the item with the given index will be returned. 12. `SeqExpr` describes a sequence of binding blocks followed by a return expression. The `SeqExpr` opens a new scope. Its binding blocks are evaluated in order and add new variables to the scope. Binding blocks are either ordinary `BindingBlock`s or `DataflowBlock`s and both consist of a series of bindings. `DataflowBlock`s are the only kind allowed to introduce bindings with `DataflowVar`s and it does not permit any constructs featuring control flow (`If` nodes or recursive calls) or calls to (possibly) impure functions. There are two different kinds of bindings: - 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope: `DataflowVar`s are scoped only to any later bindings in the `DataflowBlock` in which it was defined; `Var`s are scoped to any later bindings within the `BindingBlock` in which they were defined, as well as any bindings in subsequent `BindingBlock`s in the `SeqExpr` and in the `body` field of the `SeqExpr`. + 1. `VarBinding`s: The `value` expression (the right-hand side of the binding) of the binding is evaluated first and is bound to the `var` expression, which must be a new `Var` or `DataflowVar` (in a dataflow block). The newly bound variable will have that value for the remainder of the scope: `DataflowVar`s are scoped only to any later bindings in the `DataflowBlock` in which they were defined; `Var`s are scoped to any later bindings within the `BindingBlock` in which they were defined, as well as any bindings in subsequent `BindingBlock`s in the `SeqExpr` and in the `body` field of the `SeqExpr`. 2. `MatchCast`s: The `value` expression is evaluated and the result is dynamically checked against the structural information given in the `struct_info` field. 1. The types must match: All `StructInfo` variants correspond to a category of value value (`TensorStructInfo` to a tensor value, `ShapeStructInfo` to shape values, etc.), so if the structure of `value` does not correspond to `struct_info`, an error is triggered. The structure of `value` is compared recursively with `struct_info`, so all components of `value` must match up with any nested structural information. Special comparison rules: 1. For comparing tensor values to `TensorStructInfo`, `ndim` must match the number of dimensions in the tensor value (unless `ndim` is -1) and `dtype` must match the datatype used (unless `dtype` is `Void`). If `shape` has been specified, the shape of the value must match that encoded by `shape`; if specified, `shape` must be either a `Var` already bound in the current scope or a `ShapeExpr`. From ccafbc8a9029b7af56317abd4c7ad3de338d3a70 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 3 Mar 2023 12:46:02 -0500 Subject: [PATCH 34/47] Correct typos --- relax_spec.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 532c827fddb8..f8a64872a323 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -789,8 +789,8 @@ The above evaluation rules are general, but leave much room for implementations - `call_tir(prim_func, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: - `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). - - `args` must should be an expression that evaluates to a tuple of tensor values (where each member of a tuple will be a tensor argument to the `PrimFunc`). Let us call the members of the tuple `arg1`, `arg2`, ..., `argn`. - - `packed_ints` is an optional argument. If present, it must be a shape value (with `ShapeStructInfo`). Each dimension of the value (which we will call `shape1`, `shape2`, ..., `shapem`) + - `args` must be an expression that evaluates to a tuple of tensor values (where each member of a tuple will be a tensor argument to the `PrimFunc`). Let us call the members of the tuple `arg1`, `arg2`, ..., `argn`. + - `packed_ints` is an optional argument. If present, it must be a shape value (with `ShapeStructInfo`). If present, we will call the dimensions of the value`shape1`, `shape2`, ..., `shapem` for convenience. - The `StructInfo` arguments `aS1` through `aSk` give the `StructInfo` of the results of calling the `PrimFunc`. - All the `aSi` must be `TensorStructInfo` with a `shape` field consisting of a `ShapeExpr` (possibly containing shape variables) and a non-`Void` `dtype`, denoting the shape of the resulting tensors. - If there is exactly one member of `sinfo_args`, then the operation returns a single tensor with that shape; if there are multiple or zero members of `sinfo_args`, then the result will have the `StructInfo` `TupleStructInfo(fields=[aS1, as2, ..., aSk])`. From 581685f0f51cd0f84cb173cff1dd4bc82da74999 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 8 Mar 2023 18:52:32 -0500 Subject: [PATCH 35/47] Fix formatting --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index f8a64872a323..6f38a38c16ee 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -623,7 +623,7 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 1. «Prepopulate `Γ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Γ` corresponding to that `GlobalVar`.» 2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Γ[v]` for the structural information. -3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. +3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value`. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. 4. For `PrimValue(prim_expr)`, the resulting `StructInfo` is `PrimStructInfo(dt)`, where `dt` is the datatype of `prim_expr`, derived according to the type-checking rules for TIR. 5. For `StringImm(s)`, the resulting `StructInfo` is `ObjectStructInfo()`. 6. For `DataTypeImm(dt)`, the resulting `StructInfo` is `ObjectStructInfo()`. From 6bb6f20937eb1efdcc837a791c3b14cfb4d7cab4 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 10 Mar 2023 15:51:52 -0500 Subject: [PATCH 36/47] Update the spec to account for call_dps_packed --- relax_spec.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 6f38a38c16ee..0aaac80c3c6a 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -788,7 +788,7 @@ These semantic rules assume a single thread of evaluation on a single host machi The above evaluation rules are general, but leave much room for implementations of operators to specify custom semantics. Certain operators are used to perform common operations and will be discussed here as well. - `call_tir(prim_func, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: - - `prim_func` must be a `PrimFunc` object in the current `IRModule` (we will call it `f`). + - `prim_func` must be a `GlobalVar` that denotes a `PrimFunc` in the current `IRModule` (we will call it `f`). - `args` must be an expression that evaluates to a tuple of tensor values (where each member of a tuple will be a tensor argument to the `PrimFunc`). Let us call the members of the tuple `arg1`, `arg2`, ..., `argn`. - `packed_ints` is an optional argument. If present, it must be a shape value (with `ShapeStructInfo`). If present, we will call the dimensions of the value`shape1`, `shape2`, ..., `shapem` for convenience. - The `StructInfo` arguments `aS1` through `aSk` give the `StructInfo` of the results of calling the `PrimFunc`. @@ -798,6 +798,13 @@ The above evaluation rules are general, but leave much room for implementations - `f` will be called in destination-passing style, like so: `f(arg1, arg2, ..., argn, shape1, shape2, ..., shapem, r1, r2, ..., rk)`, omitting the `shapei` if `packed_ints` is not given. `f` is expected to mutate *only* the `ri` to give the output of the function, hence `call_tir` is considered pure. - «If the shape or data type of the actual result do not correspond to the `aSi`, an error is issued.» - After the call, the `ri` will be returned (returning `r1` directly if there is only a single result, otherwise returning `Tuple(fields=[r1, r2, ..., rk])`). -- «`call_dps_packed(global_symbol, args, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: Proceeds similarly to `call_tir`, except it calls a `PackedFunc` registered under the name `global_symbol` instead of a `PrimFunc` object. The `PackedFunc` may modify any member of `args` (`packed_ints`, if present, is immutable) in addition to the results, so purity is not assumed. The `StructInfo` for the result will be determined int he same manner as in `call_tir`, where it will be `aS1` if `sinfo_args` has a length of 1 and `TupleStructInfo(fields=[aS1, aS2, ..., aSk])` otherwise.» +- `call_dps_packed(packed_func, args, sinfo_args=[aS1])`: + - `packed_func` must evaluate to a `PackedFunc` object. + - `args` must be a tuple; we will call its elements `arg1`, `arg2`, ..., `argn`. + - The `StructInfo` argument `aS1` may be either a single `TensorStructInfo` (whose `shape` field _must_ be a `ShapeExpr`), which we will call `ts1`, or a `TupleStructInfo` whose fields are all `TensorStructInfo` (whose `shape` fields _must_ be `ShapeExpr`s), which we will call `ts1`, `ts2`, ..., `tsm`. + - Let `r1`, `r2`, ..., `rm` be newly allocated tensors whose shape match the `StructInfo` args `ts1`, `ts2`, ..., `tsm`, respectively. + - Evaluate `f(arg1, arg2, ..., argn, r1, r2, ..., rm)`. + - «If the shape or data type of the actual result do not correspond to the `tsi`, an error is issued.» + - Return `r1` if `aS1` is a single `TensorStructInfo`; otherwise, return `Tuple(fields=[r1, r2, ..., rm])`. - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a shape object. - `null_value()`: Returns a null object (treated as `ObjectStructInfo`). This is used for indicating to operators that an optional argument has been omitted. \ No newline at end of file From 43f0b1e8de62b5265f5b7d942eb3f4db97481fd1 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 14 Mar 2023 13:16:59 -0400 Subject: [PATCH 37/47] Correct typo --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 0aaac80c3c6a..1baebe7286ff 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -394,7 +394,7 @@ Note that judging subtyping requires potentially reasoning about arbitrary `Shap * They are _definitely not_ statically equal in at least one case. 1. Reflexivity: `S1 <: S1` for all `S1`. -2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <: S2` and `S2 <: S3`, then `S1 <<: S3`. +2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <: S2` and `S2 <: S3`, then `S1 <: S3`. 3. For all `S1`, `S1 <: ObjectStructInfo()`. 4. For `TensorStructInfo`: 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=-1, dtype=d)`. From 1e896ea5ef9e743eaa48b4584f775ae5d3c0b2cb Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 30 Mar 2023 22:53:09 -0400 Subject: [PATCH 38/47] Clarify phrasing related to the StructInfo derivation for Call nodes --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 1baebe7286ff..4b31cace5030 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -662,7 +662,7 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 1. Give an error if `Sf` is not `FuncStructInfo`. 2. If the `derive_func` field of `Sf` is defined, then apply the `derive_func` macro to the call node to derive the structural information for the call node, ignoring the `ret` field of `Sf`. 3. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. Let the members of params be `P1`, `P2`, ..., `Pn`. - 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. Replace all variables in `m` with their mapping in `Sf`. + 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. For each shape variable `v` that occurs in `Sf`, replace it with `m[v]` if `v` is in `m`. 5. After the substitutions, give an error if `Pi <: Si` does not hold for some `i` (give a warning if it _possibly_ holds). 6. Use `erase_to_well_defined(Sf.ret, Γ, Σ)` as the resulting structural information. 15. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: From 22ca164a711fe9f1f85f097f4b3a61b1263ad286 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 18 May 2023 16:04:33 -0400 Subject: [PATCH 39/47] Include specification for purity tracking --- relax_spec.md | 59 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 4b31cace5030..49e07c73277f 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -62,7 +62,7 @@ StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) | PrimStructInfo(dtype: DataType) | ObjectStructInfo() | TupleStructInfo(fields: [StructInfo]) - | FuncStructInfo(params: [StructInfo]?, ret: StructInfo, derive_func: EnvFunc?*) + | FuncStructInfo(params: [StructInfo]?, ret: StructInfo, purity: bool, derive_func: EnvFunc?*) # expressions Expr ::= Constant(data: NDArray) @@ -76,7 +76,7 @@ Expr ::= Constant(data: NDArray) | PrimValue(value: PrimExpr) | StringImm(value: string) | DataTypeImm(value: DataType) - | Function(params: [Var], body: Expr, ret_struct_info: StructInfo?, attrs: Attrs?) + | Function(params: [Var], body: Expr, ret_struct_info: StructInfo?, is_pure: bool?, attrs: Attrs?) | If(cond: Expr, true_branch: Expr, false_branch: Expr) | ExternFunc(global_symbol: string) | Call(op: Expr, args: [Expr], sinfo_args: [StructInfo], attrs: Attrs?) @@ -150,6 +150,8 @@ This specification provides a more detailed description of what each expression 14. `Function` nodes represent function definitions, taking in the listed parameters and evaluating the body expression in a new scope (meaning any variables defined from within the function cannot be referenced outside it). Function definitions may be nested in any other expression and they evaluate into closure values, ensuring that functions are first-class. Closures capture any variables from the outer scope that are used in their body, both Relax variables and shape variables. Note that function definitions themselves are anonymous—a function must be registered in the `IRModule` (bound to a `GlobalVar`) or appear on the right-hand side of a binding to have a name in order to be called recursively. The function can have structural annotations on the parameters and a structural annotation for the return value. When the function is called, the annotations on parameters are checked against the argument values in similar fashion to `MatchCast` and can introduce new shape variables that are scoped to the function. Additionally, the structural information of the return value is checked against the annotation before the call returns. + + In addition to the structural annotations for the parameters and the return value, the `is_pure` field on a `Function` node serves to annotate whether the `Function` itself is pure (has no visible side effects) or not. The `StructInfo` system tracks purity in order to judge what calls are permitted inside `DataflowBlock`s. At this time, Relax makes no attempt to infer the purity of functions, so it is required for users to annotate the purity (if no annotation is provided, `is_pure` will be treated as true; since this is by far the most common case for deep learning applications, it is in practice necessarily to annotate purity if the function is _impure_). «A function mapped bound to a `GlobalVar` can have a `global_symbol` attribute defined to indicate that it should be externally linked externally (be accessible outside the `IRModule`). The absence of a `global_symbol` attribute on a function definition bound to a `GlobalVar` indicates that it is "private" and hence can be called only within the `IRModule`.» @@ -177,7 +179,7 @@ Analogously to a type system in most languages, Relax tracks structural informat 2. `TupleStructInfo` corresponds to tuple values, giving the `StructInfo` for each member of the tuple. 3. `PrimStructInfo` corresponds to `PrimValue`s (immutable scalar values), giving their TIR datatype. 4. `ShapeStructInfo` corresponds to shape values, optionally giving the number of dimensions in the shape and an expression that computes the shape's dimensions (either a `ShapeExpr` or a `Var`). -5. `FunctionStructInfo` corresponds to function values (closures) and `PackedFunc`s (external functions), giving the types of the parameters, the return type, «and whether the function is pure.» +5. `FunctionStructInfo` corresponds to function values (closures) and `PackedFunc`s (external functions), giving the types of the parameters, the return type, and whether the function is pure. 6. `ObjectStructInfo` is a parent to all Relax `StructInfo` and corresponds to all the values above as well as any values returned by `PackedFunc` calls that do not fit in the above categories. `StructInfo` is assigned to every variable in scope and every type of expression based on the values it returns via a set of inference rules defined later in the specification, making use of subtyping to assign more general `StructInfo` when a more specific one cannot be determined. «Relax is strongly typed, meaning that if the `StructInfo` inferred is less specific than the one expected, an error will be issued and an explicit check via `MatchCast` will be required.» @@ -198,7 +200,7 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Tensors* are n-dimensional arrays of scalar values (which can be signed or unsigned integers of fixed bitwidths, floats of fixed bitwidths, or Boolean values). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. - *Tuples* represent a fixed-size immutable grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). -- *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» +- *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure, including whether it is pure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are immutable tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. - *Primitive values* (`PrimValue`s) represent immutable scalar values that are primarily intended for being passed to external procedures, like calls to `PackedFunc`s. As a rule of thumb, scalar values intended for arithmetical computations should be 0-rank tensors while scalar values meant to serve as metadata should be `PrimValue`s. @@ -282,7 +284,7 @@ The following criteria apply to all programs (including before normalization): 2. Calls to a global function that is mutually recursive with the current function 3. `If` nodes - «Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during `StructInfo` checking.» + Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during `StructInfo` checking. 7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» 8. `Op` nodes may appear only as the `op` argument to `Call` nodes. @@ -296,6 +298,7 @@ The following criteria apply to all programs (including before normalization): 16. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s (with `lanes` set to 1). 17. `PrimStructInfo` annotations should use only the `Int`, `UInt`, or `Float` datatypes for their `dtype` fields. 18. Per [the notes on `DataType`](#notes-on-datatype-and-related-terminology), any `DataType` annotation must have a `lanes` value of 1 for the `Int`, `UInt`, or `Float` datatypes and a `lanes` value of 0 for the `Handle` (`Void`) datatype. Additionally, `bits` must be 64 for `Void`. The supported bitwidths for `Int` and `UInt` are 1, 8, 16, 32, and 64; the supported bitwidths for `Float` are 16, 32, and 64. +19. If a `Function` `f` has an `attrs` field that includes the attribute `relax.force_pure`, `f`'s `is_pure` field must be set to `True`. Additionally, the criteria for normal form listed in [the previous section](#normal-form) must apply to any program that has been normalized. @@ -349,7 +352,7 @@ This section describes the run-time checking performed by `MatchCast(var, value, 1. If the `i`th member of `values` consists of only an unbound shape variable, then bind that variable to the `i`th field of the the concrete shape value. 2. Otherwise, evaluate the `i`th member of `values` and check that it is equal to teh `i`th field of the concrete shape value. 5. If `struct_info` is `TupleStructInfo(fields)`, then check that `value` is a tuple value with `n` fields, where `n` is the length of `fields`. Also recursively check the `i`th field of the tuple value against the `i`th member of `fields`. -6. If `struct_info` is `FuncStructInfo(params, ret, derive_func)`, then if `params` is defined, check that `value` is a closure value; if `derive_func` is defined, check that `value` is a `PackedFunc`. No further validation may be done on a `PackedFunc`. «If `value` is a closure value, then it can contain run-time structural information indicating the structural information of its intended arguments and return value that can be compared against `params` and `ret`.» +6. If `struct_info` is `FuncStructInfo(params, ret, purity, derive_func)`, then if `params` is defined, check that `value` is a closure value; if `derive_func` is defined, check that `value` is a `PackedFunc`. No further validation may be done on a `PackedFunc`. «If `value` is a closure value, then it can contain run-time structural information indicating its purity and the structural information of its intended arguments and return value that can be compared against `purity`, `params`, and `ret`.» ### Checking Structural Information at the Start and End of a Function @@ -410,7 +413,8 @@ Note that judging subtyping requires potentially reasoning about arbitrary `Shap 8. For `FuncStructInfo`: 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <: F2` only if `F1.derive_func` and `F2.derive_func` are identical. - 3. Given two lists of `StructInfo` parameters `P1` and `P2` and two `StructInfo` annotations `R1` and `R2`, `FuncStructInfo(params=P1, ret=R1) <: FuncStructInfo(params=P2, ret=R2)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <: P1[i]` and `R1 <: R2`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships given only possibly holds. + 3. Given a list of `StructInfo` parameters `P` and a `StructInfo` return annotation `R`, then `FuncStructInfo(params=P, ret=R, purity=True) <: FuncStructInfo(params=P, ret=R, purity=False)`. That is, a pure function can be passed where an impure one is accepted, but not vice versa. + 3. Given two lists of `StructInfo` parameters `P1` and `P2`, two `StructInfo` annotations `R1` and `R2`, and a Boolean `purity`, `FuncStructInfo(params=P1, ret=R1, purity=purity) <: FuncStructInfo(params=P2, ret=R2, purity=purity)` if `P1` and `P2` are the same length and for all `i`, `P2[i] <: P1[i]` and `R1 <: R2`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships given only possibly holds. These rules allow us to define the least upper bound (LUB) for any two `StructInfo` `S1` and `S2`, meaning that it is the most specific `StructInfo` `S` for which `S1 <: S` and `S2 <: S` ("most specific" meaning that if there exists some other `S'` for which `S1 <: S'` and `S2 <: S'`, then `S <: S'`), modulo reasoning about arithmetic (for example, the compiler may judge that two shape expressions are _possibly_ equivalent rather than _definitely_ equivalent). The LUB is guaranteed to exist for any two `StructInfo` because all `StructInfo` are subtypes of `ObjectStructInfo`. @@ -475,6 +479,8 @@ def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: if S1.params and S2.params are both defined: if S1.params and S2.params do not have the same length: return ObjectStructInfo() + # the LUB is pure if they're both pure and false if either isn't + purity = S1.purity and S2.purity unified_params = [] for 0 <= i < length of S1.params: unified_param = unify_struct_info(S1.params[i], S2.params[i]) @@ -487,13 +493,24 @@ def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: unified_params[i] = unified_param else: return ObjectStructInfo() - return FuncStructInfo(params=unified_params, ret=unify_struct_info(S1.ret, S2.ret)) + return FuncStructInfo(params=unified_params, ret=unify_struct_info(S1.ret, S2.ret), purity=purity) ``` ## Deriving the Structural Information for Each Expression For each kind of expression, we can recursively build up the structural information associated with the expression. +### Checking Purity + +The below derivation rules will explain in formal detail how Relax checks the correctness of purity annotations and enforces that impure calls are not made inside `DataflowBlock`s. At a high level, it operates by the following principles: +1. Calls to `ExternFunc`s (which thus includes any expression whose `StructInfo` is `FuncStructInfo` with a `derive_func` included) are assumed to be impure by default. The `call_pure_packed` operator can be used to indicate to the compiler that a particular call to an `ExternFunc` is, in fact, pure. +2. `Op` nodes must have an attribute called `FPurity`, which is a boolean flag that indicates whether or not the operator is pure. If the operator can have visible side effects in any case at all, it should be considered impure. +3. For Relax `Function`s, the purity will depend on the `is_pure` annotation (which must be user-supplied). + +Thus, the `StructInfo` system can determine whether a call is pure based on the above principles: For operators, it refers to `FPurity` and otherwise it refers to the `FuncStructInfo` (using the `purity` field for functions with `params` defined and assuming that any function with a `derive_func` defined is impure). If any such call occurs inside a `DataflowBlock` or a `Function` whose `is_pure` field is set to `True`, that is treated as a type error. + +For verifying the purity of a function, however, there is one workaround permitted: If the function has the `relax.force_pure` attribute mapped to `True` in its `attrs`, then impure calls will be disregarded. This accounts for situations where individual actions may be impure (like mutating a value) but the overall effect of the function is pure (e.g., if the value that is mutated is one that is created inside the function, meaning that no externally-visible memory was ever mutated). This case is unlikely to be common for input programs, though `relax.force_pure` is used frequently in later stages of compilation. + ### Auxiliary Procedures **`derive_func` for `FuncStructInfo`** @@ -657,15 +674,20 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 4. Use `fields[i]` (zero-based) as the structural information for the `TupleGetItem`. 13. For an `ExternFunc` node, the resulting structural information is `FuncStructInfo(params=None, ret=ObjectStructInfo(), derive_func=default_derive)`. 14. For `Call(op, [arg1, arg2, ..., argn], type_args=[aT1, aT2, ..., aTn])`: - 1. For a call to an `Op`, we use the manually defined `FInferStructInfo` macro if it has been defined and `ObjectStructInfo()` if it has not. `FInferStructInfo` is a function that takes in the call node and returns the structural information of the result. + 1. For a call to an `Op`: + 1. We use the manually defined `FInferStructInfo` macro if it has been defined for `op` and `ObjectStructInfo()` as the resulting `StructInfo` if it has not. `FInferStructInfo` is a function that takes in the call node and returns the structural information of the result. + 2. If the current function has `is_pure` set to `True` and the current function does not have `relax.force_pure` mapped to `True` in its `attrs` field _or_ if the current scope is inside a `DataflowBlock`, then consider it a type error if `op` does not have `True` as the value for its `FPurity` attribute. 2. Otherwise, derive the structural information for `op` and call it `Sf`. Next derive the structural information for the args and call it `S1`, `S2`, ..., and `Sn`. 1. Give an error if `Sf` is not `FuncStructInfo`. - 2. If the `derive_func` field of `Sf` is defined, then apply the `derive_func` macro to the call node to derive the structural information for the call node, ignoring the `ret` field of `Sf`. - 3. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. Let the members of params be `P1`, `P2`, ..., `Pn`. - 4. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. For each shape variable `v` that occurs in `Sf`, replace it with `m[v]` if `v` is in `m`. - 5. After the substitutions, give an error if `Pi <: Si` does not hold for some `i` (give a warning if it _possibly_ holds). - 6. Use `erase_to_well_defined(Sf.ret, Γ, Σ)` as the resulting structural information. -15. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info)`: + 2. If the `derive_func` field of `Sf` is defined: + 1. If the current function has `is_pure` set to `True` and the current function does not have `relax.force_pure` mapped to `True` in its `attrs` field _or_ if the current scope is inside a `DataflowBlock`, then give a type error: External functions are assumed to be impure by default (the `call_pure_packed` operator can be used to indicate to the compiler that an external function is, in fact, pure). + 2. Apply the `derive_func` macro to the call node to derive the structural information for the call node, ignoring the `ret` field of `Sf`. Additionally, + 3. If the current function has `is_pure` set to `True` and the current function does not have `relax.force_pure` mapped to `True` in its `attrs` field _or_ if the current scope is inside a `DataflowBlock`, then consider it a type error if `Sf`'s `purity` field is not `True`. + 4. Otherwise, `params` must be defined. Give an error if the length of `params` does not match the number of call arguments. Let the members of params be `P1`, `P2`, ..., `Pn`. + 5. Next, attempt to perform [beta-reduction](https://en.wikipedia.org/wiki/Lambda_calculus#%CE%B2-reduction) by matching unbound shape variables in `params` with the `Si`. Namely, get a shape var mapping `m` by applying `get_shape_var_mapping(params[i], Si)` for all `i` and taking the union of all resulting mappings. For each shape variable `v` that occurs in `Sf`, replace it with `m[v]` if `v` is in `m`. + 6. After the substitutions, give an error if `Pi <: Si` does not hold for some `i` (give a warning if it _possibly_ holds). + 7. Use `erase_to_well_defined(Sf.ret, Γ, Σ)` as the resulting structural information. +15. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info, is_pure, attrs)`: 1. Let `S1`, `S2`, ..., `Sn` be the structural information of the parameters. If `vi` has a structural annotation, then use that annotation for `Si`; if not, use `ObjectStructInfo()`. Let `Sr` be `ret_struct_info` if it is defined and `ObjectStructInfo()` if not. 2. If the function is bound to a `GlobalVar` `gv`, set `Γ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. 3. For each of the `vi`, set `Γ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. @@ -806,5 +828,12 @@ The above evaluation rules are general, but leave much room for implementations - Evaluate `f(arg1, arg2, ..., argn, r1, r2, ..., rm)`. - «If the shape or data type of the actual result do not correspond to the `tsi`, an error is issued.» - Return `r1` if `aS1` is a single `TensorStructInfo`; otherwise, return `Tuple(fields=[r1, r2, ..., rm])`. + - Note that it is assumed that `packed_func` will be pure, so `call_dps_packed` is treated as a pure operator (its `FPurity` is set to `True`). +- `call_pure_packed(func, args, sinfo_args)`: + - `func` must evaluate to a `PackedFunc` object. + - `args` must be a tuple. + - `sinfo_args` must be a non-empty list of `StructInfo`. + - The returned value will have the semantics of `Call(func, args, sinfo_args=sinfo_args)`. However, this call will be assumed to be pure (`call_pure_packed`'s `FPurity` is set to `True`), thus allowing the call to appear inside a `DataflowBlock` or a function whose `is_pure` is set to `True`. + - Note: This operator is intended to be be used for cases where the user knows that calling the packed function will _in reality_ not cause any side effects. If it is used for a call that _does_ result in side effects, then the compiler may end up removing, reordering, or repeating that call; the specification makes no guarantees about the side effects in the callee in that case. - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a shape object. - `null_value()`: Returns a null object (treated as `ObjectStructInfo`). This is used for indicating to operators that an optional argument has been omitted. \ No newline at end of file From 8742b5ffa0a2a17616303ced905572253196e83f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 18 May 2023 16:07:52 -0400 Subject: [PATCH 40/47] Indicate that is_pure will correspond to purity field for FuncStructInfo --- relax_spec.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 49e07c73277f..1d3c61884199 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -689,11 +689,11 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 7. Use `erase_to_well_defined(Sf.ret, Γ, Σ)` as the resulting structural information. 15. For `Function(params=[v1, v2, ..., vn], body, ret_struct_info, is_pure, attrs)`: 1. Let `S1`, `S2`, ..., `Sn` be the structural information of the parameters. If `vi` has a structural annotation, then use that annotation for `Si`; if not, use `ObjectStructInfo()`. Let `Sr` be `ret_struct_info` if it is defined and `ObjectStructInfo()` if not. - 2. If the function is bound to a `GlobalVar` `gv`, set `Γ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr)`. Still check the structural information in `body`, however. + 2. If the function is bound to a `GlobalVar` `gv`, set `Γ[gv]` to `FuncStructInfo(params=[S1, S2, ..., Sn], ret=Sr, purity=is_pure)`. Still check the structural information in `body` per the below steps, however. 3. For each of the `vi`, set `Γ[vi]` to `Si`. Additionally, add all new shape variables introduced in the `Si` to `Σ`. 4. Derive the structural information for `body`, calling it `Sb`. 5. Give an error if `Sb` is incompatible with `Sr` via `check_compatibility` (warn if only possibly compatible). - 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Γ, Σ))`. + 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info, purity=is_pure)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Γ, Σ), purity=is_pure)`. 7. Remove all variables added to `Γ` and `Σ` during the above steps of the derivation. ### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks From dd4fd831d64fa32ffee05ca18587859bf488a5db Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 18 May 2023 17:06:16 -0400 Subject: [PATCH 41/47] Restore definition of check_compatibility (must have been removed accidentally), update to account for purity tracking --- relax_spec.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index 1d3c61884199..689b8aff9b56 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -634,6 +634,39 @@ def get_shape_var_mapping(S1: StructInfo, S2: StructInfo) -> {tir::Var, PrimExpr return {} ``` +**Checking Compatibility** + +In many cases during the derivation of structural information, it is important to judge when two distinct structural information encodings are compatible with each other or when they are too different from each other to be reconciled, which can indicate an error. In the case of shape information, this could mean having two symbolic shapes that can be proven not to be equal to each other. Because shape expressions can contain arithmetic and it can be very difficult to statically prove whether two arithmetic expressions are equal, we permit the compiler implementation to make a best-effort attempt to prove equality for arithmetic expressions. (The user can insert a `MatchCast` to check definitively.) Since the checks are best-effort, the compatibility check will only report incompatibility if two values are _definitely_ different from each other. + +We can check if some structural information `S1` is accepted where structural information `S2` is expected by the process given below, which we refer to as `check_compability(S1, S2)` for convenience. `check_compatibility` can find that `S1` and `S2` are compatible, possibly compatible, or incompatible. "Incompatible" indicates a definite mismatch that should result in a compiler error; "possibly compatible" indicates that the structures may or may not match and should likely result in a compiler warning (indicating that a user may want to insert a dynamic check). An invariant that should should is that if `check_compatibility(S1, S2)` returns "compatible" or "possible compatible", `erase_struct_info(S1) <: erase_struct_info(S2)` should hold; that is, compatibility of structural information should be consistent with typing rules. + +1. If `S2` is `ObjectStructInfo`, then they are compatible. +2. Otherwise, if `S1` and `S2` are not both `TensorStructInfo` or both `TupleStructInfo`, etc. (besides `ObjectStructInfo`), then report an incompatibility. +3. If `S1` and `S2` are both `TupleStructInfo`: + 1. If `S1.fields` is not the same length as `S2.fields`, they are incompatible + 2. Call `check_compability(S1.fields[i], S2.fields[i])` for all `i`. If any pair of fields is incompatible, then `S1` and `S2` are incompatible. If no pair of fields is incompatible but at least one is possibly compatible, then `S1` and `S2` are possibly compatible. If all pairs of fields are compatible, then `S1` and `S2` are compatible. +4. If `S1` and `S2` are both `ShapeStructInfo`: + 1. `S2.ndim` is -1, then they are compatible. + 2. Otherwise, give an error if `S1.ndim` does not match `S2.ndim`. + 3. If `values` is not defined for `S2`, then they are compatible. + 4. If `values` is defined for `S2` but not defined for `S1`, then they are possibly compatible. + 5. If `values` is defined for both `S1` and `S2`, then the two are incompatible if `S1.values[i]` can be proven to be _not_ equal to `S2.values[i]` for some `i`. If all members can be proven to be equal, then they are compatible. Otherwise, if at least one pair of values cannot be proven to be either equal or unequal, then they are possibly compatible. +5. If `S1` and `S2` are both `TensorStructInfo`: + 1. If `S2.dtype` is not `Void` and does not match `S1.dtype`, then they are incompatible. + 2. If `S2.ndim` is not -1 and does not match `S1.ndim`, then they are incompatible. + 3. If `S2.shape` is not defined, then they are compatible. + 4. If `S2.shape` is defined and `S1.shape` is not defined, then they are possibly compatible. + 5. Otherwise, if both `shape` fields are given and either is a `Var`, then consider `S1` and `S2` compatible if the compiler can statically prove that the `Var` holds the same value as the other `shape` field, consider them possibly compatible if the compiler cannot draw a conclusion one way or the other, and consider them incompatible if the `Var` definitely has a different value from the other `shape`. + 6. If both `shape` fields are given and they are both `ShapeExpr` nodes, then `S1` and `S2` are incompatible if the compiler can prove that some dimension of `S1.shape` is _not_ equal to the corresponding dimension of `S2.shape`. Otherwise, if the all dimensions can be proven to be equal, then consider them compatible. If at least one pair of dimensions cannot be proven to be equal or unequal, consider them possibly compatible. +6. If `S1` and `S2` are both `FuncStructInfo`: + 1. If `S1` and `S2` don't both have defined `params` or both have undefined `params`, consider them incompatible. + 2. If both `S1` and `S2` have undefined `params`, consider them compatible if they have an identical `derive_func` and consider them possibly compatible if they have different `derive_func`s (as they is no further way to introspect the `derive_func` and draw static conslusions about `PackedFunc`s). + 3. If `params` is defined for both `S1` and `S2`: + 1. Consider them incompatible if the `params` have different lengths. + 2. If the `purity` of `S1` is `False` but the `purity` of `S2` is `True`, then consider them incompatible. + 3. Next, map unbound shape variables as follows: Get a variable mapping `m` by applying `get_shape_var_mapping(S1.params[i], S2.params[i])` for all values of `i`, taking the union of all resulting mappings. Next, substitute all occurrences of the shape variables in `S1` with their values in `m`. + 4. If `check_compatibility(S2.params[i], S1.params[i])` (note the direction of the check: see the subtyping rule for `FuncType`) is incompatible for any `i` or if `check_compatibility(S1.ret, S2.ret)` is incompatible, then they are incompatible. Otherwise, if `check_compatibility(S2.params[i], S1.params[i])` is possibly compatible for any `i` or if `check_compatibility(S1.ret, S2.ret)` is possibly compatible, consider `S1` and `S2` possibly compatible. Consider `S1` and `S2` compatible only if all checks are compatible. + ### Derivation Rules Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track which shape variables are in scope. From 8fede62a45702c0a3d3c7a58a4be52536cfa9afd Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 10 Jul 2023 18:18:08 -0400 Subject: [PATCH 42/47] Assign a FuncStructInfo to PrimFuncs, specify direct calls to PrimFuncs --- relax_spec.md | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 689b8aff9b56..cd6fe63f75cc 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -31,7 +31,7 @@ While Relax aims to be as general and expressive as Relay, Relax is intended to Below is a diagram of the various AST constructs in Relax, including types. In code, these are defined on the C++ side in `include/tvm/relax/{expr.h, type.h}` and in Python in `python/tvm/relax/{expr.py, ty.py}`. This diagram will give the names of the AST nodes and the types and names of their members. The semantics will describe what computation each construct represents; an AST is simply data. A Relax program consists of an `IRModule` with global variables bound to Relax functions that implement the computations of interest. -(On the notation: `[x]` means "a list of `x`," `x?` means "optionally `x`," `{x: y}` means "a map of `x` to `y`," `x | y` means "`x` or `y`," and `#` is used for comments.) +(On the notation: `[x]` means "a list of `x`," `x?` means "optionally `x`," `{x: y}` means "a map of `x` to `y`," `x | y` means "`x` or `y`," and `#` is used for comments. For the definition of `PrimFunc`, AST constructs are prefixed with `tir::` to indicate that these are the TIR versions of these AST nodes rather than the Relax ones.) ``` # PrimExprs are defined in TIR, see include/tvm/tir/expr.h @@ -51,6 +51,12 @@ PrimExpr ::= | Select(condition: PrimExpr, true_value: PrimExpr, false_value: PrimExpr) # (others may be added later, as deemed necessary) +# See include/tvm/tir/function.h +# Can appear at the module level but otherwise do not interact with any Relax constructs; +# intended to have the same semantics as in TIR +PrimFunc ::= PrimFunc(params: [tir::Var], body: tir::Stmt, ret_type: tir::Type?, + buffer_map: {tir::Var: tir::Buffer}, attrs: Attrs) + # Also from TIR DataType ::= Int(bits: int, lanes: int) | UInt(bits: int, lanes: int) @@ -95,8 +101,8 @@ Binding ::= | MatchCast(var: (Var|DataflowVar)?, struct_info: StructInfo, value: Expr) # Relax programs are IRModules. Modules may bind global variables either to -# Relax functions or TIR PrimFuncs (specified separately). -# The Relax compiler may analyze and modify the TIR PrimFUncs as well. +# Relax functions or TIR PrimFuncs. +# The Relax compiler may analyze and modify the TIR PrimFuncs as well. Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) ``` @@ -203,8 +209,9 @@ Here are the classes of values that Relax operates over, meaning that they can b - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure, including whether it is pure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are immutable tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. - *Packed functions* (`PackedFunc`s or external functions) represent arbitrary opaque functions implemented in TVM. That is, packed functions are routines that are defined outside of Relax and cannot be inspected by the compiler. They can perform side effects and return arbitrary values. +- *TIR `PrimFuncs`* are functions in TIR. They are usually invoked using the `call_tir` operator, but can be called on their own as first-class functions. - *Primitive values* (`PrimValue`s) represent immutable scalar values that are primarily intended for being passed to external procedures, like calls to `PackedFunc`s. As a rule of thumb, scalar values intended for arithmetical computations should be 0-rank tensors while scalar values meant to serve as metadata should be `PrimValue`s. -- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators; additionally, we treat TIR `PrimFunc`s as opaque objects. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. Another noteworthy value in this category is the _null object_ (the result of returning a null pointer in C++ or passing in `None` through the Python FFI), which is returned by the `null_value()` operator. +- Additionally, there are further *arbitrary objects* that do not belong in the above categories. These can be returned by `PackedFunc`s and operators. Though Relax expressions other than `PackedFunc` and operator calls cannot use those objects, Relax should pass around these values faithfully. In the future we may add more value types in order to distinguish between different objects, but at present we treat these all as arbitrary values with `ObjectStructInfo`. Note that, for now, strings and TIR datatypes are also treated as opaque objects. Another noteworthy value in this category is the _null object_ (the result of returning a null pointer in C++ or passing in `None` through the Python FFI), which is returned by the `null_value()` operator. ## Representation of Values at Run Time @@ -728,6 +735,11 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 5. Give an error if `Sb` is incompatible with `Sr` via `check_compatibility` (warn if only possibly compatible). 6. If `ret_struct_info` is defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], ret_struct_info, purity=is_pure)` as the structural information for the function. If `ret_struct_info` is not defined, use `FuncStructInfo(params=[S1, S2, ..., Sn], erase_to_well_defined(Sb, Γ, Σ), purity=is_pure)`. 7. Remove all variables added to `Γ` and `Σ` during the above steps of the derivation. +16. For `PrimFunc(params, body, ret_type, buffer_map, attrs)` at the module level, which is bound to a `GlobalVar`: + 1. Suppose there are `n` members of `params`. For the `i`th member of `params` (let us call it `v`), let `si` be a corresponding `StructInfo` defined as follows: + 1. If `v` is not in `buffer_map`, then `si` is `PrimType(d)`, where `d` is the `dtype` field of `v`. + 2. If `v` is in `buffer_map`, then let `b` be `buffer_map[v]`. Then, `si` is `TensorStructInfo(d, ShapeExpr(s))`, where `d` is the `dtype` field of `b` and `s` is the `shape` field of `b`. + 2. The `StructInfo` for the `PrimFunc` (namely, for the `GlobalVar` to which the `PrimFunc` is bound) is `FuncStructInfo([s0, s1, ..., sn-1], TupleStructInfo([]), purity=False)`. (`PrimFunc`s work by mutating their arguments, so direct calls to `PrimFunc`s are treated as impure; in order to call a `PrimFunc` from within a `DataflowBlock`, use `call_tir`, which allocates fresh tensors for the outputs.) ### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks @@ -788,7 +800,7 @@ def erase_struct_info(si: StructInfo) -> Type: In the `IRModule`, every mapping of a `GlobalVar` to a `Function` node or a TIR `PrimFunc` should be processed first and added to the global scope. «Global functions that have a `global_symbol` attribute should be externally linked, meaning that they can be invoked as program entry points; those that do not have a `global_symbol` attribute can be called only from within the global functions in the `IRModule`.» -The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax; these objects are of `ObjectStructInfo` and can be used only by the `call_tir` operator. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. +The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax, but are also assigned `FuncStructInfo` and can be called like closures. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. ## Evaluating Expressions @@ -813,6 +825,7 @@ For each expression, we define how it affects the program's visible state and th 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. 1. If `op` evaluated to a closure, push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) 2. If `op` evaluated to a `PackedFunc`, simply invoke it. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. + 3. Similarly, if `op` evaluates to a `PrimFunc` representation, the `PrimFunc` is directly called with its arguments (it likely mutates one or more of them as a result). 13. For the node `SeqExpr(blocks, body)`, we evaluate as follows: 1. Push a new scope onto the stack. 2. Iterate through the `BindingBlock`s in `blocks` in order. We will call the current one `block`. For each binding in `Block`: From 4827f9dddb9796e807a2fdce85b010bd8d809c73 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 8 Sep 2023 16:44:53 -0400 Subject: [PATCH 43/47] Add value field to PrimStructInfo, allow shape vars to appear in function signature in any order --- relax_spec.md | 75 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index cd6fe63f75cc..8e7f9397df0f 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -65,7 +65,7 @@ DataType ::= Int(bits: int, lanes: int) StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) | ShapeStructInfo(values: [PrimExpr]?, ndim: int) - | PrimStructInfo(dtype: DataType) + | PrimStructInfo(dtype: DataType, value: PrimExpr?) | ObjectStructInfo() | TupleStructInfo(fields: [StructInfo]) | FuncStructInfo(params: [StructInfo]?, ret: StructInfo, purity: bool, derive_func: EnvFunc?*) @@ -286,26 +286,29 @@ The following criteria apply to all programs (including before normalization): 3. A `Var` of any kind may not appear before it is bound. Namely, if a `Var` is bound in a `BindingBlock` in a `SeqExpr`, that `Var` may not appear in bindings that precede the one where it appears on the LHS. 4. «A return structural annotation for a function is not allowed to use any shape variables that are not in scope at the function definition. That is, the only shape variables that can appear on the return structural annotation are those defined in the outer scope or those introduced in the argument structural annotations.» 5. In each function, `PrimExpr` variables (shape variables) similarly may not appear in `ShapeExpr`s or shape annotations before the shape variables are bound (either in function signatures or `MatchCast` bindings). A shape variable is bound only when it appears in a dimension by itself (for example, a dimension consisting of `x` will bind `x`; however, `2*x` is not a binding and is considered an error if `x` has not yet been bound) in a `MatchCast` node or a function argument shape annotation. -6. The following constructs are not permitted to occur inside `DataflowBlock`s, which must be side effect– and control flow–free: +6. In a function signature, every shape variable must appear in a binding position at least once; however, (for convenience) we do not enforce any ordering amongst the function arguments—for example, it is permitted to have a shape `x * y` in the first argument and have `x` and `y` appear in binding positions in later arguments. In such a case, the dimensions corresponding to the binding positions will be checked first, allowing the variables to be bound. Then the other dimensions will be checked. +7. The following constructs are not permitted to occur inside `DataflowBlock`s, which must be side effect– and control flow–free: 1. Recursive calls to the current function 2. Calls to a global function that is mutually recursive with the current function 3. `If` nodes Calls to Relax functions, `ExternFuncs`, or `Op`s that are not pure are also not permitted, but this must be detected during `StructInfo` checking. -7. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» -8. `Op` nodes may appear only as the `op` argument to `Call` nodes. -9. If a variable has a `StructInfo` annotation, the `ndim` of any `TensorStructInfo` and `ShapeStructInfo`s must match the number of dimensions in their `shape` and `values` fields, respectively. -10. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. -11. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» -12. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» -13. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. -14. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. -15. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. -16. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s (with `lanes` set to 1). -17. `PrimStructInfo` annotations should use only the `Int`, `UInt`, or `Float` datatypes for their `dtype` fields. -18. Per [the notes on `DataType`](#notes-on-datatype-and-related-terminology), any `DataType` annotation must have a `lanes` value of 1 for the `Int`, `UInt`, or `Float` datatypes and a `lanes` value of 0 for the `Handle` (`Void`) datatype. Additionally, `bits` must be 64 for `Void`. The supported bitwidths for `Int` and `UInt` are 1, 8, 16, 32, and 64; the supported bitwidths for `Float` are 16, 32, and 64. -19. If a `Function` `f` has an `attrs` field that includes the attribute `relax.force_pure`, `f`'s `is_pure` field must be set to `True`. +8. «For functions that contain recursive calls to themselves or mutually recursive global functions (i.e., those where function `a` calls function `b` and function `b` calls function `a`), a return structural annotation is *required*.» +9. `Op` nodes may appear only as the `op` argument to `Call` nodes. +10. If a variable has a `StructInfo` annotation, the `ndim` of any `TensorStructInfo` and `ShapeStructInfo`s must match the number of dimensions in their `shape` and `values` fields, respectively. +11. A function definition inside a `DataflowBlock` may not use `DataflowVar`s from the outer scope in its body. We do not define closure capturing for `DataflowVar`s. +12. «At least one global function in the `IRModule` must be externally linked (have a `global_symbol` attribute) in order to serve as a program entry point.» +13. «If a global function has a defined `global_symbol` attribute, the `global_symbol` name must be the same as the `GlobalVar`'s name hint.» +14. If the `shape` field of a `TensorStructInfo` in any structural annotation is given, the only permissible expressions are `Var` (the variable must be in scope at the location of the annotation) or `ShapeExpr` (in which any shape variables used must already be in scope, unless the `TensorStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear in a dimension by itself). Additionally, if the `shape` field is a `ShapeExpr`, the number of dimensions must match the `ndim` field. +15. If the `values` field of a `ShapeStructInfo` in any structural annotation is given, any shape variables used in it must already be in scope, unless the `ShapeStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as a member of `values`. Additionally, the `ndim` field must match the length of `values`. +16. Similarly, if the `value` field of `PrimStructInfo` is defined, any shape variables used in it must already be in scope, unless the `PrimStructInfo` is the `struct_info` field of a `MatchCast`, in which case a new shape variable is allowed to appear by itself as `value`. +17. The `params` and `derive_func` field may not be simultaneously defined in a `FuncStructInfo` annotation; that is, if one is defined, the other must not be defined. Additionally, at least one of `params` and `derive_func` _must_ be defined for each `FuncStructInfo` in an annotation. +18. `PrimValue` nodes are intended only to be used with `value`s consisting of TIR `IntImm`s and `FloatImm`s (with `lanes` set to 1). +19. `PrimStructInfo` annotations should use only the `Int`, `UInt`, or `Float` datatypes for their `dtype` fields. +20. Per [the notes on `DataType`](#notes-on-datatype-and-related-terminology), any `DataType` annotation must have a `lanes` value of 1 for the `Int`, `UInt`, or `Float` datatypes and a `lanes` value of 0 for the `Handle` (`Void`) datatype. Additionally, `bits` must be 64 for `Void`. The supported bitwidths for `Int` and `UInt` are 1, 8, 16, 32, and 64; the supported bitwidths for `Float` are 16, 32, and 64. +21. If a `Function` `f` has an `attrs` field that includes the attribute `relax.force_pure`, `f`'s `is_pure` field must be set to `True`. +22. For `PrimStructInfo`, if the `value` field is defined, the TIR `dtype` for the `PrimExpr` must match the `PrimStructInfo`'s `dtype` field (i.e., the datatypes must be consistent). Additionally, the criteria for normal form listed in [the previous section](#normal-form) must apply to any program that has been normalized. @@ -354,7 +357,7 @@ This section describes the run-time checking performed by `MatchCast(var, value, 2. If `shape` is a `ShapeExpr`, then compare the fields of the `ShapeExpr` to the concrete shape of `value`, dimension by dimension (comparing the `i`th field of the `ShapeExpr` to the `i`th dimension of the shape of `value`). Give an error if the number of the dimensions does not match the number of fields in the `ShapeExpr`. 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. 2. Otherwise, evaluate the field of the `ShapeExpr` and ensure that it matches the concrete value of the dimension. -3. If `struct_info` is `PrimStructInfo(dtype)`, then check that `value` is a `PrimValue` and that the underlying scalar has datatype `dtype` in TIR (according to TIR's type-checking rules). +3. If `struct_info` is `PrimStructInfo(dtype, v)`, then check that `value` is a `PrimValue` and that the underlying scalar has datatype `dtype` in TIR (according to TIR's type-checking rules). If `v` is defined, then check that `value` and `v` match numerically. 4. If `struct_info` is `ShapeStructInfo(ndim, values)`, then check that `value` is a shape value, that it has `ndim` dimensions (if `ndim` is not -1). If `values` is defined, then compare it to the concrete shape value (comparing the `i`th member of `values` to the `i`th field of the shape value): 1. If the `i`th member of `values` consists of only an unbound shape variable, then bind that variable to the `i`th field of the the concrete shape value. 2. Otherwise, evaluate the `i`th member of `values` and check that it is equal to teh `i`th field of the concrete shape value. @@ -363,14 +366,23 @@ This section describes the run-time checking performed by `MatchCast(var, value, ### Checking Structural Information at the Start and End of a Function -«Shape variables are bound at the start and end of a function or in `MatchCast` bindings. We can describe the behavior at the start and end of a function in terms of the semantics of `MatchCast`, as the shape annotations in function arguments and return types are treated as "syntactic sugar" for `MatchCast` bindings. Suppose a function has the following signature, where the `Si` are structural annotations: +«Shape variables are bound at the start and end of a function or in `MatchCast` bindings. This checking is done similarly to `MatchCast`, though with a slight difference: per rule #6 in under the [well-formedness criteria](#well-formedness-criteria), we allow shape variables to appear in arguments in any order so long as shape variables appear in a binding position at least once. This requires us to check the shapes of arguments dimension by dimension in a specific order. + +Suppose a function has the following signature, where the `Si` are structural annotations: ```python def f(arg1 : S1, arg2 : S2, ..., argn : Sn) -> Sr: return body ``` -This can be treated as a macro that expands to +The dimensions corresponding to variables in binding positions are checked first. A binding position is when a shape variable appears by itself in a dimension (a field in `values` in `ShapeStructInfo`, the `value` field of `PrimStructInfo`, or a field in `shape` for `TensorStructInfo`). For each variable in a binding position that appears among the `Si`, we check the corresponding field of the concrete value in order to assign to it a value (recursing down the structure if necessary). + +1. For `PrimStructInfo`, we compare `value` to the concrete primitive value. +2. For `ShapeStructInfo`, we compare the field within `values` to the corresponding member of the shape. +3. For `TensorStructInfo`, we compare the field of `shape` to the length of the corresponding dimension of the tensor. +4. For other `StructInfo`, it may be necessary to match values recursively. In the case of a closure, the runtime shape information will be needed. + +Having found a value, that value is bound to the shape variable. The rest of the tensor shapes and the return value can then be checked from left to right per the following macro: ```python def f(arg1, arg2, ..., argn): @@ -382,6 +394,10 @@ def f(arg1, arg2, ..., argn): MatchCast(ret_var, Sr) return ret_var ``` + +For a concrete example, suppose that the function has a signature +`def f(x: R.Tensor([M * N]), y: R.Tensor([M, N])) -> R.Tensor([N * N, M * M]): ...`. +In this case, `M` would first be bound by checking dimension 0 of the value of `y`, `N` would then be bound by checking dimension 1 of the value of `y`. Next, the shape of `x` would then be compared against `M * N` using the bound values, then the shape of `y` would be compared against `(M, N)` using the bound values. At the end of the function, the shape of the return value would be compared against `(N * M, M * M)` using the bound values. » ### Invariants for `TensorStructInfo` @@ -416,7 +432,10 @@ Note that judging subtyping requires potentially reasoning about arbitrary `Shap 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=n, values=undefined)`. 3. Given an arbitrary `ndim` `n` and two arbitrary sets of values `v1` and `v2` (both defined), `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` if, for all valid `i`, `v1[i]` and `v2[i]` can be proven to be _definitely_ statically equal. We say that `ShapeStructInfo(ndim=n, values=v1) <: ShapeStructInfo(ndim=n, values=v2)` _possibly_ holds if `v1` and `v2` are _possibly_ statically equal. 6. Given two lists of `StructInfo` `fields1` and `fields2`, `TupleStructInfo(fields=fields1) <: TupleStructInfo(fields=fields2)` if `fields1` and `fields2` are the same length and for all `i`, `fields1[i] <: fields2[i]`. We consider the subtyping relationship to _possibly_ hold if any of the subtyping relationships for the fields only possibly holds. -7. For `PrimStructInfo`, `PrimStructInfo(dt1) <: PrimStructInfo(dt2)` holds if `dt1` and `dt2` are the same. That is, we do not have subtyping for TIR datatypes or `PrimStructInfo`. +7. For `PrimStructInfo`: + 1. `PrimStructInfo(dtype=dt1) <: PrimStructInfo(dtype=dt2)` (where the `value` field is undefined for both) holds if `dt1` and `dt2` are the same. That is, we do not have subtyping for TIR datatypes. + 2. Let `dt` be a datatype. `PrimStructInfo(dtype=dt, value=v) <: PrimStructInfo(dtype=dt)` for any `PrimExpr` `v`; that is, if the `value` field is undefined for a `PrimStructInfo`, then it is a supertype to a `PrimStructInfo` with a defined `value` field. + 3. Let `dt` be a datatype. `PrimStructInfo(dtype=dt, value=v1) <: PrimStructInfo(dtype=dt, value=v2)` _definitely_ holds if `v1` and `v2` can be proven to be statically equal. The relation _possibly_ holds if `v1` and `v2` are _possibly_ statically equal. 8. For `FuncStructInfo`: 1. Given an arbitrary derivation function `derive_func`, `FuncStructInfo(ret=ObjectStructInfo(), derive_func=derive_func) <: FuncStructInfo(ret=ObjectStructInfo(), derive_func=empty_derive)`. 2. Corollary, following from reflexivity: For two `FuncStructInfo` `F1` and `F2` with undefined `params`, `F1 <: F2` only if `F1.derive_func` and `F2.derive_func` are identical. @@ -436,9 +455,14 @@ def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: if S1 and S2 do not match types (e.g., not both TensorStructInfo, etc): return ObjectStructInfo() if S1 and S2 are both PrimStructInfo: - if S1.dtype == S2.dtype: + if S1.dtype != S2.dtype: + return ObjectStructInfo() + if S1.value or S2.value is undefined: + return PrimStructInfo(dtype=S1.dtype, value=undefined) + if S1.value can be statically proven to match S2.value: return S1 - return ObjectStructInfo() + # values either proven not to match or unknown + return PrimStructInfo(dtype=S1.dtype, value=undefined) if S1 and S2 are both ShapeStructInfo: if S1.ndim == -1: return S1 @@ -658,14 +682,19 @@ We can check if some structural information `S1` is accepted where structural in 3. If `values` is not defined for `S2`, then they are compatible. 4. If `values` is defined for `S2` but not defined for `S1`, then they are possibly compatible. 5. If `values` is defined for both `S1` and `S2`, then the two are incompatible if `S1.values[i]` can be proven to be _not_ equal to `S2.values[i]` for some `i`. If all members can be proven to be equal, then they are compatible. Otherwise, if at least one pair of values cannot be proven to be either equal or unequal, then they are possibly compatible. -5. If `S1` and `S2` are both `TensorStructInfo`: +5. If `S1` and `S2` are both `PrimStructInfo`: + 1. If `S1.dtype` and `S2.dtype` do not match, then they are incompatible. + 2. If `value` is not defined for `S2`, then they are compatible. + 3. If `value` is defined for `S2` but not for `S1`, then they are possibly compatible. + 4. If `value` is defined for both `S1` and `S2`, then they are compatible if `S1.value` can be statically proven to be equal to `S2.value`. They are possibly compatible if `S1.value` is possibly statically equal to `S2.value` but it cannot be proven. They are incompatible if `S1.value` can be proven to _not_ be statically equal to `S2.value`. +6. If `S1` and `S2` are both `TensorStructInfo`: 1. If `S2.dtype` is not `Void` and does not match `S1.dtype`, then they are incompatible. 2. If `S2.ndim` is not -1 and does not match `S1.ndim`, then they are incompatible. 3. If `S2.shape` is not defined, then they are compatible. 4. If `S2.shape` is defined and `S1.shape` is not defined, then they are possibly compatible. 5. Otherwise, if both `shape` fields are given and either is a `Var`, then consider `S1` and `S2` compatible if the compiler can statically prove that the `Var` holds the same value as the other `shape` field, consider them possibly compatible if the compiler cannot draw a conclusion one way or the other, and consider them incompatible if the `Var` definitely has a different value from the other `shape`. 6. If both `shape` fields are given and they are both `ShapeExpr` nodes, then `S1` and `S2` are incompatible if the compiler can prove that some dimension of `S1.shape` is _not_ equal to the corresponding dimension of `S2.shape`. Otherwise, if the all dimensions can be proven to be equal, then consider them compatible. If at least one pair of dimensions cannot be proven to be equal or unequal, consider them possibly compatible. -6. If `S1` and `S2` are both `FuncStructInfo`: +7. If `S1` and `S2` are both `FuncStructInfo`: 1. If `S1` and `S2` don't both have defined `params` or both have undefined `params`, consider them incompatible. 2. If both `S1` and `S2` have undefined `params`, consider them compatible if they have an identical `derive_func` and consider them possibly compatible if they have different `derive_func`s (as they is no further way to introspect the `derive_func` and draw static conslusions about `PackedFunc`s). 3. If `params` is defined for both `S1` and `S2`: From 47a2f364db02dfe2d99daec47bf2a7fa174b542b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 20 Sep 2023 18:05:11 -0400 Subject: [PATCH 44/47] Note that purity also prohibits references to external state --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 8e7f9397df0f..7747ad667256 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -163,7 +163,7 @@ This specification provides a more detailed description of what each expression ## Purity and Dataflow Blocks -A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. +A function or operator is called "pure" if it does not have side effects, which refers to any change in program state besides returning a result, and depends only on the values of its arguments. Side effects include mutating values other than those they create, aborting the program, or file I/O (including writing to the console). Purity is a useful property for compiler optimizations, since calls to pure functions can be reordered or duplicated or (if the result is unused) eliminated without changing any other program behavior. Note that referring to external state separate from the function arguments (e.g., like the system clock) also renders a function impure; for example, the compiler would not be able to assume that it is safe to reorder the function calls (doing so could affect the results). Most deep learning operators are pure, as they perform arithmetic on tensors and return a new tensor containing the result. Above, it is mentioned that `DataflowBlock`s are not allowed to contain constructs featuring control flow (`If` nodes or recursive calls to the current function) or calls to impure functions. This ensures that `DataflowBlock`s represent a directed acyclic graph of pure operations, which is similar to the graph-like abstractions of traditional deep learning frameworks. This allows many common optimizations from past frameworks to be directly adapted to `DataflowBlock`s without having to accommodate additional reasoning about more expressive features like control flow and side effects. From c0a1f4e90940716caa2f5d8448b1cbad65a77015 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 14 Nov 2023 21:50:23 -0500 Subject: [PATCH 45/47] Add mention of operator-specific normalization rules --- relax_spec.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/relax_spec.md b/relax_spec.md index 7747ad667256..89e9c92ac42f 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -267,6 +267,7 @@ The normal form for Relax is very similar to ANF; differences will be noted. Her 2. In the `true_branch` and `false_branch` fields of `If` nodes. 3. In fact, the `body` field of a `Function` node and the `true_branch` and `false_branch` fields of `If` nodes _must_ be `SeqExpr`s. If these fields are not `SeqExpr`s, they must be "wrapped" in a `SeqExpr`. 4. Within a `SeqExpr`, `BindingBlock`s must be consolidated. For example, if there is a `BindingBlock` that comes after another `BindingBlock`, the two blocks should be combined to form a single `BindingBlock` with all the bindings in the same order. Consecutive `DataflowBlock`s should be consolidated as well. Empty `BindingBlock`s should be dropped. However, a `DataflowBlock` cannot be consolidated with an ordinary `BindingBlock`. If all the `BindingBlock`s are empty, then the `blocks` field of the `SeqExpr` should be set to an empty list. +5. Calls to `Op` nodes can have custom normalization rules in order to ensure that calls to those operators will conform to certain specific rules (ideally, these should be _more_ and not _less_ restrictive than the other rules of normal form). In particular, `call_tir` and related operators include a custom normalization rule that requires the arguments to the `PrimFunc` to be provided as a tuple _literal_, rather than, say, a variable that evaluates to a tuple. Programs that are parsed should be "normalized" before performing `StructInfo` checking or before doing any further optimizations. Note that the process of "flattening" `SeqExpr`s and consolidating `BindingBlock`s does increase the visibility of the variables in those `SeqExpr`s and `BindingBlock`s, but this is safe, since it will not cause any variable to be referenced outside of its original scope. The specification does not require any particular method of normalizing a program so long as the final program conforms to the above-listed criteria. Here is a general approach: 1. For each function in the `IRModule`, ensure that the body is a `SeqExpr`. If the body is not a `SeqExpr`, wrap the function body in a `SeqExpr`, creating a new `BindingBlock` to hold `VarBinding`s for any non-leaf expressions that need to be bound to variables. @@ -274,6 +275,7 @@ Programs that are parsed should be "normalized" before performing `StructInfo` c 3. If the function body is not a `SeqExpr`, then recurse down the body's AST, binding any nested non-leaf expressions to a var in the current scope (doing this process in breadth-first order from left to right will respect the evaluation order in the semantics). If the body itself is a non-leaf expression, finally bind it to a var and have the final `SeqExpr` return the new var. 4. If an `If` node is encountered, ensure the `true_branch` and `false_branch` fields are `SeqExpr`s (consolidate `BindingBlock`s if necessary) or "wrap" them in `SeqExpr`s in the same manner as the function body. 5. If a `SeqExpr` node is encountered as the `value` node in a binding, "flatten" the `SeqExpr` by adding its bindings to the current scope and replacing the `SeqExpr` with its body. If the `SeqExpr` body is a non-leaf expression, normalize it recursively in the same manner as in step 3 before replacing the binding. Note that if the current scope (the location of the binding) is a `DataflowBlock` and the nested `SeqExpr` contains an ordinary `BindingBlock`, that indicates a malformed program. +6. For calls to `Op`s, check if it has a custom normalization rule and apply the custom normalization rule. # Well-Formedness Criteria @@ -309,6 +311,7 @@ The following criteria apply to all programs (including before normalization): 20. Per [the notes on `DataType`](#notes-on-datatype-and-related-terminology), any `DataType` annotation must have a `lanes` value of 1 for the `Int`, `UInt`, or `Float` datatypes and a `lanes` value of 0 for the `Handle` (`Void`) datatype. Additionally, `bits` must be 64 for `Void`. The supported bitwidths for `Int` and `UInt` are 1, 8, 16, 32, and 64; the supported bitwidths for `Float` are 16, 32, and 64. 21. If a `Function` `f` has an `attrs` field that includes the attribute `relax.force_pure`, `f`'s `is_pure` field must be set to `True`. 22. For `PrimStructInfo`, if the `value` field is defined, the TIR `dtype` for the `PrimExpr` must match the `PrimStructInfo`'s `dtype` field (i.e., the datatypes must be consistent). +23. For any `Call` node where the callee (`op` field) is an `Op` node, if the `Op` has a custom normalization rule, the call must conform to that rule. In particular, applying to the normalization rule to the `Call` should not require any further changes. Additionally, the criteria for normal form listed in [the previous section](#normal-form) must apply to any program that has been normalized. From 27296b98a4516cb6b358fb4b26d305d2b78e5fa2 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 20 Dec 2023 18:25:33 -0500 Subject: [PATCH 46/47] Minor changes and add heterogeneous semantics --- relax_spec.md | 107 +++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 84 insertions(+), 23 deletions(-) diff --git a/relax_spec.md b/relax_spec.md index 89e9c92ac42f..3cb18baf4359 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -57,13 +57,22 @@ PrimExpr ::= PrimFunc ::= PrimFunc(params: [tir::Var], body: tir::Stmt, ret_type: tir::Type?, buffer_map: {tir::Var: tir::Buffer}, attrs: Attrs) + +# VDevice is used to indicate target devices for heterogeneous computing +Target ::= Target() # null target + | Target(tag: string) + | Target(config: {String, ObjectRef}) + +VDevice ::= VDevice(tgt: Target, int: dev_id, mem_scope: string) + # Also from TIR DataType ::= Int(bits: int, lanes: int) | UInt(bits: int, lanes: int) | Float(bits: int, lanes: int) | Handle(bits: int, lanes: int) -StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, ndim: int) + +StructInfo ::= TensorStructInfo(shape: Expr?, dtype: DataType, vdevice: VDevice?, ndim: int) | ShapeStructInfo(values: [PrimExpr]?, ndim: int) | PrimStructInfo(dtype: DataType, value: PrimExpr?) | ObjectStructInfo() @@ -103,7 +112,12 @@ Binding ::= # Relax programs are IRModules. Modules may bind global variables either to # Relax functions or TIR PrimFuncs. # The Relax compiler may analyze and modify the TIR PrimFuncs as well. -Program ::= IRModule(funcs: {GlobalVar: Function|PrimFunc}) +# Note that there can be global info other than VDevices, but only VDevice +# is used by Relax at present. +Program ::= IRModule( + funcs: {GlobalVar: Function|PrimFunc} + global_info: {string: VDevice} + ) ``` ### Notes on `derive_func` @@ -181,7 +195,7 @@ Exiting with an error and infinitely looping are traditionally considered "[dive ## Structural Information (`StructInfo`) System Survey Analogously to a type system in most languages, Relax tracks structural information (referred to as `StructInfo` in the implementation) related to the categories of values in Relax: -1. `TensorStructInfo` corresponds to tensor values, giving the scalar data type, the number of dimensions (rank), and an expression that computes the tensor's shape (either a `ShapeExpr` or a `Var`), all of which are optional. +1. `TensorStructInfo` corresponds to tensor values, giving the scalar data type, the number of dimensions (rank), and an expression that computes the tensor's shape (either a `ShapeExpr` or a `Var`), all of which are optional. The optional `vdevice` ("virtual device") field, if present, indicates which device a tensor is located on. Tensor operators must be implemented on the appropriate device. 2. `TupleStructInfo` corresponds to tuple values, giving the `StructInfo` for each member of the tuple. 3. `PrimStructInfo` corresponds to `PrimValue`s (immutable scalar values), giving their TIR datatype. 4. `ShapeStructInfo` corresponds to shape values, optionally giving the number of dimensions in the shape and an expression that computes the shape's dimensions (either a `ShapeExpr` or a `Var`). @@ -204,7 +218,7 @@ Oftentimes, compiler passes operate only on particular functions or add new func Here are the classes of values that Relax operates over, meaning that they can be assigned to variables or be the result of evaluating expressions. -- *Tensors* are n-dimensional arrays of scalar values (which can be signed or unsigned integers of fixed bitwidths, floats of fixed bitwidths, or Boolean values). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. +- *Tensors* are n-dimensional arrays of scalar values (which can be signed or unsigned integers of fixed bitwidths, floats of fixed bitwidths, or Boolean values). A tensor's *shape* is a tuple of the size of each dimension; the number of dimensions is a tensor's *rank*. For example, a vector (1, 2, 3) is a rank-1 tensor of shape `(3,)`. Note that scalars are tensor values with a rank of 0, meaning that their shape is `()`. Tensors can be _located_ on different devices in the program, namely one of the `VDevice`s listed in the `IRModule`'s `global_info` map; if tensors are located on different devices, it may be necessary to insert operators like `to_vdevice` in order to transfer them so that they can be used together in an operator call. - *Tuples* represent a fixed-size immutable grouping of other Relax values (tensors, closures, shapes, objects, or other tuples, to an arbitrary degree of nesting). Note that an empty tuple, i.e., `()`, also called "unit" in functional programming, is commonly used as the return value for operations not intended to return a value (as may be the case in some `PackedFunc` or operator calls that have side effects). - *Closures* are the values resulting from evaluating Relax function expressions; closures can be passed around like other values, ensuring that functions are first-class in Relax. Functions defined in Relax can capture variables from outer scopes. A [closure](https://en.wikipedia.org/wiki/Closure_(computer_programming)) consists of a function and a mapping of any variables "captured" (those are *free variables* in the function body, variables from an outer scope that are neither arguments nor defined within the function but are used in the function) to their values. Closures capture both Relax-level local variables and shape variables from outer scopes. A closure also stores a name for itself when the body contains recursive calls. «Closures additionally carry some *run-time structural information* (RTSI) indicating their argument and result structures, in order to facilitate dynamic structural checks (since it is not otherwise possible to introspect the function contained within a closure); the precise form of the RTSI is left up to the compiler implementation to determine so long as `MatchCast` can verify the structure of a closure, including whether it is pure. Closures can be evaluated in a call node, which results in calling the function with the call's arguments and the captured values.» - *Tensor shapes* (shape values) are immutable tuples of integers describing a tensor shape, obtained by evaluating `ShapeExpr`s. @@ -312,6 +326,7 @@ The following criteria apply to all programs (including before normalization): 21. If a `Function` `f` has an `attrs` field that includes the attribute `relax.force_pure`, `f`'s `is_pure` field must be set to `True`. 22. For `PrimStructInfo`, if the `value` field is defined, the TIR `dtype` for the `PrimExpr` must match the `PrimStructInfo`'s `dtype` field (i.e., the datatypes must be consistent). 23. For any `Call` node where the callee (`op` field) is an `Op` node, if the `Op` has a custom normalization rule, the call must conform to that rule. In particular, applying to the normalization rule to the `Call` should not require any further changes. +24. «All `VDevice`s reference in `StructInfo` annotations _must_ appear in the `IRModule`'s `global_info` map. (Corollary: If no `VDevice` is given in `global_info`, then _all_ `vdevice` fields in `TensorStructInfo` annotations must remain undefined.)» Additionally, the criteria for normal form listed in [the previous section](#normal-form) must apply to any program that has been normalized. @@ -324,7 +339,7 @@ Tensor shapes are the primary motivation for including structural information in ## Defining Structural Information The structural information in Relax corresponds to the values in the language: -* `TensorStructInfo` describes tensor values. The `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` (with `ShapeStructInfo`). If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation that returns a shape value, which can be useful for memory planning. +* `TensorStructInfo` describes tensor values. The `dtype` field gives the datatype (with `Void` indicating a statically unknown datatype), the `ndim` field gives the rank (with -1 indicating a statically unknown rank). Unlike `DynTensorType`, there is an optional `shape` field which, if defined, describes the shape of the tensor using either a `ShapeExpr` or a `Var` (with `ShapeStructInfo`). If `shape` is a `ShapeExpr`, the `PrimExpr`s in the `ShapeExpr`'s dimensions describe how to compute each dimension of the shape (or are constants). If `shape` is a `Var`, the `Var` can assign the result of an arbitrary computation that returns a shape value, which can be useful for memory planning. The `vdevice` field, if present, indicates on which device the tensor value is located, since `NDArray`s can be allocated on different devices (if absent, then that means that the device is unspecified). * `ShapeStructInfo` describes shape values. It has an `ndim` field that gives the number of dimensions in the shape (with -1 indicating that it is statically unknown). It additionally has an optional `values` field. If defined, `values` gives a list of `PrimExpr`s that indicate how to compute the dimensions of the shape, potentially providing further information for static analyses. * `PrimStructInfo` describes `PrimValue`s, giving their TIR datatype. * `TupleStructInfo` describes tuple values, namely by giving the `StructInfo` for each of the tuple's members via `fields`. @@ -355,7 +370,7 @@ Because structural information is checked in a "best-effort" fashion, it is not This section describes the run-time checking performed by `MatchCast(var, value, struct_info)`, for each combination of value and structural annotation (if `var` is defined, then `value` will be bound to `var` as discussed in the [general section on semantics](#detailed-semantics)). If any check given below fails, an error is raised by the `MatchCast`. 1. If `struct_info` is `ObjectStructInfo`, then no additional check is performed. All values in Relax are objects. -2. If `struct_info` is `TensorStructInfo(ndim, dtype, shape)`, then check that `value` is a tensor value, that it has a rank of `ndim` (if `ndim` is not -1), a datatype of `dtype` (if `dtype` is not `Void`). If `shape` is defined, consider the following cases: +2. If `struct_info` is `TensorStructInfo(shape, dtype, vdevice, ndim)`, then check that `value` is a tensor value, that it has a rank of `ndim` (if `ndim` is not -1) and a datatype of `dtype` (if `dtype` is not `Void`), and is located on the device denoted by `vdevice` (if defined). If `shape` is defined, consider the following cases: 1. If `shape` is a `Var`, then check that the concrete shape of `value` matches the value bound to the `Var`. 2. If `shape` is a `ShapeExpr`, then compare the fields of the `ShapeExpr` to the concrete shape of `value`, dimension by dimension (comparing the `i`th field of the `ShapeExpr` to the `i`th dimension of the shape of `value`). Give an error if the number of the dimensions does not match the number of fields in the `ShapeExpr`. 1. If a field of the `ShapeExpr` consists of only an unbound shape variable, then bind that variable to the value of the dimension. @@ -369,7 +384,7 @@ This section describes the run-time checking performed by `MatchCast(var, value, ### Checking Structural Information at the Start and End of a Function -«Shape variables are bound at the start and end of a function or in `MatchCast` bindings. This checking is done similarly to `MatchCast`, though with a slight difference: per rule #6 in under the [well-formedness criteria](#well-formedness-criteria), we allow shape variables to appear in arguments in any order so long as shape variables appear in a binding position at least once. This requires us to check the shapes of arguments dimension by dimension in a specific order. +Shape variables are bound at the start and end of a function or in `MatchCast` bindings. This checking is done similarly to `MatchCast`, though with a slight difference: per rule #6 in under the [well-formedness criteria](#well-formedness-criteria), we allow shape variables to appear in arguments in any order so long as shape variables appear in a binding position at least once. This requires us to check the shapes of arguments dimension by dimension in a specific order. Suppose a function has the following signature, where the `Si` are structural annotations: @@ -401,7 +416,6 @@ def f(arg1, arg2, ..., argn): For a concrete example, suppose that the function has a signature `def f(x: R.Tensor([M * N]), y: R.Tensor([M, N])) -> R.Tensor([N * N, M * M]): ...`. In this case, `M` would first be bound by checking dimension 0 of the value of `y`, `N` would then be bound by checking dimension 1 of the value of `y`. Next, the shape of `x` would then be compared against `M * N` using the bound values, then the shape of `y` would be compared against `(M, N)` using the bound values. At the end of the function, the shape of the return value would be compared against `(N * M, M * M)` using the bound values. -» ### Invariants for `TensorStructInfo` @@ -426,10 +440,11 @@ Note that judging subtyping requires potentially reasoning about arbitrary `Shap 2. Transitivity: For all `S1`, `S2`, and `S3`, if `S1 <: S2` and `S2 <: S3`, then `S1 <: S3`. 3. For all `S1`, `S1 <: ObjectStructInfo()`. 4. For `TensorStructInfo`: - 1. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=-1, dtype=d)`. - 2. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=Void(), shape=s)`. - 3. Given any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s`, `TensorStructInfo(ndim=n, dtype=d, shape=s) <: TensorStructInfo(ndim=n, dtype=d, shape=undefined)`. - 4. Given any datatype `d`, an arbitrary `ndim` `n`, and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` if `s1` and `s2` are _definitely_ statically equal. We say that `TensorStructInfo(ndim=n, dtype=d, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, shape=s2)` _possibly_ holds if `s1` and `s2` are _possibly_ statically equal. + 1. Given any datatype `d`, an arbitrary `ndim` `n`, an arbitrary expression `s` (possibly undefined), and an arbitrary `VDevice` `v` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s, vdevice=v) <: TensorStructInfo(ndim=-1, dtype=d, vdevice=v)`. + 2. Given any datatype `d`, an arbitrary `ndim` `n`, an arbitrary expression `s` (possibly undefined), and an arbitrary `VDevice` `v` (possibly undefined), `TensorStructInfo(ndim=n, dtype=d, shape=s, vdevice=v) <: TensorStructInfo(ndim=n, dtype=Void(), shape=s, vdevice=v)`. + 3. Given any datatype `d`, an arbitrary `ndim` `n`, an arbitrary `VDevice` `v` (possibly undefined), and an arbitrary expression `s`, `TensorStructInfo(ndim=n, dtype=d, vdevice=v, shape=s) <: TensorStructInfo(ndim=n, dtype=d, vdevice=v, shape=undefined)`. + 4. Given any datatype `d`, an arbitrary `ndim` `n`, an arbitrary `VDevice` `v` (possibly undefined), and arbitrary expressions `s1` and `s2` (both defined), then `TensorStructInfo(ndim=n, dtype=d, vdevice=v, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, vdevice=v, shape=s2)` if `s1` and `s2` are _definitely_ statically equal. We say that `TensorStructInfo(ndim=n, dtype=d, vdevice=v, shape=s1) <: TensorStructInfo(ndim=n, dtype=d, vdevice=v, shape=s2)` _possibly_ holds if `s1` and `s2` are _possibly_ statically equal. + 5. Given any `VDevice` `v` (that is defined), any datatype `d`, an arbitrary `ndim` `n`, and an arbitrary expression `s` (possibly undefined), then `TensorStructInfo(ndim=n, dtype=d, vdevice=v, shape=s) <: TensorStructInfo(ndim=n, dtype=d, vdevice=undefined, shape=s)`. 5. For `ShapeStructInfo`: 1. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (possibly undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=-1)`. 2. Given an arbitrary `ndim` `n` and an arbitrary set of values `v` (not undefined), `ShapeStructInfo(ndim=n, values=v) <: ShapeStructInfo(ndim=n, values=undefined)`. @@ -485,6 +500,7 @@ def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: if S1 and S2 are both TensorStructInfo: ndim = S1.ndim if S1.ndim == S2.ndim else -1 dtype = S1.dtype if S1.dtype == S2.dtype else Void + vdev = S1.vdevice if S1.vdevice == S2.vdevice else undefined if ( S1.ndim == -1 or S2.ndim == -1 or S1.ndim != S2.ndim or S1.shape is undefined or S2.shape is undefined @@ -494,7 +510,7 @@ def unify_struct_info(S1: StructInfo, S2: StructInfo) -> StructInfo: if S1.shape can be proven to equal S2.shape: return S1 # either proven to be unequal or cannot be concluded whether they are equal - return TensorStructInfo(ndim=ndim, dtype=dtype) # leave shape undefined + return TensorStructInfo(ndim=ndim, dtype=dtype, vdevice=vdev, shape=undefined) if S1 and S2 are both TupleStructInfo: if S1.fields and S2.fields are of different lengths: return ObjectStructInfo() @@ -691,12 +707,14 @@ We can check if some structural information `S1` is accepted where structural in 3. If `value` is defined for `S2` but not for `S1`, then they are possibly compatible. 4. If `value` is defined for both `S1` and `S2`, then they are compatible if `S1.value` can be statically proven to be equal to `S2.value`. They are possibly compatible if `S1.value` is possibly statically equal to `S2.value` but it cannot be proven. They are incompatible if `S1.value` can be proven to _not_ be statically equal to `S2.value`. 6. If `S1` and `S2` are both `TensorStructInfo`: - 1. If `S2.dtype` is not `Void` and does not match `S1.dtype`, then they are incompatible. - 2. If `S2.ndim` is not -1 and does not match `S1.ndim`, then they are incompatible. - 3. If `S2.shape` is not defined, then they are compatible. - 4. If `S2.shape` is defined and `S1.shape` is not defined, then they are possibly compatible. - 5. Otherwise, if both `shape` fields are given and either is a `Var`, then consider `S1` and `S2` compatible if the compiler can statically prove that the `Var` holds the same value as the other `shape` field, consider them possibly compatible if the compiler cannot draw a conclusion one way or the other, and consider them incompatible if the `Var` definitely has a different value from the other `shape`. - 6. If both `shape` fields are given and they are both `ShapeExpr` nodes, then `S1` and `S2` are incompatible if the compiler can prove that some dimension of `S1.shape` is _not_ equal to the corresponding dimension of `S2.shape`. Otherwise, if the all dimensions can be proven to be equal, then consider them compatible. If at least one pair of dimensions cannot be proven to be equal or unequal, consider them possibly compatible. + 1. If `S2.dtype` is not `Void`, `S1.dtype` is not `Void`, and `S1.dtype` and `S2.dtype` do not match, then they are incompatible. + 2. If `S2.ndim` is not -1, `S1.ndim` is not -1, and `S1.ndim` and `S2.ndim` do not match, then they are incompatible. + 3. If `S2.vdevice` is defined and does not match `S1.vdevice`, then they are incompatible. + 4. If `S2.shape` is not defined, then they are compatible pending step 8. + 5. If `S2.shape` is defined and `S1.shape` is not defined, then they are possibly compatible. + 6. Otherwise, if both `shape` fields are given and either is a `Var`, then consider `S1` and `S2` compatible (pending step 8) if the compiler can statically prove that the `Var` holds the same value as the other `shape` field, consider them possibly compatible if the compiler cannot draw a conclusion one way or the other, and consider them incompatible if the `Var` definitely has a different value from the other `shape`. + 7. If both `shape` fields are given and they are both `ShapeExpr` nodes, then `S1` and `S2` are incompatible if the compiler can prove that some dimension of `S1.shape` is _not_ equal to the corresponding dimension of `S2.shape`. Otherwise, if the all dimensions can be proven to be equal, then consider them compatible pending step 8. If at least one pair of dimensions cannot be proven to be equal or unequal, consider them possibly compatible. + 8. If we have concluded `S1.shape` and `S2.shape` to match in step 4, 6, or 7, then consider `S1` and `S2` possibly compatible if `S1.dtype` is `Void` while `S2.dtype` is not `Void` or if `S1.vdevice` is undefined but `S2.vdevice` is defined. Otherwise, consider `S1` and `S2` compatible. 7. If `S1` and `S2` are both `FuncStructInfo`: 1. If `S1` and `S2` don't both have defined `params` or both have undefined `params`, consider them incompatible. 2. If both `S1` and `S2` have undefined `params`, consider them compatible if they have an identical `derive_func` and consider them possibly compatible if they have different `derive_func`s (as they is no further way to introspect the `derive_func` and draw static conslusions about `PackedFunc`s). @@ -712,7 +730,7 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 1. «Prepopulate `Γ` with the annotated types of all global functions (see the rule for `Function` nodes) that are called mutually recursively. Afterwards check the structural information of the global functions one at a time and populate the entry of `Γ` corresponding to that `GlobalVar`.» 2. For a variable (`Var`, `DataflowVar`, or `GlobalVar`) `v`, look up `Γ[v]` for the structural information. -3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value`. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. +3. For `Constant(value)`, the resulting structural information is `TensorStructInfo(ndim, dtype, shape, vdevice=undefined)` where `ndim` is the concrete rank of `value`, `dtype` is the concrete datatype used in `value`, and `shape` is a `ShapeExpr` giving the concrete shape of `value`. For example, for `Constant(1)`, `shape` is `ShapeExpr([])` and for `Constant([1, 2])`, `shape` is `ShapeExpr([IntImm(2, "int64")])`. 4. For `PrimValue(prim_expr)`, the resulting `StructInfo` is `PrimStructInfo(dt)`, where `dt` is the datatype of `prim_expr`, derived according to the type-checking rules for TIR. 5. For `StringImm(s)`, the resulting `StructInfo` is `ObjectStructInfo()`. 6. For `DataTypeImm(dt)`, the resulting `StructInfo` is `ObjectStructInfo()`. @@ -770,9 +788,44 @@ Let `Γ` be the `StructInfo` context for Relax variables and let `Σ` track whic 16. For `PrimFunc(params, body, ret_type, buffer_map, attrs)` at the module level, which is bound to a `GlobalVar`: 1. Suppose there are `n` members of `params`. For the `i`th member of `params` (let us call it `v`), let `si` be a corresponding `StructInfo` defined as follows: 1. If `v` is not in `buffer_map`, then `si` is `PrimType(d)`, where `d` is the `dtype` field of `v`. - 2. If `v` is in `buffer_map`, then let `b` be `buffer_map[v]`. Then, `si` is `TensorStructInfo(d, ShapeExpr(s))`, where `d` is the `dtype` field of `b` and `s` is the `shape` field of `b`. + 2. If `v` is in `buffer_map`, then let `b` be `buffer_map[v]`. Then, `si` is `TensorStructInfo(dtype=d, shape=ShapeExpr(s), ndim=len(s), vdevice=undefined)`, where `d` is the `dtype` field of `b`, `s` is the `shape` field of `b`. 2. The `StructInfo` for the `PrimFunc` (namely, for the `GlobalVar` to which the `PrimFunc` is bound) is `FuncStructInfo([s0, s1, ..., sn-1], TupleStructInfo([]), purity=False)`. (`PrimFunc`s work by mutating their arguments, so direct calls to `PrimFunc`s are treated as impure; in order to call a `PrimFunc` from within a `DataflowBlock`, use `call_tir`, which allocates fresh tensors for the outputs.) +### Propagating Virtual Device Information + +If the `IRModule` contains `VDevice`s in its global information map, then we additionally propagate virtual device information to `TensorStructInfo` after deriving the `StructInfo` by the above rules. If no `VDevice`s are given in the global information map, then this step is omitted. (Implementation note: This is implemented in the pass `RealizeVDevice`.) Note that this propagation can only succeed if at least some `VDevice`s are manually provided, either through `StructInfo` annotations or calls to related operators like `to_vdevice`. + +We use the following auxiliary procedure, in pseudocode, to set the `vdevice` field in `StructInfo`: + +```python +def update_struct_info(S: StructInfo, v: VDevice) -> StructInfo: + if S is TensorStructInfo: + if S.vdevice is defined and S.vdevice != v: + # this is a compile-time inconsistency + raise error + return TensorStructInfo(ndim=S.ndim, shape=S.shape, dtype=S.dtype, vdevice=v) + «if S is TupleStructInfo: + return TupleStructInfo(fields=[update_struct_info(s, v) for s in S.fields])» + «if S is FuncStructInfo: + if S has a defined derive_func: + return S + return FuncStructInfo(params=S.params, ret=update_struct_info(S.ret, v), purity=S.purity)» + return S +``` + +For each Relax function in the `IRModule`, we will update the `VDevice`s in the "backward" direction by proceeding recursively: +1. For a `Function` node with return `StructInfo` `finfo` and body `body`, suppose its `StructInfo` is `finfo` (which must be `FuncStructInfo`). If the return `StructInfo` (`finfo->ret`) is `TensorStructInfo` with a defined `vdevice` field, then set the `StructInfo` of `body` to `update_struct_info(finfo, v)`. Visit `body` recursively. +2. For a `Call` node with callee `op` (with `StructInfo` `finfo`) and arguments `args` (for which the `i`th member has `StructInfo` `Si`), suppose its `StructInfo` is `S`. If `S` is a `TensorStructInfo` with a defined `vdevice` `v`, set the `StructInfo` of `op` to `update_struct_info(finfo, v)` and for the `i`th member of `args`, set the `StructInfo` to `update_struct_info(Si, v)`. Visit `op` and each member of `args` recursively. +3. For `SeqExpr(blocks=blocks, body=body)`, suppose the `StructInfo` of the entire node is `S`. If `S` is a `TensorStructInfo` with a defined `vdevice` `v`, set the `StructInfo` of `body` to `S`. Recuse down each binding block in `blocks` and each binding in each binding block. + 1. For each `VarBinding(var=var, value=value)`, let the `StructInfo` of `var` be `Svar` and the `StructInfo` of `value` be `Svalue`. If `Svar` is `TensorStructInfo` with a defined `vdevice` `v`, update the `StructInfo` of `value` to `update_struct_info(Svalue, v)`. If `Svalue` is `TensorStructInfo` with a defined `vdevice` `v`, update the `StructInfo` of `var` to `update_struct_info(Svar, v)`. Recurse down `value`. + 2. For each `MatchCast(var=var, value=value, struct_info=S)`, let the `StructInfo` of `var` be `Svar` and the `StructInfo` of `value` be `Svalue`. If `Svar` is `TensorStructInfo` with a defined `vdevice` `v`, update `S` to `update_struct_info(S, v)` and update the `StructInfo` of `value` to `update_struct_info(Svalue, v)`. If `S` is `TensorStructInfo` with a defined `vdevice` `v`, update the `StructInfo` of `var` to `update_struct_info(Svar, v)`. Recurse down `value`. + 3. Finally, recurse down `body`. +4. For all other expressions, recurse down to all child `Expr` nodes, making any changes specified in the above steps. + +The type-checking procedure (the derivation rules) will have to be run again to propagate the `VDevice`s in the "forward" direction and ensure consistency after the `VDevice`s are filled in. + +«After the second run of type-checking, consider it a compilation error if there are any `Call` nodes to `Op`s or `PrimFunc`s remaining where an argument has an undefined `vdevice`.» + ### Note on Proving Shapes Equivalent and Eliminating Dynamic Checks There can be some complexity involved in checking whether two shapes match during shape inference. A very simple, conservative method for determining equality is simply using alpha-equivalence: If the two shapes have the same structure, then they are equivalent. However, this method is conservative and can overlook numerical properties in `PrimExpr`s. We leave it up to compiler implementations as to whether to use more advanced methods for proving equivalence, such as attempting to use algebraic rewrite rules. (As a consequence, portability requires inserting dynamic checks wherever there needs to be a comparison of shapes.) @@ -834,6 +887,10 @@ In the `IRModule`, every mapping of a `GlobalVar` to a `Function` node or a TIR The rules for evaluating `Function` nodes into closures are given below. TIR `PrimFunc`s evaluate into objects that are opaque to Relax, but are also assigned `FuncStructInfo` and can be called like closures. None of the values in global scope is mutable. Execution of a Relax function in an IR module thus begins by evaluating all globally visible functions into a form in which they can be accessed. +## Destination Devices + +«If the `global_info` table in the `IRModule` contains any `VDevice`s, then execution will be distributed across the devices listed. The `to_vdevice` operator is responsible for transferring data (tensors) from one device to another and operators involving these tensors will be implemented on the appropriate device. If no `VDevice`s are listed in the `global_info`, then all (tensor) computations will take place on a single "target" device specified in compilation (if the target is a GPU, then some computations may still take place on a CPU host, but only those concerning metadata or data structures that are not tensors).» + ## Evaluating Expressions For each expression, we define how it affects the program's visible state and the order in which they are evaluated. Below, all evaluation results are passed by reference (and hence possibly alias) unless it is explicitly specified that they allocate new values. @@ -853,7 +910,7 @@ For each expression, we define how it affects the program's visible state and th 3. If `r` is false, evaluate the `false_branch` and return its result. 11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not). Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. 12. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: - 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. It is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the executor implementation). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» + 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. If there are `VDevice`s defined, then (per the `StructInfo` rules for propagating `VDevice` information), all tensor arguments must have a `VDevice` specified in their `StructInfo`; the operator must be implemented on the `VDevice` specified. In all other respects, it is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into language runtime). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. 1. If `op` evaluated to a closure, push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) 2. If `op` evaluated to a `PackedFunc`, simply invoke it. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value. @@ -898,6 +955,7 @@ The above evaluation rules are general, but leave much room for implementations - `f` will be called in destination-passing style, like so: `f(arg1, arg2, ..., argn, shape1, shape2, ..., shapem, r1, r2, ..., rk)`, omitting the `shapei` if `packed_ints` is not given. `f` is expected to mutate *only* the `ri` to give the output of the function, hence `call_tir` is considered pure. - «If the shape or data type of the actual result do not correspond to the `aSi`, an error is issued.» - After the call, the `ri` will be returned (returning `r1` directly if there is only a single result, otherwise returning `Tuple(fields=[r1, r2, ..., rk])`). +- `call_tir_inplace(prim_func, args, inplace_indices, packed_ints, sinfo_args=[aS1, aS2, ..., aSk])`: Behaves similarly to `call_tir`, except the computation will mutate some members of `args` instead of allocating new tensors for all outputs. For each intended output, there must be a corresponding index given in `inplace_indices`: if the index is -1, then that output will be freshly allocated and the `PrimFunc` will take an "output argument" in destination-passing style corresponding to that output; otherwise, the `PrimFunc` will mutate the member of `args` with that index in-place. `prim_func` must be implemented in such a way as to mutate the appropriate arguments directly instead of taking output arguments in destination-passing style. - `call_dps_packed(packed_func, args, sinfo_args=[aS1])`: - `packed_func` must evaluate to a `PackedFunc` object. - `args` must be a tuple; we will call its elements `arg1`, `arg2`, ..., `argn`. @@ -913,5 +971,8 @@ The above evaluation rules are general, but leave much room for implementations - `sinfo_args` must be a non-empty list of `StructInfo`. - The returned value will have the semantics of `Call(func, args, sinfo_args=sinfo_args)`. However, this call will be assumed to be pure (`call_pure_packed`'s `FPurity` is set to `True`), thus allowing the call to appear inside a `DataflowBlock` or a function whose `is_pure` is set to `True`. - Note: This operator is intended to be be used for cases where the user knows that calling the packed function will _in reality_ not cause any side effects. If it is used for a call that _does_ result in side effects, then the compiler may end up removing, reordering, or repeating that call; the specification makes no guarantees about the side effects in the callee in that case. +- `call_inplace_packed(func, args, inplace_indices, sinfo_args)`: Behaves identically to `call_pure_packed`, but `inplace_indices` denote that certain arguments (those with the corresponding indices) are mutated in-place. This is used to indicate intent, so that the compiler can verified that the arguments or any alias of those arguments will not be used after being mutated (this is what allows the operator to be considered pure and used within `DataflowBlock`s even though the `PackedFunc` might have side effects). - `shape_of(t)`: Given a tensor argument `t`, it returns its shape. The return value is a shape object. -- `null_value()`: Returns a null object (treated as `ObjectStructInfo`). This is used for indicating to operators that an optional argument has been omitted. \ No newline at end of file +- `null_value()`: Returns a null object (treated as `ObjectStructInfo`). This is used for indicating to operators that an optional argument has been omitted. +- `hint_on_device(data, device)`: This operator acts as a "hint" to the compiler that `data` should be located on `device`. `device` is not a `VDevice` but is rather a device id (the `dev_id` field on a `VDevice`). This operator is de-sugared into calls to `to_vdevice` if `data` does not have a specified `vdevice` or it differs from `device` (if `data` already matches `device`, then the call is removed entirely). +- `to_vdevice(data, vdevice)`: Move `data` to the `VDevice` corresponding to `vdevice`. If `data` is a tensor, the tensor is copied over to `vdevice`. If `data` is a tuple, all members of the tuple (proceeding recursively for any members that are in turn tuples) are copied over to `vdevice`. \ No newline at end of file From d231e88fed5b58513535828821c20a9ac8121889 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 20 Dec 2023 18:34:01 -0500 Subject: [PATCH 47/47] Typo --- relax_spec.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relax_spec.md b/relax_spec.md index 3cb18baf4359..cc5e4c0bcf6d 100644 --- a/relax_spec.md +++ b/relax_spec.md @@ -910,7 +910,7 @@ For each expression, we define how it affects the program's visible state and th 3. If `r` is false, evaluate the `false_branch` and return its result. 11. The node `ExternFunc(global_symbol)` is evaluated by looking up the global symbol name and returning the `PackedFunc` if it exists (it is an error if it does not). Note that if a TIR `PrimFunc` in the `IRModule` has a global symbol attribute registered, it can be called as an `ExternFunc` using that global symbol as well. 12. The node `Call(op, [arg1, arg2, ..., argn])` is evaluated as follows: - 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. If there are `VDevice`s defined, then (per the `StructInfo` rules for propagating `VDevice` information), all tensor arguments must have a `VDevice` specified in their `StructInfo`; the operator must be implemented on the `VDevice` specified. In all other respects, it is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into language runtime). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» + 1. If `op` is an `Op` node, then evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. If there are `VDevice`s defined, then (per the `StructInfo` rules for propagating `VDevice` information), all tensor arguments must have a `VDevice` specified in their `StructInfo`; the operator must be implemented on the `VDevice` specified. In all other respects, it is up to the compiler implementation to decide how operators should be implemented (some may have an associated `PackedFunc` and others may be built into the language runtime). The operator may mutate its arguments. It is also up to the operator implementation as to whether the result is newly allocated or aliases another value. «(TODO: Once we have operators for logical and AND and OR, we should also define short-circuiting semantics for those.)» 2. Otherwise, first evaluate `op` (it must evaluate to a closure or `PackedFunc`). Next, we evaluate `arg1`, `arg2`, …, `argn` in that order and call the results `a1`, `a2`, …, `an`. 1. If `op` evaluated to a closure, push a new scope onto the stack where arguments `v1`, `v2`, …, `vn` in the closure are bound to `a1`, `a2`, …, and `an`, respectively, and all variables saved in the closure are added to the scope. Evaluate the closure body in this new scope; this will be the return value of the call. Pop the scope before returning the value. (Note that the checking of the structural information of the argument result values and the body values should be done as described in the previous section.) 2. If `op` evaluated to a `PackedFunc`, simply invoke it. `PackedFunc`s may have arbitrary side effect and are responsible for whether the result is a newly allocated value or aliases another value.