Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
- op: add.out
kernels:
- arg_meta: null
kernel_name: impl::HiFi::add_out
kernel_name: cadence::impl::HiFi::add_out

- op: bmm.out
kernels:
Expand Down Expand Up @@ -61,11 +61,16 @@
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::full_out

- op: gt.Scalar_out
kernels:
- arg_meta: null
kernel_name: torch::executor::gt_scalar_out

- op: maximum.out
kernels:
- arg_meta: null
kernel_name: impl::HiFi::maximum_out
kernel_name: cadence::impl::HiFi::maximum_out

- op: mean.out
kernels:
Expand All @@ -75,7 +80,7 @@
- op: minimum.out
kernels:
- arg_meta: null
kernel_name: impl::HiFi::minimum_out
kernel_name: cadence::impl::HiFi::minimum_out

- op: mul.out
kernels:
Expand All @@ -90,22 +95,22 @@
- op: pow.Scalar_out
kernels:
- arg_meta: null
kernel_name: impl::HiFi::pow_Scalar_out
kernel_name: cadence::impl::HiFi::pow_Scalar_out

- op: pow.Tensor_Scalar_out
kernels:
- arg_meta: null
kernel_name: impl::HiFi::pow_Tensor_Scalar_out
kernel_name: cadence::impl::HiFi::pow_Tensor_Scalar_out

- op: pow.Tensor_Tensor_out
kernels:
- arg_meta: null
kernel_name: impl::HiFi::pow_Tensor_Tensor_out
kernel_name: cadence::impl::HiFi::pow_Tensor_Tensor_out

- op: rsqrt.out
kernels:
- arg_meta: null
kernel_name: impl::HiFi::rsqrt_out
kernel_name: cadence::impl::HiFi::rsqrt_out

- op: sigmoid.out
kernels:
Expand Down
5 changes: 5 additions & 0 deletions backends/cadence/hifi/kernels/kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ memcpy(void* dst, const void* src, size_t num_bytes) {
MEMCPY_8b(dst, src, num_bytes);
}

void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size) {
Result<void*> temp_mem_res = ctx.allocate_temp(size);
return temp_mem_res.ok() ? temp_mem_res.get() : nullptr;
}

// Quantize a fp32 value to an int8_t/uint8_t value
template <typename T>
__attribute__((always_inline)) T
Expand Down
8 changes: 7 additions & 1 deletion backends/cadence/hifi/kernels/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
/* For NNLIB APIs */
#include "xa_nnlib_kernels_api.h"

/* Potential NNLIB function/APIs */
#include <executorch/runtime/kernel/kernel_includes.h>

using executorch::runtime::KernelRuntimeContext;
using executorch::runtime::Result;

/* Potential NNLIB function/APIs */
extern "C" WORD32 xa_nn_broadcast_32_32(
WORD32* __restrict__ p_out,
const int* const out_shape,
Expand Down Expand Up @@ -149,6 +153,8 @@ namespace impl {
namespace HiFi {
namespace kernels {

void* allocate_temp_memory(KernelRuntimeContext& ctx, size_t size);

void memcpy(void* dst, const void* src, size_t num_bytes);

WORD32 matmul_asym8uxasym8u_asym8u(
Expand Down
1 change: 1 addition & 0 deletions backends/cadence/hifi/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_gt.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_softmax.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp"
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/hifi/operators/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using executorch::runtime::CppTypeToScalarType;
using executorch::runtime::KernelRuntimeContext;
using torch::executor::Error;

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {
Expand Down Expand Up @@ -202,3 +203,4 @@ Tensor& add_out(
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
2 changes: 2 additions & 0 deletions backends/cadence/hifi/operators/op_maximum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ using torch::executor::apply_binary_elementwise_fn;
using torch::executor::Error;
using torch::executor::resize_to_broadcast_target_size;

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {
Expand Down Expand Up @@ -170,3 +171,4 @@ Tensor& maximum_out(
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
4 changes: 3 additions & 1 deletion backends/cadence/hifi/operators/op_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ Tensor& mean_dim_out(
int scratch_size = xa_nn_reduce_getsize_nhwc(
-3, inp_shape, num_inp_dims, p_axis, num_axis_dims, 1);

void* __restrict__ p_scratch_in = (void* __restrict__)malloc(scratch_size);
void* __restrict__ p_scratch_in =
(void* __restrict__)kernels::allocate_temp_memory(
ctx, scratch_size * sizeof(int));

xa_nn_reduce_mean_4D_f32_f32(
p_out,
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/hifi/operators/op_minimum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ using torch::executor::apply_binary_elementwise_fn;
using torch::executor::Error;
using torch::executor::resize_to_broadcast_target_size;

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {
Expand Down Expand Up @@ -169,3 +170,4 @@ Tensor& minimum_out(
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
14 changes: 10 additions & 4 deletions backends/cadence/hifi/operators/op_pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ using executorch::runtime::promoteTypes;
using torch::executor::Error;
using torch::executor::resize_to_broadcast_target_size;

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {
Expand Down Expand Up @@ -119,9 +120,11 @@ Tensor& pow_Tensor_Tensor_out(
if (optimized) {
if (broadcast) {
WORD32* __restrict__ ptr1 =
(WORD32* __restrict__)malloc(num_elm * sizeof(WORD32));
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, num_elm * sizeof(int));
WORD32* __restrict__ ptr2 =
(WORD32* __restrict__)malloc(num_elm * sizeof(WORD32));
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, num_elm * sizeof(int));

WORD32* __restrict__ pin1 =
(WORD32* __restrict__)a.const_data_ptr<float>();
Expand Down Expand Up @@ -154,7 +157,8 @@ Tensor& pow_Tensor_Tensor_out(
free(ptr2);
} else if (a_is_broadcasted && (!b_is_broadcasted)) {
FLOAT32* __restrict__ ptr1 =
(FLOAT32* __restrict__)malloc((num_elm + 2) * sizeof(WORD32));
(FLOAT32* __restrict__)kernels::allocate_temp_memory(
ctx, num_elm * sizeof(int));

FLOAT32* __restrict__ pin1 =
(FLOAT32* __restrict__)a.const_data_ptr<float>();
Expand All @@ -181,7 +185,8 @@ Tensor& pow_Tensor_Tensor_out(
free(ptr1);
} else if (b_is_broadcasted && (!a_is_broadcasted)) {
WORD32* __restrict__ ptr1 =
(WORD32* __restrict__)malloc(num_elm * sizeof(WORD32));
(WORD32* __restrict__)kernels::allocate_temp_memory(
ctx, num_elm * sizeof(int));

WORD32* __restrict__ pin1 =
(WORD32* __restrict__)b.const_data_ptr<float>();
Expand Down Expand Up @@ -349,3 +354,4 @@ Tensor& pow_Scalar_out(
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
2 changes: 2 additions & 0 deletions backends/cadence/hifi/operators/op_rsqrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using exec_aten::ScalarType;
using exec_aten::Tensor;
using executorch::aten::RuntimeContext;

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {
Expand Down Expand Up @@ -51,3 +52,4 @@ Tensor& rsqrt_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
6 changes: 4 additions & 2 deletions backends/cadence/hifi/operators/op_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ Tensor& where_out(

if (con_shape[0] != out_shape[0] || con_shape[1] != out_shape[1] ||
con_shape[2] != out_shape[2] || con_shape[3] != out_shape[3]) {
void* p_scratch =
malloc(out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]);
void* p_scratch = (void*)kernels::allocate_temp_memory(
ctx,
(out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3]) *
sizeof(int));
const unsigned char* p_brd_cond = (const unsigned char*)p_scratch;
xa_nn_broadcast_8_8(
(WORD8* __restrict__)p_brd_cond,
Expand Down
10 changes: 9 additions & 1 deletion examples/portable/executor_runner/executor_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB

static uint8_t temp_allocator_pool[1024U * 1024U];

DEFINE_string(
model_path,
"model.pte",
Expand Down Expand Up @@ -120,6 +122,10 @@ int main(int argc, char** argv) {
MemoryAllocator method_allocator{
MemoryAllocator(sizeof(method_allocator_pool), method_allocator_pool)};

// Temporary memory required by kernels
MemoryAllocator temp_allocator{
MemoryAllocator(sizeof(temp_allocator_pool), temp_allocator_pool)};

// The memory-planned buffers will back the mutable tensors used by the
// method. The sizes of these buffers were determined ahead of time during the
// memory-planning pasees.
Expand All @@ -144,7 +150,8 @@ int main(int argc, char** argv) {

// Assemble all of the allocators into the MemoryManager that the Executor
// will use.
MemoryManager memory_manager(&method_allocator, &planned_memory);
MemoryManager memory_manager(
&method_allocator, &planned_memory, &temp_allocator);

//
// Load the method from the program, using the provided allocators. Running
Expand Down Expand Up @@ -172,6 +179,7 @@ int main(int argc, char** argv) {

// Run the model.
Error status = method->execute();

ET_CHECK_MSG(
status == Error::Ok,
"Execution of method %s failed with status 0x%" PRIx32,
Expand Down