From 3b166dbe57b393f155b2e18f0aaf3bfe454d96a1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 30 Nov 2018 23:01:22 +0900 Subject: [PATCH 1/5] add codebase walkthrough doc --- docs/dev/codebase_walkthrough.rst | 227 ++++++++++++++++++++++++++++++ docs/dev/index.rst | 3 +- 2 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 docs/dev/codebase_walkthrough.rst diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst new file mode 100644 index 000000000000..578c5b0c5023 --- /dev/null +++ b/docs/dev/codebase_walkthrough.rst @@ -0,0 +1,227 @@ +======================================= +**TVM Codebase Walkthrough by Example** +======================================= + +Getting to know a new codebase can be a challenge. This is especially true for a codebase like that of TVM, where different components interact in non-obvious ways. In this guide, we try to illustrate the key elements that comprise a compilation pipeline with a simple example. For each important step, we show where in the codebase it is implemented. The purpose is to let new developers and interested users dive into the codebase more quickly. + +******************************************* +Codebase Structure Overview +******************************************* + +At the root of the TVM repository, we have following subdirectories that together comprise a bulk of the codebase. + +- ``src`` - C++ code for operator compilation and deployment runtimes. +- ``python`` - Python frontend that wraps C++ functions and objects implemented in ``src``. +- ``topi`` - Compute definitions and backend schedules for standard neural network operators. +- ``nnvm`` - C++ code and Python frontend for graph optimization and compilation. Depends on three directories above. + +Using standard Deep Learning terminologies, ``nnvm`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructures implemented in ``src`` and ``python``. Operators corresponding to each node are registered in ``nnvm``. Registration can be done via C++ or Python. Implemenations for operators are in ``topi``, and they are also coded in either C++ or Python. + +When an user invokes graph compilation by ``nnvm.compiler.build(...)``, the following sequence of actions happens for each node in the graph: + +- Look up an operator implementation by querying the operator registry +- Generate a compute expression and a schdule for the operator +- Compile the operator into object code + +One of the interesting aspects of TVM codebase is that interop between C++ and Python is not unidirectional. Typically, all code that do heavy liftings are implemented in C++, and Python bindings are provided for user interface. This is also true in TVM, but in TVM codebase, C++ code also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in nnvm. + +At the time of writing (Nov. 30, 2018), there is an going effort to reimplement functionality offered by ``nnvm`` in a new intermidiate representation called Relay. New Relay code resides in ``src/relay`` and ``python/tvm/relay``. + +This guide focuses on contents in ``src`` and ``python`` subdirectories. We may cover ``topi`` and ``nnvm`` in another documents. + +******************************************* +Vector Add Example +******************************************* + +We use a simple example that does not use topi or nnvm. The example is vector addition, which is covered in detail in `this tutorial `_. + +:: + + n = 1024 + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") + +Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlining C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is an subclass of ``NodeBase``. + +:: + + @register_node + class Tensor(NodeBase, _expr.ExprOp): + """Tensor object, to construct, see function.Tensor""" + + def __call__(self, *indices): + ... + +The Node system is the basis of exposing C++ types to frontend languages, including Python. The way TVM implements Python wrapping is not straightforward. It is briefly covered in `this document `_, and details are in ``python/tvm/_ffi/`` if you are interested. + +``Tensor`` is created by functions in ``python/tvm/api.py``, which in turn calls into C++ functions exposed in ``src/api/api_lang.cc``. All C++ functions that are callable from Python are exposed in the ``src/api`` subdirectory. For example, the ``tvm.compute()`` function above calls into ``_ComputeOp`` api exposed in ``src/api/api_lang.cc``: + +:: + + TVM_REGISTER_API("_ComputeOp") + .set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ComputeOpNode::make(args[0], + args[1], + args[2], + args[3], + args[4]); + }); + +We use ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of `PackedFunc `_. ``PackedFunc`` is another mechanism by which TVM implements C++ and Python interop. In particular, this is what makes calling Python functions from the C++ codebase very easy. + +A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensors`` to it. This way we can keep track of dependencies between ``Operation``. + +We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/schedule.py``. + +:: + + s = tvm.create_schedule(C.op) + +This function is mapped to the C++ function in ``include/tvm/schedule.h``. + +:: + + inline Schedule create_schedule(Array ops) { + return ScheduleNode::make(ops); + } + +``Schedule`` consists of collections of ``Stage`` and output ``Operation``. + +``Stage`` corresponds to one ``Operation``. In the vector add example above, there are two placeholder ops and one compute op, so the schedule ``s`` contains three stages. Each ``Stage`` holds information about a loop nest structure, types of each loop (``Parallel``, ``Vectorized``, ``Unrolled``), and where to execute its computation in the loop nest of the next ``Stage``, if any. + +``Schedule`` and ``Stage`` are defined in ``tvm/python/schedule.py``, ``include/tvm/schedule.h``, and ``src/schedule/schedule_ops.cc``. + +To keep it simple, we call ``tvm.build(...)`` on the default schedule created by ``create_schedule()`` function above. + +:: + + target = "cuda" + fadd = tvm.build(s, [A, B, C], target) + +``tvm.build()``, defined in ``python/tvm/build_module.py``, takes a schedule, input and output ``Tensor``, and a target, and returns a ``tvm.Module`` object, defined in ``python/tvm/module.py``. A ``Module`` object contains a compiled function which can be invoked with function call syntax. + +The process of ``tvm.build()`` can be divided into two steps: + +- Lowering, where an high level, initial loop nest structures are transformed into a final, low level IR +- Code generation, where target machine code is generated from the low level IR + +Lowering is done by ``tvm.lower()`` function, defined in ``python/tvm/build_module.py``. First, bound inference is peformed, and an initial loop nest structure is created. + +:: + + def lower(sch, + args, + name="default_function", + binds=None, + simple_mode=False): + ... + bounds = schedule.InferBound(sch) + stmt = schedule.ScheduleOps(sch, bounds) + ... + +Bound inference is a process where all loop bounds and sizes of intermidiate buffers are inferred. If you target the CUDA backend and you use shared memory, its minimum size is automatically determined here. Bound inference is implemented in ``src/schedule/bound.cc``, ``src/schedule/graph.cc`` and ``src/schedule/message_passing.cc``. + +``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects that changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``. + +Next, we apply a number of lowering passes to ``stmt``. These passes are implemented in ``src/pass`` subdirectory. For example, if you have applied ``vectorize`` or ``unroll`` primitives to your schedule, they are applied in loop vectorization and unrolling passes below. + +:: + + ... + stmt = ir_pass.VectorizeLoop(stmt) + ... + stmt = ir_pass.UnrollLoop( + stmt, + cfg.auto_unroll_max_step, + cfg.auto_unroll_max_depth, + cfg.auto_unroll_max_extent, + cfg.unroll_explicit) + ... + +After lowering is done, ``build()`` function generates target machine code from the lowered function. This code can contain SSE or AVX instructions if you target x86, or PTX instructions for CUDA target. In addition to target specific machine code, TVM also generates host side code that is responsible for memory management, kernel launch etc. + +Code generation is done by ``build_module()`` function, defined in ``python/tvm/codege.py``. On the C++ side, code generation is implemented in ``src/codegen`` subdirectory. ``build_module()`` Python function will reach ``Build()`` function below in ``src/codegen/codegen.cc``: + +:: + + runtime::Module Build(const Array& funcs, + const std::string& target) { + std::string build_f_name = "codegen.build_" + target; + const PackedFunc* bf = runtime::Registry::Get(build_f_name); + runtime::Module m = (*bf)(funcs, target); + return m; + } + + +``Build()`` function looks up code generators for a particular target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: + +:: + + TVM_REGISTER_API("codegen.build_cuda") + .set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = BuildCUDA(args[0]); + }); + +``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code. + +``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlining target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, CUDA backends implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages CUDA driver API. ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. + +The returned module, which can be thought of as a combination of a compiled function and a device API, can be invoked on TVM's NDArray objects. + +:: + + ctx = tvm.context(target, 0) + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + fadd(a, b, c) + output = c.asnumpy() + +Under the hood, TVM allocates device memory and manages memory transfer automatically. To do that, each backend needs to subclass ``DeviceAPI`` class, defined in ``include/tvm/runtime/device_api.h``, and override memory management methods to use device specific API. For example, the CUDA backend implements ``CUDADeviceAPI`` in ``src/runtime/cuda/cuda_device_api.cc`` to use ``cudaMalloc``, ``cudaMemcpy`` etc. + +The first time you invoke the compiled module with ``fadd(a, b, c)``, ``GetFunction()`` method of ``ModuleNode`` is called to get a ``PackedFunc`` that can be used for a kernel call. For example, in ``src/runtime/cuda/cuda_module.cc`` the CUDA backend implements ``CUDAModuleNode::GetFunction()`` like this: + +:: + + PackedFunc CUDAModuleNode::GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) { + auto it = fmap_.find(name); + const FunctionInfo& info = it->second; + CUDAWrappedFunc f; + f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags); + return PackFuncVoidAddr(f, info.arg_types); + } + +The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn calls ``operator()`` of ``CUDAWrappedFunc`` in ``src/runtime/cuda/cuda_module.cc``, where finally we see the ``cuLaunchKernel`` driver call: + +:: + + class CUDAWrappedFunc { + public: + void Init(...) + ... + void operator()(TVMArgs args, + TVMRetValue* rv, + void** void_args) const { + int device_id; + CUDA_CALL(cudaGetDevice(&device_id)); + if (fcache_[device_id] == nullptr) { + fcache_[device_id] = m_->GetFunc(device_id, func_name_); + } + CUstream strm = static_cast(CUDAThreadEntry::ThreadLocal()->stream); + ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); + CUresult result = cuLaunchKernel( + fcache_[device_id], + wl.grid_dim(0), + wl.grid_dim(1), + wl.grid_dim(2), + wl.block_dim(0), + wl.block_dim(1), + wl.block_dim(2), + 0, strm, void_args, 0); + } + }; + +This concludes an overview of how TVM compiles and executes a function. You are encouraged to dive into the details of the codebase. diff --git a/docs/dev/index.rst b/docs/dev/index.rst index 2734a816dc68..3f4944fe1d52 100644 --- a/docs/dev/index.rst +++ b/docs/dev/index.rst @@ -13,4 +13,5 @@ In this part of documentation, we share the rationale for the specific choices m nnvm_overview hybrid_script relay_intro - relay_add_op \ No newline at end of file + relay_add_op + codebase_walkthrough From b9db4c979e2d46642fb11c95387941368c592ad6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 1 Dec 2018 16:44:03 +0900 Subject: [PATCH 2/5] fix per review comment --- docs/dev/codebase_walkthrough.rst | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 578c5b0c5023..54dab44340fd 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -15,9 +15,9 @@ At the root of the TVM repository, we have following subdirectories that togethe - ``topi`` - Compute definitions and backend schedules for standard neural network operators. - ``nnvm`` - C++ code and Python frontend for graph optimization and compilation. Depends on three directories above. -Using standard Deep Learning terminologies, ``nnvm`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructures implemented in ``src`` and ``python``. Operators corresponding to each node are registered in ``nnvm``. Registration can be done via C++ or Python. Implemenations for operators are in ``topi``, and they are also coded in either C++ or Python. +Using standard Deep Learning terminologies, ``nnvm`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructures implemented in ``src`` and ``python``. Operators corresponding to each node are registered in ``nnvm``. Registration can be done via C++ or Python. Implementations for operators are in ``topi``, and they are also coded in either C++ or Python. -When an user invokes graph compilation by ``nnvm.compiler.build(...)``, the following sequence of actions happens for each node in the graph: +When a user invokes graph compilation by ``nnvm.compiler.build(...)``, the following sequence of actions happens for each node in the graph: - Look up an operator implementation by querying the operator registry - Generate a compute expression and a schdule for the operator @@ -25,7 +25,7 @@ When an user invokes graph compilation by ``nnvm.compiler.build(...)``, the foll One of the interesting aspects of TVM codebase is that interop between C++ and Python is not unidirectional. Typically, all code that do heavy liftings are implemented in C++, and Python bindings are provided for user interface. This is also true in TVM, but in TVM codebase, C++ code also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in nnvm. -At the time of writing (Nov. 30, 2018), there is an going effort to reimplement functionality offered by ``nnvm`` in a new intermidiate representation called Relay. New Relay code resides in ``src/relay`` and ``python/tvm/relay``. +At the time of writing (Nov. 30, 2018), there is an ongoing effort to reimplement functionality offered by ``nnvm`` in a new intermediate representation called Relay. New Relay code resides in ``src/relay`` and ``python/tvm/relay``. This guide focuses on contents in ``src`` and ``python`` subdirectories. We may cover ``topi`` and ``nnvm`` in another documents. @@ -42,7 +42,7 @@ We use a simple example that does not use topi or nnvm. The example is vector ad B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") -Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlining C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is an subclass of ``NodeBase``. +Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlining C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``. :: @@ -70,7 +70,7 @@ The Node system is the basis of exposing C++ types to frontend languages, includ We use ``TVM_REGISTER_*`` macro to expose C++ functions to frontend languages, in the form of `PackedFunc `_. ``PackedFunc`` is another mechanism by which TVM implements C++ and Python interop. In particular, this is what makes calling Python functions from the C++ codebase very easy. -A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensors`` to it. This way we can keep track of dependencies between ``Operation``. +A ``Tensor`` object has an ``Operation`` object associated with it, defined in ``python/tvm/tensor.py``, ``include/tvm/operation.h``, and ``src/tvm/op`` subdirectory. A ``Tensor`` is an output of its ``Operation`` object. Each ``Operation`` object has in turn ``input_tensors()`` method, which returns a list of input ``Tensor`` to it. This way we can keep track of dependencies between ``Operation``. We pass the operation corresponding to the output tensor ``C`` to ``tvm.create_schedule()`` function in ``python/tvm/schedule.py``. @@ -103,7 +103,7 @@ To keep it simple, we call ``tvm.build(...)`` on the default schedule created by The process of ``tvm.build()`` can be divided into two steps: -- Lowering, where an high level, initial loop nest structures are transformed into a final, low level IR +- Lowering, where a high level, initial loop nest structures are transformed into a final, low level IR - Code generation, where target machine code is generated from the low level IR Lowering is done by ``tvm.lower()`` function, defined in ``python/tvm/build_module.py``. First, bound inference is peformed, and an initial loop nest structure is created. @@ -154,7 +154,7 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/ } -``Build()`` function looks up code generators for a particular target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: +``Build()`` function looks up the code generator for the given target in the ``PackedFunc`` registry, and invokes the function found. For example, ``codegen.build_cuda`` function is registered in ``src/codegen/build_cuda_on.cc``, like this: :: @@ -165,7 +165,7 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/ ``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code. -``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlining target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, CUDA backends implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages CUDA driver API. ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. +``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlining target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages CUDA driver API. ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. The returned module, which can be thought of as a combination of a compiled function and a device API, can be invoked on TVM's NDArray objects. @@ -178,7 +178,7 @@ The returned module, which can be thought of as a combination of a compiled func fadd(a, b, c) output = c.asnumpy() -Under the hood, TVM allocates device memory and manages memory transfer automatically. To do that, each backend needs to subclass ``DeviceAPI`` class, defined in ``include/tvm/runtime/device_api.h``, and override memory management methods to use device specific API. For example, the CUDA backend implements ``CUDADeviceAPI`` in ``src/runtime/cuda/cuda_device_api.cc`` to use ``cudaMalloc``, ``cudaMemcpy`` etc. +Under the hood, TVM allocates device memory and manages memory transfers automatically. To do that, each backend needs to subclass ``DeviceAPI`` class, defined in ``include/tvm/runtime/device_api.h``, and override memory management methods to use device specific API. For example, the CUDA backend implements ``CUDADeviceAPI`` in ``src/runtime/cuda/cuda_device_api.cc`` to use ``cudaMalloc``, ``cudaMemcpy`` etc. The first time you invoke the compiled module with ``fadd(a, b, c)``, ``GetFunction()`` method of ``ModuleNode`` is called to get a ``PackedFunc`` that can be used for a kernel call. For example, in ``src/runtime/cuda/cuda_module.cc`` the CUDA backend implements ``CUDAModuleNode::GetFunction()`` like this: From 36ffabde487ec9efeb9b02bfe64b5aa00087ce9a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Dec 2018 22:08:23 +0900 Subject: [PATCH 3/5] fix per review --- docs/dev/codebase_walkthrough.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 54dab44340fd..7587380a3dd3 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -42,7 +42,7 @@ We use a simple example that does not use topi or nnvm. The example is vector ad B = tvm.placeholder((n,), name='B') C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") -Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlining C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``. +Here, types of ``A``, ``B``, ``C`` are ``tvm.tensor.Tensor``, defined in ``python/tvm/tensor.py``. The Python ``Tensor`` is backed by C++ ``Tensor``, implemented in ``include/tvm/tensor.h`` and ``src/lang/tensor.cc``. All Python types in TVM can be thought of as a handle to the underlying C++ type with the same name. If you look at the definition of Python ``Tensor`` type below, you can see it is a subclass of ``NodeBase``. :: From 2c59243202e2de535f7e6f15cca150de9ba17569 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Dec 2018 23:21:46 +0900 Subject: [PATCH 4/5] more fix --- docs/dev/codebase_walkthrough.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 7587380a3dd3..9ac438762281 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -165,7 +165,7 @@ Code generation is done by ``build_module()`` function, defined in ``python/tvm/ ``BuildCUDA()`` above generates CUDA kernel source from the lowered IR using ``CodeGenCUDA`` class defined in ``src/codegen/codegen_cuda.cc``, and compile the kernel using NVRTC. If you target a backend that uses LLVM, which includes x86, ARM, NVPTX and AMDGPU, code generation is done primarily by ``CodeGenLLVM`` class defined in ``src/codegen/llvm/codegen_llvm.cc``. ``CodeGenLLVM`` translates TVM IR into LLVM IR, runs a number of LLVM optimization passes, and generates target machine code. -``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlining target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages CUDA driver API. ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. +``Build()`` function in ``src/codegen/codegen.cc`` returns a ``runtime::Module`` object, defined in ``include/tvm/runtime/module.h`` and ``src/runtime/module.cc``. A ``Module`` object is a container for the underlying target specific ``ModuleNode`` object. Each backend implements a subclass of ``ModuleNode`` to add target specific runtime API calls. For example, the CUDA backend implements ``CUDAModuleNode`` class in ``src/runtime/cuda/cuda_module.cc``, which manages CUDA driver API. ``BuildCUDA()`` function above wraps ``CUDAModuleNode`` with ``runtime::Module`` and return it to the Python side. The LLVM backend implements ``LLVMModuleNode`` in ``src/codegen/llvm/llvm_module.cc``, which handles JIT execution of compiled code. Other subclasses of ``ModuleNode`` can be found under subdirectories of ``src/runtime`` corresponding to each backend. The returned module, which can be thought of as a combination of a compiled function and a device API, can be invoked on TVM's NDArray objects. From c4aa14904214c57c64b83ef33cbf8c1dcc0521ba Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 17 Dec 2018 22:11:11 +0900 Subject: [PATCH 5/5] replace nnvm with relay --- docs/dev/codebase_walkthrough.rst | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/docs/dev/codebase_walkthrough.rst b/docs/dev/codebase_walkthrough.rst index 9ac438762281..6f5cff8a06d6 100644 --- a/docs/dev/codebase_walkthrough.rst +++ b/docs/dev/codebase_walkthrough.rst @@ -11,29 +11,28 @@ Codebase Structure Overview At the root of the TVM repository, we have following subdirectories that together comprise a bulk of the codebase. - ``src`` - C++ code for operator compilation and deployment runtimes. +- ``src/relay`` - Implementation of Relay, a new IR for deep learning framework superseding ``nnvm`` below. - ``python`` - Python frontend that wraps C++ functions and objects implemented in ``src``. - ``topi`` - Compute definitions and backend schedules for standard neural network operators. -- ``nnvm`` - C++ code and Python frontend for graph optimization and compilation. Depends on three directories above. +- ``nnvm`` - C++ code and Python frontend for graph optimization and compilation. After the introduction of Relay, it remains in the codebase for backward compatibility. -Using standard Deep Learning terminologies, ``nnvm`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructures implemented in ``src`` and ``python``. Operators corresponding to each node are registered in ``nnvm``. Registration can be done via C++ or Python. Implementations for operators are in ``topi``, and they are also coded in either C++ or Python. +Using standard Deep Learning terminologies, ``src/relay`` is the component that manages a computational graph, and nodes in a graph are compiled and executed using infrastructures implemented in the rest of ``src``. ``python`` provides python bindings for the C++ API and driver code that users can use to execute compilation. Operators corresponding to each node are registered in ``src/relay/op``. Implementations for operators are in ``topi``, and they are coded in either C++ or Python. -When a user invokes graph compilation by ``nnvm.compiler.build(...)``, the following sequence of actions happens for each node in the graph: +Relay is the new IR for deep networks that is intended to replace NNVM. If you have used NNVM, Relay provides equivalent or better functionalities. In fact, Relay goes beyond a traditional way of thinking deep networks in terms of computational graphs. But for the purpose of this document, we can think of Relay as a traditional computational graph framework. You can read more about Relay `here `_. + +When a user invokes graph compilation by ``relay.build(...)`` (or ``nnvm.compiler.build(...)`` for the older API), the following sequence of actions happens for each node in the graph: - Look up an operator implementation by querying the operator registry - Generate a compute expression and a schdule for the operator - Compile the operator into object code -One of the interesting aspects of TVM codebase is that interop between C++ and Python is not unidirectional. Typically, all code that do heavy liftings are implemented in C++, and Python bindings are provided for user interface. This is also true in TVM, but in TVM codebase, C++ code also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in nnvm. - -At the time of writing (Nov. 30, 2018), there is an ongoing effort to reimplement functionality offered by ``nnvm`` in a new intermediate representation called Relay. New Relay code resides in ``src/relay`` and ``python/tvm/relay``. - -This guide focuses on contents in ``src`` and ``python`` subdirectories. We may cover ``topi`` and ``nnvm`` in another documents. +One of the interesting aspects of TVM codebase is that interop between C++ and Python is not unidirectional. Typically, all code that do heavy liftings are implemented in C++, and Python bindings are provided for user interface. This is also true in TVM, but in TVM codebase, C++ code also call into functions defined in a Python module. For example, the convolution operator is implemented in Python, and its implementation is invoked from C++ code in Relay. ******************************************* Vector Add Example ******************************************* -We use a simple example that does not use topi or nnvm. The example is vector addition, which is covered in detail in `this tutorial `_. +We use a simple example that uses the low level TVM API directly. The example is vector addition, which is covered in detail in `this tutorial `_. :: @@ -120,7 +119,7 @@ Lowering is done by ``tvm.lower()`` function, defined in ``python/tvm/build_modu stmt = schedule.ScheduleOps(sch, bounds) ... -Bound inference is a process where all loop bounds and sizes of intermidiate buffers are inferred. If you target the CUDA backend and you use shared memory, its minimum size is automatically determined here. Bound inference is implemented in ``src/schedule/bound.cc``, ``src/schedule/graph.cc`` and ``src/schedule/message_passing.cc``. +Bound inference is the process where all loop bounds and sizes of intermidiate buffers are inferred. If you target the CUDA backend and you use shared memory, its required minimum size is automatically determined here. Bound inference is implemented in ``src/schedule/bound.cc``, ``src/schedule/graph.cc`` and ``src/schedule/message_passing.cc``. ``stmt``, which is the output of ``ScheduleOps()``, represents an initial loop nest structure. If you have applied ``reorder`` or ``split`` primitives to your schedule, then the initial loop nest already reflects that changes. ``ScheduleOps()`` is defined in ``src/schedule/schedule_ops.cc``. @@ -224,4 +223,4 @@ The ``PackedFunc``'s overloaded ``operator()`` will be called, which in turn cal } }; -This concludes an overview of how TVM compiles and executes a function. You are encouraged to dive into the details of the codebase. +This concludes an overview of how TVM compiles and executes a function. Although we did not detail TOPI or Relay, at the end all neural network operators go through the same compilation process as above. You are encouraged to dive into the details of the rest of the codebase.