Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/contrib/ethosu/cascader/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,21 +225,21 @@ def choose_proposal(
return proposal_choice


def extract_memory_info(memory_pool: PoolInfo) -> MemoryRegion:
def extract_memory_info(memory_pool: PoolInfo, memory_pressure: int) -> MemoryRegion:
"Create a MemoryRegion based on the info in the memory pool"
size = int(memory_pool.size_hint_bytes)
size = int(memory_pool.size_hint_bytes - memory_pressure)
read_bandwidth = int(memory_pool.read_bandwidth_bytes_per_cycle)
write_bandwidth = int(memory_pool.write_bandwidth_bytes_per_cycle)

for param in (size, read_bandwidth, write_bandwidth):
assert param != -1, f"{param} needs to be specified for the cascader."

name_to_burst_lenght = {
name_to_burst_length = {
target.kind.name: burst for target, burst in memory_pool.target_burst_bytes.items()
}

try:
burst_length = int(name_to_burst_lenght["ethos-u"])
burst_length = int(name_to_burst_length["ethos-u"])
except KeyError:
burst_length = 1

Expand Down
59 changes: 48 additions & 11 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,46 @@ def _ethos_u55_cascader(sram, enable_striping) -> Callable:
)


def _calculate_memory_pressure(mod: tvm.ir.IRModule) -> int:
"""
Calculates a worst-case estimate of the memory consumed at the callsite of
each microNPU function. This value can be used as a hint to guide the cascader,
indicating how aggressively it will need to optimize the input module to fit
into the memory that remains in the memory workspace.

Parameters
----------
mod : tvm.ir.IRModule
The input module

Returns
-------
int
Memory pressure value for the module.
"""
memory_pressure = 0

@util.create_npu_function_pass(opt_level=1)
class CalculateMemoryPressure:
"""
Traverse the module and get total memory used by external NPU functions.
"""

def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
nonlocal memory_pressure
max_val = max(func.attrs["used_memory"])
memory_pressure += max_val
return func

CalculateMemoryPressure()(mod) # pylint: disable=not-callable

io_used_memory = 0
if not tvm.tir.usmp.utils.use_workspace_io_is_enabled():
io_used_memory = int(mod["main"].attrs["io_used_memory"])

return memory_pressure - io_used_memory


@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir")
def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
"""
Expand Down Expand Up @@ -413,21 +453,18 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
# Use the cascader if it is enabled for the U55 accelerator, otherwise use copy_constants
# scheduler
if util.is_cascader_enabled():
assert (
util.get_accelerator_config() != "ethos-u65-256"
), "Cascading is not supported for the U65 accelerator"
if util.get_accelerator_config() == "ethos-u65-256":
raise ValueError("Cascading is not supported for the U65 accelerator")

workspace_memory_pools = mod.attrs["workspace_memory_pools"]

assert (
workspace_memory_pools
), "Workspace memory pool needs to be provided for the U55 cascader"

assert (
len(workspace_memory_pools.pools) == 1
), "Exactly one workspace pool needs to be provided for the U55 cascader"
if not workspace_memory_pools:
raise ValueError("Workspace memory pool needs to be provided for the U55 cascader")
if len(workspace_memory_pools.pools) != 1:
raise ValueError("Exactly one workspace pool needs to be provided for the U55 cascader")

sram = extract_memory_info(workspace_memory_pools.pools[0])
memory_pressure = _calculate_memory_pressure(mod)
sram = extract_memory_info(workspace_memory_pools.pools[0], memory_pressure)
tir_mod = LowerToTIR(_ethos_u55_cascader(sram, util.is_striping_enabled()))(mod)
else:
tir_mod = LowerToTIR(copy_constants())(mod)
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/tir/usmp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import Optional, List

import tvm
from tvm._ffi import register_object
from tvm.runtime import Object
from . import _ffi_api
Expand All @@ -31,6 +32,14 @@
CANDIDATE_MEMORY_POOL_ATTR = "candidate_memory_pools"


def use_workspace_io_is_enabled() -> bool:
"""
Check whether placing I/O tensors in the workspace is enabled.
"""
ctx = tvm.transform.PassContext.current()
return bool(ctx.config.get("tir.usmp.use_workspace_io", False))


@register_object("tir.usmp.BufferInfo")
class BufferInfo(Object):
"""BufferInfo object holds information related to buffers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wrong-import-position

"""
Test memory pressure is calculated correctly from used memory annotations.
"""

import pytest

pytest.importorskip("ethosu.vela")

import tvm
from tvm import relay
from tvm.relay.backend.contrib.ethosu.codegen import _calculate_memory_pressure
from tvm.contrib.ethosu.cascader.scheduler import extract_memory_info
from tvm import WorkspacePoolInfo, PoolInfoProperties


def _npu_and_non_npu_functions():
mod = tvm.IRModule({})

# NPU function 1
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
max_pool = relay.nn.max_pool2d(x)
composite_func = relay.Function([x], max_pool)
composite_func = composite_func.with_attr("Composite", "ethos-u.pooling")
inp = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
compiler_func = relay.Function([inp], composite_func)
compiler_func = compiler_func.with_attr("used_memory", [32])
npu_compiler_func1 = compiler_func.with_attr("Compiler", "ethos-u")
g1 = relay.GlobalVar("g1")
mod[g1] = npu_compiler_func1

# Non-NPU function
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
max_pool = relay.abs(x)
composite_func = relay.Function([x], max_pool)
composite_func = composite_func.with_attr("Composite", "foo.unary_elementwise")
inp = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
compiler_func = relay.Function([inp], composite_func)
compiler_func = compiler_func.with_attr("used_memory", [32])
non_npu_compiler_func = compiler_func.with_attr("Compiler", "foo")
g2 = relay.GlobalVar("g2")
mod[g2] = non_npu_compiler_func

# NPU function 2
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
max_pool = relay.abs(x)
composite_func = relay.Function([x], max_pool)
composite_func = composite_func.with_attr("Composite", "ethos-u.unary_elementwise")
inp = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
compiler_func = relay.Function([inp], composite_func)
compiler_func = compiler_func.with_attr("used_memory", [32])
npu_compiler_func2 = compiler_func.with_attr("Compiler", "ethos-u")
g3 = relay.GlobalVar("g3")
mod[g3] = npu_compiler_func2

# Main
inp = relay.var("main_input", shape=(1, 2, 2, 4), dtype="int8")
call1 = relay.Call(g1, [inp])
call2 = relay.Call(g2, [call1])
call3 = relay.Call(g3, [call2])
main_func = relay.Function([inp], call3)
main_func = main_func.with_attr("io_used_memory", 32)
mod["main"] = main_func
return mod


def _parallel_npu_functions():
mod = tvm.IRModule({})

# NPU function 1
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
max_pool = relay.nn.max_pool2d(x)
composite_func = relay.Function([x], max_pool)
composite_func = composite_func.with_attr("Composite", "ethos-u.pooling")
inp = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
compiler_func = relay.Function([inp], composite_func)
compiler_func = compiler_func.with_attr("used_memory", [32])
npu_compiler_func1 = compiler_func.with_attr("Compiler", "ethos-u")
g1 = relay.GlobalVar("g1")
mod[g1] = npu_compiler_func1

# NPU function 2
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
abs_op = relay.abs(x)
composite_func = relay.Function([x], abs_op)
composite_func = composite_func.with_attr("Composite", "ethos-u.unary_elementwise")
inp = relay.var("input", shape=(1, 2, 2, 4), dtype="int8")
compiler_func = relay.Function([inp], composite_func)
compiler_func = compiler_func.with_attr("used_memory", [32 + 16])
npu_compiler_func2 = compiler_func.with_attr("Compiler", "ethos-u")
g2 = relay.GlobalVar("g2")
mod[g2] = npu_compiler_func2

# Main
inp = relay.var("main_input", shape=(1, 2, 2, 4), dtype="int8")
call1 = relay.Call(g1, [inp])
call2 = relay.Call(g2, [inp])
concat = relay.concatenate([call1, call2], axis=3)
main_func = relay.Function([inp], concat)
main_func = main_func.with_attr("io_used_memory", 32)
mod["main"] = main_func
return mod


def _full_offload():
mod = tvm.IRModule({})

# NPU function
x = relay.var("x", shape=(1, 4, 4, 16), dtype="int8")
max_pool = relay.nn.max_pool2d(x)
composite_func = relay.Function([x], max_pool)
composite_func = composite_func.with_attr("Composite", "ethos-u.pooling")
inp = relay.var("input", shape=(1, 4, 4, 16), dtype="int8")
compiler_func = relay.Function([inp], composite_func)
compiler_func = compiler_func.with_attr("used_memory", [256 + 256])
npu_compiler_func = compiler_func.with_attr("Compiler", "ethos-u")
g1 = relay.GlobalVar("g1")
mod[g1] = npu_compiler_func

# Main
inp = relay.var("main_input", shape=(1, 4, 4, 16), dtype="int8")
call = relay.Call(g1, [inp])
main_func = relay.Function([inp], call)
main_func = main_func.with_attr("io_used_memory", 256 + 256)
mod["main"] = main_func
return mod


@pytest.mark.parametrize(
"model_func,use_workspace_io,expected_memory_pressure",
[
(_npu_and_non_npu_functions, True, (16 + 16) + (16 + 16)),
(_npu_and_non_npu_functions, False, (16 + 16) + (16 + 16) - (16 + 16)),
(_parallel_npu_functions, True, (16 + 16) + (16 + 16 + 16)),
(_parallel_npu_functions, False, (16 + 16) + (16 + 16 + 16) - (16 + 16)),
(_full_offload, True, (256 + 256)),
(_full_offload, False, (256 + 256) - (256 + 256)),
],
)
def test_calculate_memory_pressure_pass(model_func, use_workspace_io, expected_memory_pressure):
"""
Test that memory pressure is correctly calculated for NPU external functions.
"""

mod = model_func()
with tvm.transform.PassContext(config={"tir.usmp.use_workspace_io": use_workspace_io}):
memory_pressure = _calculate_memory_pressure(mod)
assert memory_pressure == expected_memory_pressure


def test_extract_memory_info():
"""
Test memory pressure value correctly reduces the workspace size.
"""
initial_pool_size = 2000
memory_pressure = 500
memory_pool = WorkspacePoolInfo(
"SRAM",
[tvm.target.Target("c"), tvm.target.Target("ethos-u")],
PoolInfoProperties(
size_hint_bytes=initial_pool_size,
read_bandwidth_bytes_per_cycle=16,
write_bandwidth_bytes_per_cycle=16,
target_burst_bytes={tvm.target.Target("ethos-u"): 1},
),
)

sram = extract_memory_info(memory_pool, memory_pressure)
assert sram.size == initial_pool_size - memory_pressure
Loading