From 0eb125c301df26185b81f30ced3fc36fe5aa1d41 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 6 Oct 2022 15:40:13 -0700 Subject: [PATCH 1/5] [ROOFLINE] Add support for different dtypes Support different dtypes in roofline analysis. Only x86 support for now, but the interface is there to add support on cuda. --- python/tvm/utils/roofline/__init__.py | 64 +++---- python/tvm/utils/roofline/cuda.py | 161 ++++++++++++---- python/tvm/utils/roofline/registry.py | 46 ++++- python/tvm/utils/roofline/x86.py | 254 ++++++++++++++++--------- tests/python/unittest/test_roofline.py | 121 ++++++------ 5 files changed, 412 insertions(+), 234 deletions(-) diff --git a/python/tvm/utils/roofline/__init__.py b/python/tvm/utils/roofline/__init__.py index 0affb0704997..3b0144cb90e8 100644 --- a/python/tvm/utils/roofline/__init__.py +++ b/python/tvm/utils/roofline/__init__.py @@ -15,20 +15,20 @@ # specific language governing permissions and limitations # under the License. """Utilities for computing an approximate roofline model""" -from typing import Dict, Union, Optional +from typing import Dict, Optional, Union + import numpy as np -from ... import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func -from ...target import Target -from ...runtime import profiler_vm, profiling, Device, num_threads -from ...script import tir as T -from ...ir.instrument import pass_instrument +from ... import IRModule, auto_scheduler, build, get_global_func, nd, relay, tir, topi, transform +from ...contrib import utils from ...ir.expr import GlobalVar +from ...ir.instrument import pass_instrument from ...rpc.base import RPC_SESS_MASK from ...rpc.client import RPCSession -from ...contrib import utils - -from . import registry, cuda, x86 +from ...runtime import Device, num_threads, profiler_vm, profiling +from ...script import tir as T +from ...target import Target +from . import cuda, registry, x86 def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=None): @@ -131,14 +131,9 @@ def roofline_from_existing( :py:func:`roofline_analysis` for more information on which metrics are included. """ - with target: - peak_bandwidth = registry.estimate_peak_bandwidth(target, dev, remote) - peak_flops = registry.estimate_peak_flops(target, dev, remote) - - ridge_point = peak_flops / peak_bandwidth all_features = { - prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim)) + prim.attrs["hash"]: (name, prim, auto_scheduler.feature.named_features_from_primfunc(prim)) for name, prim in tir_functions.items() if isinstance(prim, tir.PrimFunc) and "hash" in prim.attrs.keys() } @@ -146,28 +141,17 @@ def roofline_from_existing( new_calls = [] for call in report.calls: if "Hash" in call.keys() and call["Hash"] in all_features: - _, features = all_features[call["Hash"]] - - flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"]) - loaded_bytes = 0.0 - # assume no more than 100 buffers - for i in range(100): - if str(target.kind) == "cuda": - # autoscheduler features do not take into account that 1. - # global and shared memory have very different performance - # characteristics -- both are included in the same bytes - # touched count 2. multiple threads accessing the same byte - # of memory does not use the same amount of bandwidth as - # multiple threads accessing different bytes of memory. We - # use unique bytes accessed here to avoid these two issues, - # but this does bias results towards being more compute - # bound. - key = f"B{i}.unique_bytes" - else: - key = f"B{i}.bytes" - if not key in features.keys(): - break - loaded_bytes += np.sum(features[key]) + _, prim, features = all_features[call["Hash"]] + + with target: + flops, peak_flops, flops_name = registry.estimate_peak_flops( + prim, features, target, dev, remote + ) + loaded_bytes, peak_bandwidth, bandwidth_name = registry.estimate_peak_bandwidth( + prim, features, target, dev, remote + ) + ridge_point = peak_flops / peak_bandwidth + runtime = call["Duration (us)"].microseconds * 1e-6 arith_inten = flops / loaded_bytes call = dict(call) @@ -188,8 +172,10 @@ def roofline_from_existing( else: new_calls.append(call) new_configuration = dict(report.configuration.items()) - new_configuration["Estimated Peak FLOP/s"] = profiling.Ratio(peak_flops) - new_configuration["Estimated Peak Bandwidth (byte/second)"] = profiling.Ratio(peak_bandwidth) + new_configuration[f"Estimated Peak FLOP/s ({flops_name})"] = profiling.Ratio(peak_flops) + new_configuration[ + f"Estimated Peak Bandwidth ({bandwidth_name}, byte/second)" + ] = profiling.Ratio(peak_bandwidth) return profiling.Report(new_calls, report.device_metrics, new_configuration) diff --git a/python/tvm/utils/roofline/cuda.py b/python/tvm/utils/roofline/cuda.py index f5a3f5e1dde9..e33978b7d500 100644 --- a/python/tvm/utils/roofline/cuda.py +++ b/python/tvm/utils/roofline/cuda.py @@ -15,25 +15,30 @@ # specific language governing permissions and limitations # under the License. """Estimation of peak flops and memory bandwidth for cuda devices""" -from typing import Optional -from ...script import tir as T -from ... import nd, build, transform -from ...runtime import Device -from ...target import Target +import functools +from typing import Dict, Optional, Tuple + +import numpy as np + +from ... import build, nd, transform +from ...contrib import nvcc, utils from ...rpc.base import RPC_SESS_MASK from ...rpc.client import RPCSession +from ...runtime import Device +from ...script import tir as T +from ...target import Target +from ...tir import PrimFunc from . import registry -from ...contrib import utils, nvcc -@registry.estimate_peak_flops.register("cuda") +@functools.lru_cache(maxsize=None) def estimate_peak_flops_tensorcore( target: Target, dev: Device, remote: Optional[RPCSession], mat_dtype: str = "float16", acc_dtype: str = "float32", -) -> float: +) -> Tuple[float, float, str]: """Estimate the peak FLOP/s of a cuda device with tensorcores. This estimate should only be used to compare with operators that can use @@ -64,12 +69,11 @@ def estimate_peak_flops_tensorcore( Returns ------- - float + peak_flops : float Approximate sustained FLOP/s of this target/device combo assuming mma instructions. Addition and multiplications are each counted as separate FLOPs. """ - assert str(target.kind) == "cuda", "Only CUDA devices have tensorcores" @T.prim_func def peak_flops_tensorcore_tir( @@ -161,6 +165,51 @@ def peak_flops_tensorcore_tir( return n * 16 * 16 * 16 * 2 * sms * 8 / times.min +@registry.estimate_peak_flops.register("cuda") +def estimate_peak_flops( + func: PrimFunc, # pylint: disable=unused-argument + features: Dict[str, np.ndarray], + target: Target, + dev: Device, + remote: Optional[RPCSession], +) -> Tuple[float, float, str]: + """Estimate the peak FLOP/s of a cuda device. + + Parameters + ---------- + func : PrimFunc + Function to estimate peak flops for. Used to check if a specific kind + intrinsic or dtype could be used with this function. + features : Dict[str, np.ndarry] + Features extracted from `func`. Used to check if a specific kind + intrinsic or dtype could be used with this function. + target : Target + Target to run on. This should be as specific to the actual hardware as + possible. + dev : Device + Device to run on. + remote : Optional[RPCSession] + Remote session used to upload artifacts for runtime evaluation. Must be + the same session used to create `dev`. + + Returns + ------- + flops : float + Estimated number of flops used by `func`. + peak_flops : float + Approximate sustained FLOP/s of this target/device combo. Addition and + multiplications are each counted as separate FLOPs. + name : str + Dtype/intrinsic used by `func` to achieve peak flops. + """ + assert nvcc.have_tensorcore( + dev.compute_version + ), "CUDA roofline only works with devices that have tensorcores" + flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"]) + peak_flops = estimate_peak_flops_tensorcore(target, dev, remote) + return flops, peak_flops, "float16 tensorcore" + + @T.prim_func def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.int32) -> None: # pylint: disable=invalid-name, missing-function-docstring @@ -178,37 +227,13 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.i B[i, l, j] += A[i, k, l, j] -@registry.estimate_peak_bandwidth.register("cuda") -def estimate_peak_bandwidth( +@functools.lru_cache(maxsize=None) +def estimate_peak_bandwidth_global( target: Target, dev: Device, remote: Optional[RPCSession] = None, -) -> float: - """Estimate peak memory bandwidth of a target/device combo. - - Peak bandwidth is estimated by running a small experiment on the underlying - hardware. The peak bandwidth measurement assumes that vector instructions - are being used to load the data. - - Parameters - ---------- - target : Target - Target to use for measurement. This target should be as specific to the - underlying hardware as possible. - dev : Device - Device to measure peak bandwidth on. - remote : Optional[RPCSession] - Remote session used to upload artifacts for runtime evaluation. Must be - the same session used to create `dev`. - - Returns - ------- - float - Peak memory bandwidth in bytes/seconds. - """ - assert nvcc.have_tensorcore( - dev.compute_version - ), "CUDA roofline only works with devices that have tensorcores" +) -> Tuple[float, float, str]: + """Estimate peak bandwidth of global memory. See estimate_peak_bandwidth""" warp_size = dev.warp_size # These sizes seem large enough to give the card time to hit a fixpoint on memory bandwidth blocks = 1024 @@ -234,3 +259,63 @@ def estimate_peak_bandwidth( b = nd.empty((blocks, 4, warp_size), dtype="float32", device=dev) times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b) return a.numpy().size * 4 / times.min # 4 bytes per float32 + + +@registry.estimate_peak_bandwidth.register("cuda") +def estimate_peak_bandwidth( + func: PrimFunc, # pylint: disable=unused-argument + features: Dict[str, np.ndarray], + target: Target, + dev: Device, + remote: Optional[RPCSession] = None, +) -> Tuple[float, float, str]: + """Estimate peak memory bandwidth of a target/device combo. + + Peak bandwidth is estimated by running a small experiment on the underlying + hardware. The peak bandwidth measurement assumes that vector instructions + are being used to load the data. + + Parameters + ---------- + func : PrimFunc + Function to estimate peak bandwidth for. Used to check if a specific + kind of memory could be used with this function. + features : Dict[str, np.ndarry] + Features extracted from `func`. Used to check if a specific kind of + memory could be used with this function. + target : Target + Target to use for measurement. This target should be as specific to the + underlying hardware as possible. + dev : Device + Device to measure peak bandwidth on. + remote : Optional[RPCSession] + Remote session used to upload artifacts for runtime evaluation. Must be + the same session used to create `dev`. + + Returns + ------- + loaded_bytes : float + Estimated bytes loaded by `func`. + peak_bandwidth : float + Peak memory bandwidth in bytes/seconds. + name : str + Name of the memory being used. + """ + loaded_bytes = 0.0 + # assume no more than 100 buffers + for i in range(100): + # autoscheduler features do not take into account that 1. + # global and shared memory have very different performance + # characteristics -- both are included in the same bytes + # touched count 2. multiple threads accessing the same byte + # of memory does not use the same amount of bandwidth as + # multiple threads accessing different bytes of memory. We + # use unique bytes accessed here to avoid these two issues, + # but this does bias results towards being more compute + # bound. + key = f"B{i}.unique_bytes" + if not key in features.keys(): + break + loaded_bytes += np.sum(features[key]) + peak_bandwidth = estimate_peak_bandwidth_global(target, dev, remote) + return loaded_bytes, peak_bandwidth, "global" diff --git a/python/tvm/utils/roofline/registry.py b/python/tvm/utils/roofline/registry.py index b3ea522be899..9358529b38ec 100644 --- a/python/tvm/utils/roofline/registry.py +++ b/python/tvm/utils/roofline/registry.py @@ -15,18 +15,24 @@ # specific language governing permissions and limitations # under the License. """Definition of generic functions for estimating peak flops and bandwidth""" -from typing import Optional -from ...target import Target, generic_func -from ...runtime import Device +from typing import Dict, Optional, Tuple + +import numpy as np + from ...rpc.client import RPCSession +from ...runtime import Device +from ...target import Target, generic_func +from ...tir import PrimFunc @generic_func def estimate_peak_bandwidth( + func: PrimFunc, + features: Dict[str, np.ndarray], target: Target, dev: Device, remote: Optional[RPCSession] = None, -) -> float: +) -> Tuple[float, float, str]: """Estimate peak memory bandwidth of a target/device combo. Peak bandwidth is estimated by running a small experiment on the underlying @@ -35,6 +41,12 @@ def estimate_peak_bandwidth( Parameters ---------- + func : PrimFunc + Function to estimate peak bandwidth for. Used to check if a specific + kind of memory could be used with this function. + features : Dict[str, np.ndarry] + Features extracted from `func`. Used to check if a specific kind of + memory could be used with this function. target : Target Target to use for measurement. This target should be as specific to the underlying hardware as possible. @@ -46,18 +58,24 @@ def estimate_peak_bandwidth( Returns ------- - float + loaded_bytes : float + Estimated bytes loaded by `func`. + peak_bandwidth : float Peak memory bandwidth in bytes/seconds. + name : str + Name of the memory being used. """ raise NotImplementedError() @generic_func def estimate_peak_flops( + func: PrimFunc, + features: Dict[str, np.ndarray], target: Target, dev: Device, remote: Optional[RPCSession], -) -> float: +) -> Tuple[float, float, str]: """ Estimate the maximum number of FLOP/s this target/device combo is capable of reaching by running a test program. This is a generic function that @@ -65,6 +83,12 @@ def estimate_peak_flops( Parameters ---------- + func : PrimFunc + Function to estimate peak flops for. Used to check if a specific kind + intrinsic or dtype could be used with this function. + features : Dict[str, np.ndarry] + Features extracted from `func`. Used to check if a specific kind + intrinsic or dtype could be used with this function. target : Target Target to run on. This should be as specific to the actual hardware as possible to make sure that LLVM generates the best vector code. @@ -76,8 +100,12 @@ def estimate_peak_flops( Returns ------- - float - Approximate sustained FLOP/s of this target/device combo. Each FMA - operation counts as two FLOPs. + flops : float + Estimated number of flops used by `func`. + peak_flops : float + Approximate sustained FLOP/s of this target/device combo assuming + vectorized FMA instructions. Each FMA operation counts as two FLOPs. + name : str + Dtype/intrinsic used by `func` to achieve peak flops. """ raise NotImplementedError() diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py index d4a0e511848f..55435fba4c8a 100644 --- a/python/tvm/utils/roofline/x86.py +++ b/python/tvm/utils/roofline/x86.py @@ -15,15 +15,19 @@ # specific language governing permissions and limitations # under the License. """Estimate peak flops and bandwidth for x86 devices""" -from typing import Optional +import functools +from typing import Dict, Optional, Tuple -from ... import nd, build, topi, transform, get_global_func -from ...target import Target -from ...runtime import Device, num_threads -from ...script import tir as T +import numpy as np + +from ... import build, get_global_func, nd, topi, transform +from ...contrib import utils from ...rpc.base import RPC_SESS_MASK from ...rpc.client import RPCSession -from ...contrib import utils +from ...runtime import DataType, Device, num_threads +from ...script import tir as T +from ...target import Target +from ...tir import PrimFunc from . import registry @@ -44,7 +48,7 @@ def _detect_vec_width_registers( Returns ------- vec_width: int - Width of a vector register on `target`. + Width of a vector register on `target` in bytes. num_vector_registers: int Number of vector registers on `target`. """ @@ -57,7 +61,7 @@ def _detect_vec_width_registers( and target.keys[0] == "cpu" ): with target: - vec_width = topi.x86.utils.get_simd_32bit_lanes() # in number of float32s + vec_width = topi.x86.utils.get_simd_32bit_lanes() * 4 # in number of bytes else: raise RuntimeError(f"Cannot determine vector width for target {target}") if num_vector_registers is None: @@ -68,66 +72,41 @@ def _detect_vec_width_registers( return vec_width, num_vector_registers -@T.prim_func -def peakflops_fma_tir( - a: T.handle, - vec_width: T.int32, - iters: T.int32, - num_vector_registers: T.int32, - threads: T.int32, -) -> None: - # pylint: disable=invalid-name, missing-function-docstring - A = T.match_buffer(a, [threads, num_vector_registers, vec_width], "float32") - for t in T.parallel(threads): - for _j in range(iters): - for l in T.unroll(num_vector_registers): - # We want to use as few registers as possible, so we perform - # all operations on the same element - for k in T.vectorized(vec_width): - A[t, l, k] = A[t, l, k] * A[t, l, k] + A[t, l, k] - - -@registry.estimate_peak_flops.register("cpu") -def estimate_peak_fma_flops( +@functools.lru_cache(maxsize=None) +def estimate_peak_fma_vector_flops( target: Target, dev: Device, remote: Optional[RPCSession], + dtype: DataType, vec_width: Optional[int] = None, num_vector_registers: Optional[int] = None, -) -> float: +): + """Estimate peak flops assuming vector fma instructions and no explicit + intrinsics. See estimate_peak_fma_flops. """ - Estimate the maximum number of FLOP/s this target/device combo is capable - of reaching by running a test program. This assumes vectorized f32 FMA - (fused-multiply-add) instructions. - - Parameters - ---------- - target : Target - Target to run on. This should be as specific to the actual hardware as - possible to make sure that LLVM generates the best vector code. - dev : Device - Device to run on. - remote : Optional[RPCSession] - Remote session used to upload artifacts for runtime evaluation. Must be - the same session used to create `dev`. - vec_width : Optional[int] - Vector width of SIMD units on the underlying hardware. Will try to - infer if no value is provided. - num_vector_registers : Optional[int] - Number of vector registers on the underlying hardware. Will try to - infer if no value is provided. + @T.prim_func + def peakflops_fma_tir( + a: T.handle, + vec_width: T.int32, + iters: T.int32, + num_vector_registers: T.int32, + threads: T.int32, + ) -> None: + # pylint: disable=invalid-name, missing-function-docstring + A = T.match_buffer(a, [threads, num_vector_registers, vec_width], dtype) + for t in T.parallel(threads): + for _j in range(iters): + for l in T.unroll(num_vector_registers): + # We want to use as few registers as possible, so we perform + # all operations on the same element + for k in T.vectorized(vec_width): + A[t, l, k] = A[t, l, k] * A[t, l, k] + A[t, l, k] - Returns - ------- - float - Approximate sustained FLOP/s of this target/device combo assuming - vectorized f32 FMA instructions. Each FMA operation counts as two FLOPs. - """ - assert str(target.kind) == "llvm", "Only llvm targets are supported" vec_width, num_vector_registers = _detect_vec_width_registers( target, vec_width, num_vector_registers ) + vec_width //= DataType(dtype).bits // 8 iters = 1000000 nthreads = num_threads() specialized = peakflops_fma_tir.specialize( @@ -155,12 +134,72 @@ def estimate_peak_fma_flops( random_fill = get_global_func("tvm.contrib.random.random_fill") assert random_fill, "Please make sure USE_RANDOM is ON in config.cmake" - a = nd.empty((nthreads, num_vector_registers, vec_width), dtype="float32", device=dev) + a = nd.empty((nthreads, num_vector_registers, vec_width), dtype=dtype, device=dev) random_fill(a) times = f.time_evaluator(f.entry_name, dev, repeat=100, number=1)(a) flops = 2 * vec_width * num_vector_registers * nthreads * iters # fma is two flops - flop_s = flops / times.min - return flop_s + return flops / times.min + + +@registry.estimate_peak_flops.register("cpu") +def estimate_peak_fma_flops( + func: PrimFunc, + features: Dict[str, np.ndarray], + target: Target, + dev: Device, + remote: Optional[RPCSession], + vec_width: Optional[int] = None, + num_vector_registers: Optional[int] = None, +) -> Tuple[float, float, str]: + """ + Estimate the maximum number of FLOP/s this target/device combo is capable + of reaching by running a test program. This assumes vectorized FMA + (fused-multiply-add) instructions. + + + Parameters + ---------- + func : PrimFunc + Function to estimate peak flops for. Used to check if a specific kind + intrinsic or dtype could be used with this function. + features : Dict[str, np.ndarry] + Features extracted from `func`. Used to check if a specific kind + intrinsic or dtype could be used with this function. + target : Target + Target to run on. This should be as specific to the actual hardware as + possible to make sure that LLVM generates the best vector code. + dev : Device + Device to run on. + remote : Optional[RPCSession] + Remote session used to upload artifacts for runtime evaluation. Must be + the same session used to create `dev`. + vec_width : Optional[int] + Vector width of SIMD units on the underlying hardware. Will try to + infer if no value is provided. + num_vector_registers : Optional[int] + Number of vector registers on the underlying hardware. Will try to + infer if no value is provided. + + Returns + ------- + flops : float + Estimated number of flops used by `func`. + peak_flops : float + Approximate sustained FLOP/s of this target/device combo assuming + vectorized FMA instructions. Each FMA operation counts as two FLOPs. + name : str + Dtype/intrinsic used by `func` to achieve peak flops. + """ + # assume that the first argument's dtype is the one we want + dtype = list(func.buffer_map.values())[0].dtype + if "int" in dtype: + flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"]) + else: + flops = np.sum(features["int_addsub"] + features["int_mul"] + features["int_mad"]) + peak_flops = estimate_peak_fma_vector_flops( + target, dev, remote, dtype, vec_width, num_vector_registers + ) + return flops, peak_flops, f"{dtype} FMA" @T.prim_func @@ -181,43 +220,14 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, threads: T.int32, vec_width: T. B[i, l, j] += A[i, k, l, j] -@registry.estimate_peak_bandwidth.register("cpu") -def estimate_peak_bandwidth( +@functools.lru_cache(maxsize=None) +def estimate_peak_bandwidth_dram( target: Target, dev: Device, remote: Optional[RPCSession], vec_width: Optional[int] = None, ) -> float: - """Estimate peak memory bandwidth of a target/device combo. - - Peak bandwidth is estimated by running a small experiment on the underlying - hardware. The peak bandwidth measurement assumes that vector instructions - are being used to load the data. - - Parameters - ---------- - target : Target - Target to use for measurement. This target should be as specific to the - underlying hardware as possible. - dev : Device - Device to measure peak bandwidth on. - remote : Optional[RPCSession] - Remote session used to upload artifacts for runtime evaluation. Must be - the same session used to create `dev`. - vec_width : Optional[int] - Vector unit width, determined from target if not supplied. - - Returns - ------- - float - Peak memory bandwidth in bytes/seconds. - """ - # Ideally we'd be able to use this code to measure peak bandwidth of the - # different cache levels. If we could just generate load commands, then we - # could use those in a tight loop. Instead we need some code that is - # limited on the cache bandwidth. With the L1 cache we need an operation - # that has a very low arithmetic intensity and we haven't come up with one - # yet. + """Estimate peak bandwidth for DRAM. See estimate_peak_bandwidth.""" vec_width, _ = _detect_vec_width_registers(target, vec_width, 1) specialized = peak_bandwidth_tir.specialize( { @@ -252,3 +262,63 @@ def estimate_peak_bandwidth( random_fill(b) times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b, threads) return a.numpy().size * 4 / times.min # 4 bytes per float32 + + +@registry.estimate_peak_bandwidth.register("cpu") +def estimate_peak_bandwidth( + func: PrimFunc, # pylint: disable=unused-argument + features: Dict[str, np.ndarray], + target: Target, + dev: Device, + remote: Optional[RPCSession], + vec_width: Optional[int] = None, +) -> Tuple[float, float, str]: + """Estimate peak memory bandwidth of a target/device combo. + + Peak bandwidth is estimated by running a small experiment on the underlying + hardware. The peak bandwidth measurement assumes that vector instructions + are being used to load the data. + + Parameters + ---------- + func : PrimFunc + Function to estimate peak bandwidth for. Used to check if a specific + kind of memory could be used with this function. + features : Dict[str, np.ndarry] + Features extracted from `func`. Used to check if a specific kind of + memory could be used with this function. + target : Target + Target to use for measurement. This target should be as specific to the + underlying hardware as possible. + dev : Device + Device to measure peak bandwidth on. + remote : Optional[RPCSession] + Remote session used to upload artifacts for runtime evaluation. Must be + the same session used to create `dev`. + vec_width : Optional[int] + Vector unit width, determined from target if not supplied. + + Returns + ------- + loaded_bytes : float + Estimated bytes loaded by `func`. + peak_bandwidth : float + Peak memory bandwidth in bytes/seconds. + name : str + Name of the memory being used. + """ + # Ideally we'd be able to use this code to measure peak bandwidth of the + # different cache levels. If we could just generate load commands, then we + # could use those in a tight loop. Instead we need some code that is + # limited on the cache bandwidth. With the L1 cache we need an operation + # that has a very low arithmetic intensity and we haven't come up with one + # yet. + peak_bandwidth = estimate_peak_bandwidth_dram(target, dev, remote, vec_width) + loaded_bytes = 0.0 + # assume no more than 100 buffers + for i in range(100): + key = f"B{i}.bytes" + if not key in features.keys(): + break + loaded_bytes += np.sum(features[key]) + return loaded_bytes, peak_bandwidth, "DRAM" diff --git a/tests/python/unittest/test_roofline.py b/tests/python/unittest/test_roofline.py index e37f6e085bf6..1b0c600388b3 100644 --- a/tests/python/unittest/test_roofline.py +++ b/tests/python/unittest/test_roofline.py @@ -14,81 +14,90 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import numpy as np -import pytest -from io import StringIO import csv -import os import json +import os import platform +from io import StringIO + +import numpy as np +import pytest import tvm.testing import tvm.utils -from tvm.runtime import profiler_vm -from tvm import relay -from tvm.relay.testing import mlp -from tvm.contrib.debugger import debug_executor -from tvm import rpc +from tvm import relay, rpc from tvm.contrib import utils +from tvm.contrib.debugger import debug_executor +from tvm.relay.testing import mlp +from tvm.runtime import profiler_vm from tvm.runtime.profiling import Report from tvm.script import tir as T -@tvm.testing.parametrize_targets("llvm", "cuda") -def test_estimate_peak_flops(target, dev): - server = rpc.Server(key="roofline_flops") - remote = rpc.connect("127.0.0.1", server.port, key="roofline_flops") - dev = remote.device(target) +@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/12955") +@tvm.testing.requires_llvm +@pytest.mark.parametrize("dtype", ["float32", "int8", "int32"]) +def test_estimate_peak_flops_cpu(dtype): + server = rpc.Server(key="roofline_flops_cpu") + remote = rpc.connect("127.0.0.1", server.port, key="roofline_flops_cpu") + target = tvm.target.Target("llvm -mattr=+fma,+avx2") + dev = remote.device(str(target)) # This test uses vectorized instructions so we need a target that supports them - if target == "llvm": - target = "llvm -mattr=+fma,+avx2" - target = tvm.target.Target(target) - with target: - flops = tvm.utils.roofline.registry.estimate_peak_flops(target, dev, remote) - if str(target.kind) == "llvm": - # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu. - assert ( - flops > 10**9 and flops < 10**14 - ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}" - elif str(target.kind) == "cuda": - # should be able to hit a TFLOP/s with tensor cores - assert ( - flops > 10**12 and flops < 10**14 - ), f"FLOP/s should be between 10^12 and 10^14, but it is {flops}" - else: - raise RuntimeError("Unsupported target " + str(target)) + flops = tvm.utils.roofline.x86.estimate_peak_fma_vector_flops(target, dev, remote, "float32") + # Assume we can achieve 1 GFLOP/s per thread, which is 1 FLOP per cycle on a 1GHz cpu. + assert ( + flops > 10**9 and flops < 10**14 + ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}" + + +@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/12955") +@tvm.testing.requires_cuda +def test_estimate_peak_flops_gpu(): + server = rpc.Server(key="roofline_flops_gpu") + remote = rpc.connect("127.0.0.1", server.port, key="roofline_flops_gpu") + target = tvm.target.Target("cuda") + dev = remote.device(str(target)) + # This test uses vectorized instructions so we need a target that supports them + flops = tvm.utils.roofline.cuda.estimate_peak_flops_tensorcore(target, dev, remote) + # should be able to hit a TFLOP/s with tensor cores + assert ( + flops > 10**12 and flops < 10**14 + ), f"FLOP/s should be between 10^12 and 10^14, but it is {flops}" @tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386") -@tvm.testing.parametrize_targets("llvm", "cuda") -def test_estimate_peak_bandwidth(target, dev): - server = rpc.Server(key="roofline_bandwidth") - remote = rpc.connect("127.0.0.1", server.port, key="roofline_bandwidth") - dev = remote.device(target) +@tvm.testing.requires_llvm +def test_estimate_peak_bandwidth_cpu(): + server = rpc.Server(key="roofline_bandwidth_cpu") + remote = rpc.connect("127.0.0.1", server.port, key="roofline_bandwidth_cpu") + target = tvm.target.Target("llvm -mattr=+fma,+avx2") + dev = remote.device(str(target)) + # This test uses vectorized instructions so we need a target that supports them + bandwidth = tvm.utils.roofline.x86.estimate_peak_bandwidth_dram(target, dev, remote) + # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6 + # GB/s, so this should leave enough wiggle room. + assert ( + bandwidth > 10**9 and bandwidth < 10**12 + ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}" + + +@tvm.testing.requires_cuda +def test_estimate_peak_bandwidth_gpu(): + server = rpc.Server(key="roofline_bandwidth_gpu") + remote = rpc.connect("127.0.0.1", server.port, key="roofline_bandwidth_gpu") + target = tvm.target.Target("cuda") + dev = remote.device(str(target)) # This test uses vectorized instructions so we need a target that supports them - if target == "llvm": - target = "llvm -mattr=+fma,+avx2" - target = tvm.target.Target(target) - with target: - bandwidth = tvm.utils.roofline.registry.estimate_peak_bandwidth(target, dev, remote) - if str(target.kind) == "llvm": - # Assume we can achieve 1 GB/s. DDR2 should transfer somewhere around 6 - # GB/s, so this should leave enough wiggle room. - assert ( - bandwidth > 10**9 and bandwidth < 10**12 - ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}" - elif str(target.kind) == "cuda": - # should be able to hit a 100 GB/s on a GPU. GTX 280 hits 140 GB/s and - # it is really old. - assert ( - bandwidth > 10**11 and bandwidth < 10**13 - ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}" - else: - raise RuntimeError("Unsupported target " + str(target)) + bandwidth = tvm.utils.roofline.cuda.estimate_peak_bandwidth_global(target, dev, remote) + # should be able to hit a 100 GB/s on a GPU. GTX 280 hits 140 GB/s and + # it is really old. + assert ( + bandwidth > 10**11 and bandwidth < 10**13 + ), f"Bandwidth should be between 10^9 and 10^12, but it is {bandwidth}" @tvm.testing.skip_if_32bit(reason="Cannot allocate enough memory on i386") -@tvm.testing.parametrize_targets("llvm -mattr=+fma+avx2", "cuda") +@tvm.testing.parametrize_targets("llvm -mattr=+fma,+avx2", "cuda") def test_roofline_analysis(target, dev): a = relay.var("a", relay.TensorType((512, 512), "float32")) b = relay.var("b", relay.TensorType((512, 512), "float32")) From 68781573c1d72d1add4953ac7047bb7381595bce Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 12 Oct 2022 12:32:29 -0700 Subject: [PATCH 2/5] whoops --- python/tvm/utils/roofline/x86.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py index 55435fba4c8a..97f5a7632b05 100644 --- a/python/tvm/utils/roofline/x86.py +++ b/python/tvm/utils/roofline/x86.py @@ -193,9 +193,9 @@ def estimate_peak_fma_flops( # assume that the first argument's dtype is the one we want dtype = list(func.buffer_map.values())[0].dtype if "int" in dtype: - flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"]) - else: flops = np.sum(features["int_addsub"] + features["int_mul"] + features["int_mad"]) + else: + flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"]) peak_flops = estimate_peak_fma_vector_flops( target, dev, remote, dtype, vec_width, num_vector_registers ) From 96ec2b8a813b9b7da92ab49f2d9d632495714aef Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 13 Oct 2022 16:31:41 -0700 Subject: [PATCH 3/5] add div, mad*2 to flops count --- python/tvm/utils/roofline/cuda.py | 17 ++++++++++++----- python/tvm/utils/roofline/x86.py | 19 +++++++++++++++---- tests/python/unittest/test_roofline.py | 2 +- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/python/tvm/utils/roofline/cuda.py b/python/tvm/utils/roofline/cuda.py index e33978b7d500..45b10e9afb52 100644 --- a/python/tvm/utils/roofline/cuda.py +++ b/python/tvm/utils/roofline/cuda.py @@ -205,7 +205,12 @@ def estimate_peak_flops( assert nvcc.have_tensorcore( dev.compute_version ), "CUDA roofline only works with devices that have tensorcores" - flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"]) + flops = np.sum( + features["float_addsub"] + + features["float_mul"] + + features["float_mad"] * 2 + + features["float_divmod"] + ) peak_flops = estimate_peak_flops_tensorcore(target, dev, remote) return flops, peak_flops, "float16 tensorcore" @@ -228,7 +233,7 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.i @functools.lru_cache(maxsize=None) -def estimate_peak_bandwidth_global( +def estimate_peak_bandwidth_global_mem( target: Target, dev: Device, remote: Optional[RPCSession] = None, @@ -302,8 +307,9 @@ def estimate_peak_bandwidth( Name of the memory being used. """ loaded_bytes = 0.0 - # assume no more than 100 buffers - for i in range(100): + + i = 0 + while True: # autoscheduler features do not take into account that 1. # global and shared memory have very different performance # characteristics -- both are included in the same bytes @@ -317,5 +323,6 @@ def estimate_peak_bandwidth( if not key in features.keys(): break loaded_bytes += np.sum(features[key]) - peak_bandwidth = estimate_peak_bandwidth_global(target, dev, remote) + i += 1 + peak_bandwidth = estimate_peak_bandwidth_global_mem(target, dev, remote) return loaded_bytes, peak_bandwidth, "global" diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py index 97f5a7632b05..1a056b09fbec 100644 --- a/python/tvm/utils/roofline/x86.py +++ b/python/tvm/utils/roofline/x86.py @@ -193,9 +193,19 @@ def estimate_peak_fma_flops( # assume that the first argument's dtype is the one we want dtype = list(func.buffer_map.values())[0].dtype if "int" in dtype: - flops = np.sum(features["int_addsub"] + features["int_mul"] + features["int_mad"]) + flops = np.sum( + features["int_addsub"] + + features["int_mul"] + + features["int_mad"] * 2 + + features["int_divmod"] + ) else: - flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"]) + flops = np.sum( + features["float_addsub"] + + features["float_mul"] + + features["float_mad"] * 2 + + features["float_divmod"] + ) peak_flops = estimate_peak_fma_vector_flops( target, dev, remote, dtype, vec_width, num_vector_registers ) @@ -315,10 +325,11 @@ def estimate_peak_bandwidth( # yet. peak_bandwidth = estimate_peak_bandwidth_dram(target, dev, remote, vec_width) loaded_bytes = 0.0 - # assume no more than 100 buffers - for i in range(100): + i = 0 + while True: key = f"B{i}.bytes" if not key in features.keys(): break loaded_bytes += np.sum(features[key]) + i += 1 return loaded_bytes, peak_bandwidth, "DRAM" diff --git a/tests/python/unittest/test_roofline.py b/tests/python/unittest/test_roofline.py index 1b0c600388b3..7e5de71987b8 100644 --- a/tests/python/unittest/test_roofline.py +++ b/tests/python/unittest/test_roofline.py @@ -88,7 +88,7 @@ def test_estimate_peak_bandwidth_gpu(): target = tvm.target.Target("cuda") dev = remote.device(str(target)) # This test uses vectorized instructions so we need a target that supports them - bandwidth = tvm.utils.roofline.cuda.estimate_peak_bandwidth_global(target, dev, remote) + bandwidth = tvm.utils.roofline.cuda.estimate_peak_bandwidth_global_mem(target, dev, remote) # should be able to hit a 100 GB/s on a GPU. GTX 280 hits 140 GB/s and # it is really old. assert ( From 96e31bd14104adf188aa52185b1a9d19fd9052b3 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Thu, 13 Oct 2022 16:32:57 -0700 Subject: [PATCH 4/5] remove skips (from bad rebase) --- tests/python/unittest/test_roofline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/unittest/test_roofline.py b/tests/python/unittest/test_roofline.py index 7e5de71987b8..a8bf4df497f6 100644 --- a/tests/python/unittest/test_roofline.py +++ b/tests/python/unittest/test_roofline.py @@ -34,7 +34,6 @@ from tvm.script import tir as T -@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/12955") @tvm.testing.requires_llvm @pytest.mark.parametrize("dtype", ["float32", "int8", "int32"]) def test_estimate_peak_flops_cpu(dtype): @@ -50,7 +49,6 @@ def test_estimate_peak_flops_cpu(dtype): ), f"FLOP/s should be between 10^9 and 10^14, but it is {flops}" -@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/12955") @tvm.testing.requires_cuda def test_estimate_peak_flops_gpu(): server = rpc.Server(key="roofline_flops_gpu") From 34291f7e5563dbfd469d5efcd59714938b285876 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Fri, 14 Oct 2022 13:06:34 -0700 Subject: [PATCH 5/5] match features by name format --- python/tvm/utils/roofline/cuda.py | 35 +++++++++++++++---------------- python/tvm/utils/roofline/x86.py | 12 ++++------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/python/tvm/utils/roofline/cuda.py b/python/tvm/utils/roofline/cuda.py index 45b10e9afb52..b6e8ae066459 100644 --- a/python/tvm/utils/roofline/cuda.py +++ b/python/tvm/utils/roofline/cuda.py @@ -16,6 +16,7 @@ # under the License. """Estimation of peak flops and memory bandwidth for cuda devices""" import functools +import re from typing import Dict, Optional, Tuple import numpy as np @@ -306,23 +307,21 @@ def estimate_peak_bandwidth( name : str Name of the memory being used. """ - loaded_bytes = 0.0 - - i = 0 - while True: - # autoscheduler features do not take into account that 1. - # global and shared memory have very different performance - # characteristics -- both are included in the same bytes - # touched count 2. multiple threads accessing the same byte - # of memory does not use the same amount of bandwidth as - # multiple threads accessing different bytes of memory. We - # use unique bytes accessed here to avoid these two issues, - # but this does bias results towards being more compute - # bound. - key = f"B{i}.unique_bytes" - if not key in features.keys(): - break - loaded_bytes += np.sum(features[key]) - i += 1 + # autoscheduler features do not take into account that 1. + # global and shared memory have very different performance + # characteristics -- both are included in the same bytes + # touched count 2. multiple threads accessing the same byte + # of memory does not use the same amount of bandwidth as + # multiple threads accessing different bytes of memory. We + # use unique bytes accessed here to avoid these two issues, + # but this does bias results towards being more compute + # bound. + loaded_bytes = sum( + [ + np.sum(x) + for (k, x) in features.items() + if re.match(r"^B[0-9]+\.unique_bytes$", k) is not None + ] + ) peak_bandwidth = estimate_peak_bandwidth_global_mem(target, dev, remote) return loaded_bytes, peak_bandwidth, "global" diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py index 1a056b09fbec..8ed7ac418f0c 100644 --- a/python/tvm/utils/roofline/x86.py +++ b/python/tvm/utils/roofline/x86.py @@ -16,6 +16,7 @@ # under the License. """Estimate peak flops and bandwidth for x86 devices""" import functools +import re from typing import Dict, Optional, Tuple import numpy as np @@ -324,12 +325,7 @@ def estimate_peak_bandwidth( # that has a very low arithmetic intensity and we haven't come up with one # yet. peak_bandwidth = estimate_peak_bandwidth_dram(target, dev, remote, vec_width) - loaded_bytes = 0.0 - i = 0 - while True: - key = f"B{i}.bytes" - if not key in features.keys(): - break - loaded_bytes += np.sum(features[key]) - i += 1 + loaded_bytes = sum( + [np.sum(x) for (k, x) in features.items() if re.match(r"^B[0-9]+\.bytes$", k) is not None] + ) return loaded_bytes, peak_bandwidth, "DRAM"