Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions python/tvm/contrib/hexagon/meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
"""Meta schedule tuning utilities for Hexagon."""
import os
import tempfile
from typing import Callable, List, Optional
from typing import Callable, Dict, List, Optional
import tvm

from tvm.ir.module import IRModule
from tvm.runtime import Module, NDArray
from tvm.target import Target
from tvm.driver import build as tvm_build
from tvm.tir.transform import RemoveWeightLayoutRewriteBlock
from tvm.contrib.popen_pool import PopenPoolExecutor
from tvm.meta_schedule.utils import cpu_count, derived_object
from tvm.meta_schedule.builder import LocalBuilder
Expand Down Expand Up @@ -121,14 +128,24 @@ def _worker_func(hexagon_launcher, evaluator_config, alloc_repeat, artifact_path
return costs


def get_hexagon_local_builder():
def get_hexagon_local_builder(pass_context: tvm.transform.PassContext = None):
"""Return Hexagon-compatible Builder for meta schedule."""

def export_func(mod):
binary_path = export_module(mod, tempfile.mkdtemp())
return str(binary_path)

return LocalBuilder(f_export=export_func)
def default_build_with_context(
mod: IRModule, target: Target, _params: Optional[Dict[str, NDArray]]
) -> Module:
with pass_context:
mod = RemoveWeightLayoutRewriteBlock(skip_ndarray_rewrite=True)(mod)
return tvm_build(mod, target=target)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use meta_schdule.builder.default_build here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Changed this to use the old strategy if pass context is not present.


if pass_context is not None:
return LocalBuilder(f_build=default_build_with_context, f_export=export_func)
else:
return LocalBuilder(f_export=export_func)


def get_hexagon_rpc_runner(
Expand Down
178 changes: 90 additions & 88 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,98 +20,100 @@
from .. import TensorIntrin


@T.prim_func
def dot_product_32x4_u8u8i32_desc(
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((32, 4), "uint8", offset_factor=1),
C: T.Buffer((32,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
for i in T.serial(0, 32):
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")


@T.prim_func
def dot_product_32x4_u8u8i32_vrmpy(
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((32, 4), "uint8", offset_factor=1),
C: T.Buffer((32,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")

B_i8x128 = B.vload([0, 0], dtype="uint8x128")
B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")

C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
T.uint32(3),
C[T.ramp(T.int32(0), 1, 32)],
B_i32x32,
A_i32,
dtype="int32x32",
)


@T.prim_func
def dot_product_32x4_u8i8i32_desc(
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((32, 4), "int8", offset_factor=1),
C: T.Buffer((32,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
for i in T.serial(0, 32):
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")


@T.prim_func
def dot_product_32x4_u8i8i32_vrmpy(
A: T.Buffer((4,), "uint8", offset_factor=1),
B: T.Buffer((32, 4), "int8", offset_factor=1),
C: T.Buffer((32,), "int32", offset_factor=1),
) -> None:
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")

B_i8x128 = B.vload([0, 0], dtype="int8x128")
B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")

C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
T.uint32(3),
C[T.ramp(T.int32(0), 1, 32)],
T.broadcast(A_i32, 32),
B_i32x32,
dtype="int32x32",
)
def generate_dot_product_32x4_u8u8i32(mem_scope="global"):
@T.prim_func
def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
for i in T.serial(0, 32):
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")

@T.prim_func
def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")

B_i8x128 = B.vload([0, 0], dtype="uint8x128")
B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")

C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
T.uint32(3),
C[T.ramp(T.int32(0), 1, 32)],
B_i32x32,
A_i32,
dtype="int32x32",
)

return dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy


def generate_dot_product_32x4_u8i8i32(mem_scope="global"):
@T.prim_func
def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])
for i in T.serial(0, 32):
for k in T.serial(0, 4):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")

@T.prim_func
def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:4], B[0:32, 0:4])
T.writes(C[0:32])

A_u8x4 = A.vload([0], "uint8x4")
A_i32 = T.reinterpret(A_u8x4, dtype="int32")

B_i8x128 = B.vload([0, 0], dtype="int8x128")
B_i32x32 = T.reinterpret(B_i8x128, dtype="int32x32")

C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
T.uint32(3),
C[T.ramp(T.int32(0), 1, 32)],
T.broadcast(A_i32, 32),
B_i32x32,
dtype="int32x32",
)

return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy


VRMPY_u8u8i32_INTRIN = "dot_32x4_u8u8i32_vrmpy"

TensorIntrin.register(
VRMPY_u8u8i32_INTRIN, dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy
)
TensorIntrin.register(VRMPY_u8u8i32_INTRIN, *generate_dot_product_32x4_u8u8i32())

VRMPY_u8i8i32_INTRIN = "dot_32x4_u8i8i32_vrmpy"

TensorIntrin.register(
VRMPY_u8i8i32_INTRIN, dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy
)
TensorIntrin.register(VRMPY_u8i8i32_INTRIN, *generate_dot_product_32x4_u8i8i32())

VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy"
TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8u8i32("global.vtcm"))

VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy"
TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8i8i32("global.vtcm"))
Loading