Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 25 additions & 39 deletions python/tvm/utils/roofline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
# specific language governing permissions and limitations
# under the License.
"""Utilities for computing an approximate roofline model"""
from typing import Dict, Union, Optional
from typing import Dict, Optional, Union

import numpy as np

from ... import auto_scheduler, relay, tir, nd, IRModule, build, topi, transform, get_global_func
from ...target import Target
from ...runtime import profiler_vm, profiling, Device, num_threads
from ...script import tir as T
from ...ir.instrument import pass_instrument
from ... import IRModule, auto_scheduler, build, get_global_func, nd, relay, tir, topi, transform
from ...contrib import utils
from ...ir.expr import GlobalVar
from ...ir.instrument import pass_instrument
from ...rpc.base import RPC_SESS_MASK
from ...rpc.client import RPCSession
from ...contrib import utils

from . import registry, cuda, x86
from ...runtime import Device, num_threads, profiler_vm, profiling
from ...script import tir as T
from ...target import Target
from . import cuda, registry, x86


def _create_args(mod: IRModule, dev: Device, func_name: str = "main", remote=None):
Expand Down Expand Up @@ -131,43 +131,27 @@ def roofline_from_existing(
:py:func:`roofline_analysis` for more information on which metrics
are included.
"""
with target:
peak_bandwidth = registry.estimate_peak_bandwidth(target, dev, remote)
peak_flops = registry.estimate_peak_flops(target, dev, remote)

ridge_point = peak_flops / peak_bandwidth

all_features = {
prim.attrs["hash"]: (name, auto_scheduler.feature.named_features_from_primfunc(prim))
prim.attrs["hash"]: (name, prim, auto_scheduler.feature.named_features_from_primfunc(prim))
for name, prim in tir_functions.items()
if isinstance(prim, tir.PrimFunc) and "hash" in prim.attrs.keys()
}

new_calls = []
for call in report.calls:
if "Hash" in call.keys() and call["Hash"] in all_features:
_, features = all_features[call["Hash"]]

flops = np.sum(features["float_addsub"] + features["float_mul"] + features["float_mad"])
loaded_bytes = 0.0
# assume no more than 100 buffers
for i in range(100):
if str(target.kind) == "cuda":
# autoscheduler features do not take into account that 1.
# global and shared memory have very different performance
# characteristics -- both are included in the same bytes
# touched count 2. multiple threads accessing the same byte
# of memory does not use the same amount of bandwidth as
# multiple threads accessing different bytes of memory. We
# use unique bytes accessed here to avoid these two issues,
# but this does bias results towards being more compute
# bound.
key = f"B{i}.unique_bytes"
else:
key = f"B{i}.bytes"
if not key in features.keys():
break
loaded_bytes += np.sum(features[key])
_, prim, features = all_features[call["Hash"]]

with target:
flops, peak_flops, flops_name = registry.estimate_peak_flops(
prim, features, target, dev, remote
)
loaded_bytes, peak_bandwidth, bandwidth_name = registry.estimate_peak_bandwidth(
prim, features, target, dev, remote
)
ridge_point = peak_flops / peak_bandwidth

runtime = call["Duration (us)"].microseconds * 1e-6
arith_inten = flops / loaded_bytes
call = dict(call)
Expand All @@ -188,8 +172,10 @@ def roofline_from_existing(
else:
new_calls.append(call)
new_configuration = dict(report.configuration.items())
new_configuration["Estimated Peak FLOP/s"] = profiling.Ratio(peak_flops)
new_configuration["Estimated Peak Bandwidth (byte/second)"] = profiling.Ratio(peak_bandwidth)
new_configuration[f"Estimated Peak FLOP/s ({flops_name})"] = profiling.Ratio(peak_flops)
new_configuration[
f"Estimated Peak Bandwidth ({bandwidth_name}, byte/second)"
] = profiling.Ratio(peak_bandwidth)
return profiling.Report(new_calls, report.device_metrics, new_configuration)


Expand Down
167 changes: 129 additions & 38 deletions python/tvm/utils/roofline/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,31 @@
# specific language governing permissions and limitations
# under the License.
"""Estimation of peak flops and memory bandwidth for cuda devices"""
from typing import Optional
from ...script import tir as T
from ... import nd, build, transform
from ...runtime import Device
from ...target import Target
import functools
import re
from typing import Dict, Optional, Tuple

import numpy as np

from ... import build, nd, transform
from ...contrib import nvcc, utils
from ...rpc.base import RPC_SESS_MASK
from ...rpc.client import RPCSession
from ...runtime import Device
from ...script import tir as T
from ...target import Target
from ...tir import PrimFunc
from . import registry
from ...contrib import utils, nvcc


@registry.estimate_peak_flops.register("cuda")
@functools.lru_cache(maxsize=None)
def estimate_peak_flops_tensorcore(
target: Target,
dev: Device,
remote: Optional[RPCSession],
mat_dtype: str = "float16",
acc_dtype: str = "float32",
) -> float:
) -> Tuple[float, float, str]:
"""Estimate the peak FLOP/s of a cuda device with tensorcores.

This estimate should only be used to compare with operators that can use
Expand Down Expand Up @@ -64,12 +70,11 @@ def estimate_peak_flops_tensorcore(

Returns
-------
float
peak_flops : float
Approximate sustained FLOP/s of this target/device combo assuming
mma instructions. Addition and multiplications are each counted as
separate FLOPs.
"""
assert str(target.kind) == "cuda", "Only CUDA devices have tensorcores"

@T.prim_func
def peak_flops_tensorcore_tir(
Expand Down Expand Up @@ -161,6 +166,56 @@ def peak_flops_tensorcore_tir(
return n * 16 * 16 * 16 * 2 * sms * 8 / times.min


@registry.estimate_peak_flops.register("cuda")
def estimate_peak_flops(
func: PrimFunc, # pylint: disable=unused-argument
features: Dict[str, np.ndarray],
target: Target,
dev: Device,
remote: Optional[RPCSession],
) -> Tuple[float, float, str]:
"""Estimate the peak FLOP/s of a cuda device.

Parameters
----------
func : PrimFunc
Function to estimate peak flops for. Used to check if a specific kind
intrinsic or dtype could be used with this function.
features : Dict[str, np.ndarry]
Features extracted from `func`. Used to check if a specific kind
intrinsic or dtype could be used with this function.
target : Target
Target to run on. This should be as specific to the actual hardware as
possible.
dev : Device
Device to run on.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.

Returns
-------
flops : float
Estimated number of flops used by `func`.
peak_flops : float
Approximate sustained FLOP/s of this target/device combo. Addition and
multiplications are each counted as separate FLOPs.
name : str
Dtype/intrinsic used by `func` to achieve peak flops.
"""
assert nvcc.have_tensorcore(
dev.compute_version
), "CUDA roofline only works with devices that have tensorcores"
flops = np.sum(
features["float_addsub"]
+ features["float_mul"]
+ features["float_mad"] * 2
+ features["float_divmod"]
)
peak_flops = estimate_peak_flops_tensorcore(target, dev, remote)
return flops, peak_flops, "float16 tensorcore"


@T.prim_func
def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.int32) -> None:
# pylint: disable=invalid-name, missing-function-docstring
Expand All @@ -178,37 +233,13 @@ def peak_bandwidth_tir(a: T.handle, b: T.handle, blocks: T.int32, warp_size: T.i
B[i, l, j] += A[i, k, l, j]


@registry.estimate_peak_bandwidth.register("cuda")
def estimate_peak_bandwidth(
@functools.lru_cache(maxsize=None)
def estimate_peak_bandwidth_global_mem(
target: Target,
dev: Device,
remote: Optional[RPCSession] = None,
) -> float:
"""Estimate peak memory bandwidth of a target/device combo.

Peak bandwidth is estimated by running a small experiment on the underlying
hardware. The peak bandwidth measurement assumes that vector instructions
are being used to load the data.

Parameters
----------
target : Target
Target to use for measurement. This target should be as specific to the
underlying hardware as possible.
dev : Device
Device to measure peak bandwidth on.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.

Returns
-------
float
Peak memory bandwidth in bytes/seconds.
"""
assert nvcc.have_tensorcore(
dev.compute_version
), "CUDA roofline only works with devices that have tensorcores"
) -> Tuple[float, float, str]:
"""Estimate peak bandwidth of global memory. See estimate_peak_bandwidth"""
warp_size = dev.warp_size
# These sizes seem large enough to give the card time to hit a fixpoint on memory bandwidth
blocks = 1024
Expand All @@ -234,3 +265,63 @@ def estimate_peak_bandwidth(
b = nd.empty((blocks, 4, warp_size), dtype="float32", device=dev)
times = f.time_evaluator(f.entry_name, dev, repeat=10, number=1)(a, b)
return a.numpy().size * 4 / times.min # 4 bytes per float32


@registry.estimate_peak_bandwidth.register("cuda")
def estimate_peak_bandwidth(
func: PrimFunc, # pylint: disable=unused-argument
features: Dict[str, np.ndarray],
target: Target,
dev: Device,
remote: Optional[RPCSession] = None,
) -> Tuple[float, float, str]:
"""Estimate peak memory bandwidth of a target/device combo.

Peak bandwidth is estimated by running a small experiment on the underlying
hardware. The peak bandwidth measurement assumes that vector instructions
are being used to load the data.

Parameters
----------
func : PrimFunc
Function to estimate peak bandwidth for. Used to check if a specific
kind of memory could be used with this function.
features : Dict[str, np.ndarry]
Features extracted from `func`. Used to check if a specific kind of
memory could be used with this function.
target : Target
Target to use for measurement. This target should be as specific to the
underlying hardware as possible.
dev : Device
Device to measure peak bandwidth on.
remote : Optional[RPCSession]
Remote session used to upload artifacts for runtime evaluation. Must be
the same session used to create `dev`.

Returns
-------
loaded_bytes : float
Estimated bytes loaded by `func`.
peak_bandwidth : float
Peak memory bandwidth in bytes/seconds.
name : str
Name of the memory being used.
"""
# autoscheduler features do not take into account that 1.
# global and shared memory have very different performance
# characteristics -- both are included in the same bytes
# touched count 2. multiple threads accessing the same byte
# of memory does not use the same amount of bandwidth as
# multiple threads accessing different bytes of memory. We
# use unique bytes accessed here to avoid these two issues,
# but this does bias results towards being more compute
# bound.
loaded_bytes = sum(
[
np.sum(x)
for (k, x) in features.items()
if re.match(r"^B[0-9]+\.unique_bytes$", k) is not None
]
)
peak_bandwidth = estimate_peak_bandwidth_global_mem(target, dev, remote)
return loaded_bytes, peak_bandwidth, "global"
Loading