From 070e7e9e39ed312b3bf1eb7a6baf210f1534b3e2 Mon Sep 17 00:00:00 2001 From: Alexander Nesterov Date: Mon, 15 Dec 2025 22:56:53 +0100 Subject: [PATCH] =?UTF-8?q?Allow=20f16=E2=86=92f16=20AArch64=20JIT=20reord?= =?UTF-8?q?er=20and=20relax=20stride=20checks=20for=20f16=20paths.=20AArch?= =?UTF-8?q?64=20jit=5Funi=5Freorder=20now=20treats=20pure=20f16=E2=86=92f1?= =?UTF-8?q?6=20as=20valid=20(previously=20only=20f32<->f16=20passed),=20pr?= =?UTF-8?q?eventing=20unnecessary=20fallback=20to=20reference.=20For=20f16?= =?UTF-8?q?=20cases,=20the=20small=E2=80=91stride=20requirement=20is=20rel?= =?UTF-8?q?axed=20so=20blocked/large=E2=80=91stride=20layouts=20can=20stay?= =?UTF-8?q?=20on=20the=20JIT=20path=20instead=20of=20degrading=20to=20ref.?= =?UTF-8?q?=20This=20should=20reduce=20ref=20reorder=20usage=20and=20keep?= =?UTF-8?q?=20f16=20workloads=20on=20optimized=20kernels=20on=20AArch64.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/cpu/aarch64/acl_reorder.cpp | 52 - src/cpu/aarch64/acl_reorder.hpp | 251 -- src/cpu/aarch64/acl_utils.cpp | 368 ++ src/cpu/aarch64/acl_utils.hpp | 133 + src/cpu/aarch64/jit_uni_pooling.hpp | 2 +- src/cpu/aarch64/jit_uni_reorder.cpp | 3350 ----------------- src/cpu/aarch64/jit_uni_reorder.hpp | 314 -- .../matmul/brgemm_matmul_copy_utils.cpp | 76 +- .../matmul/brgemm_matmul_copy_utils.hpp | 8 +- .../aarch64/matmul/brgemm_matmul_reorders.cpp | 130 +- .../aarch64/matmul/brgemm_matmul_reorders.hpp | 10 +- .../aarch64/matmul/brgemm_matmul_utils.cpp | 173 +- .../aarch64/matmul/brgemm_matmul_utils.hpp | 16 +- src/cpu/aarch64/reorder/acl_reorder.cpp | 280 ++ src/cpu/aarch64/reorder/acl_reorder.hpp | 97 + src/cpu/aarch64/reorder/jit_blk_reorder.cpp | 152 + src/cpu/aarch64/reorder/jit_blk_reorder.hpp | 69 + .../reorder/jit_blk_reorder_kernel.cpp | 560 +++ src/cpu/aarch64/reorder/jit_uni_reorder.cpp | 525 +++ src/cpu/aarch64/reorder/jit_uni_reorder.hpp | 104 + .../reorder/jit_uni_reorder_kernel.cpp | 2083 ++++++++++ .../reorder/jit_uni_reorder_kernel.hpp | 424 +++ .../{ => reorder}/jit_uni_reorder_utils.cpp | 284 +- .../aarch64/reorder/jit_uni_reorder_utils.hpp | 168 + src/cpu/platform.hpp | 10 +- src/cpu/reorder/cpu_reorder.hpp | 5 +- src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp | 304 +- src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp | 298 +- src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp | 302 +- src/cpu/reorder/cpu_reorder_regular_bf16.cpp | 22 +- src/cpu/reorder/cpu_reorder_regular_f16.cpp | 20 +- .../reorder/cpu_reorder_regular_f32_bf16.cpp | 10 +- .../reorder/cpu_reorder_regular_f32_f16.cpp | 7 +- .../reorder/cpu_reorder_regular_f32_f32.cpp | 80 +- .../reorder/cpu_reorder_regular_f32_fp8.cpp | 2 +- .../reorder/cpu_reorder_regular_f32_s32.cpp | 12 +- .../reorder/cpu_reorder_regular_f32_s8.cpp | 18 +- .../reorder/cpu_reorder_regular_f32_u8.cpp | 16 +- src/cpu/reorder/cpu_reorder_regular_fp4.cpp | 9 - src/cpu/reorder/cpu_reorder_regular_fp8.cpp | 10 +- src/cpu/reorder/cpu_reorder_regular_s32.cpp | 18 +- src/cpu/reorder/cpu_reorder_regular_s4.cpp | 32 +- src/cpu/reorder/cpu_reorder_regular_s8.cpp | 38 +- src/cpu/reorder/cpu_reorder_regular_u4.cpp | 36 +- src/cpu/reorder/cpu_reorder_regular_u8.cpp | 22 +- 45 files changed, 6097 insertions(+), 4803 deletions(-) delete mode 100644 src/cpu/aarch64/acl_reorder.cpp delete mode 100644 src/cpu/aarch64/acl_reorder.hpp create mode 100644 src/cpu/aarch64/acl_utils.cpp create mode 100644 src/cpu/aarch64/acl_utils.hpp delete mode 100644 src/cpu/aarch64/jit_uni_reorder.cpp delete mode 100644 src/cpu/aarch64/jit_uni_reorder.hpp create mode 100644 src/cpu/aarch64/reorder/acl_reorder.cpp create mode 100644 src/cpu/aarch64/reorder/acl_reorder.hpp create mode 100644 src/cpu/aarch64/reorder/jit_blk_reorder.cpp create mode 100644 src/cpu/aarch64/reorder/jit_blk_reorder.hpp create mode 100644 src/cpu/aarch64/reorder/jit_blk_reorder_kernel.cpp create mode 100644 src/cpu/aarch64/reorder/jit_uni_reorder.cpp create mode 100644 src/cpu/aarch64/reorder/jit_uni_reorder.hpp create mode 100644 src/cpu/aarch64/reorder/jit_uni_reorder_kernel.cpp create mode 100644 src/cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp rename src/cpu/aarch64/{ => reorder}/jit_uni_reorder_utils.cpp (64%) create mode 100644 src/cpu/aarch64/reorder/jit_uni_reorder_utils.hpp diff --git a/src/cpu/aarch64/acl_reorder.cpp b/src/cpu/aarch64/acl_reorder.cpp deleted file mode 100644 index 73e38c0c4bb..00000000000 --- a/src/cpu/aarch64/acl_reorder.cpp +++ /dev/null @@ -1,52 +0,0 @@ -/******************************************************************************* -* Copyright 2023 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include "cpu/aarch64/acl_reorder.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace acl { - -status_t acl_reorder_fwd_t::execute_forward(const exec_ctx_t &ctx) const { - // Lock here is needed because resource_mapper does not support - // concurrent multithreaded access. - std::lock_guard _lock {this->mtx}; - - auto src = CTX_IN_MEM(const void *, DNNL_ARG_FROM); - auto dst = CTX_OUT_MEM(void *, DNNL_ARG_TO); - - // Retrieve primitive resource and configured Compute Library objects - auto *acl_resource - = ctx.get_resource_mapper()->get(this); - - acl_reorder_obj_t &acl_obj = acl_resource->get_acl_obj(); - - acl_obj.src_tensor.allocator()->import_memory(const_cast(src)); - acl_obj.dst_tensor.allocator()->import_memory(dst); - - acl_obj.reorder.run(); - - acl_obj.src_tensor.allocator()->free(); - acl_obj.dst_tensor.allocator()->free(); - - return status::success; -} - -} // namespace acl -} // namespace cpu -} // namespace impl -} // namespace dnnl diff --git a/src/cpu/aarch64/acl_reorder.hpp b/src/cpu/aarch64/acl_reorder.hpp deleted file mode 100644 index 617053841be..00000000000 --- a/src/cpu/aarch64/acl_reorder.hpp +++ /dev/null @@ -1,251 +0,0 @@ -/******************************************************************************* -* Copyright 2023-2025 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ -#ifndef CPU_ACL_REORDER_HPP -#define CPU_ACL_REORDER_HPP - -#include "arm_compute/core/Types.h" -#include "common/utils.hpp" -#include "cpu/acl/acl_utils.hpp" -#include "cpu/aarch64/cpu_isa_traits.hpp" -#include "cpu/reorder/cpu_reorder_pd.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace acl { - -struct acl_reorder_obj_t { - arm_compute::NEReorderLayer reorder; - arm_compute::Tensor src_tensor; - arm_compute::Tensor dst_tensor; - arm_compute::WeightFormat src_wf; - arm_compute::WeightFormat dst_wf; -}; - -struct acl_reorder_conf_t { - arm_compute::TensorInfo src_info; - arm_compute::TensorInfo dst_info; - arm_compute::WeightFormat src_wf; - arm_compute::WeightFormat dst_wf; -}; - -struct acl_reorder_resource_t : public resource_t { - acl_reorder_resource_t() - : acl_obj_(utils::make_unique()) {} - - status_t configure(const acl_reorder_conf_t &app) { - if (!acl_obj_) return status::out_of_memory; - - // Init Compute Library tensors based on info from descriptor - acl_obj_->src_tensor.allocator()->init(app.src_info); - acl_obj_->dst_tensor.allocator()->init(app.dst_info); - - // clang-format off - acl_obj_->reorder.configure( - &acl_obj_->src_tensor, - &acl_obj_->dst_tensor, - app.src_wf, - app.dst_wf - ); - // clang-format on - - return status::success; - } - - acl_reorder_obj_t &get_acl_obj() const { return *acl_obj_; } - DNNL_DISALLOW_COPY_AND_ASSIGN(acl_reorder_resource_t); - -private: - std::unique_ptr acl_obj_; -}; // acl_reorder_resource_t - -struct acl_reorder_fwd_t : public primitive_t { - using primitive_t::primitive_t; - struct pd_t : public cpu_reorder_pd_t { - - using cpu_reorder_pd_t::cpu_reorder_pd_t; - - DECLARE_COMMON_PD_T("acl", acl_reorder_fwd_t); - - static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, - const primitive_attr_t *attr, engine_t *src_engine, - const memory_desc_t *src_md, engine_t *dst_engine, - const memory_desc_t *dst_md) { - - using namespace acl_utils; - - // ACL reorder support f32->f32 and f32->bf16 - bool ok = src_md->data_type == data_type::f32 - && utils::one_of( - dst_md->data_type, data_type::f32, data_type::bf16) - && attr->has_default_values(); - - if (!ok) return status::unimplemented; - - if (!attr->scales_.has_default_values(DNNL_ARG_DST)) { - int mask = attr->scales_.get_mask(DNNL_ARG_DST); - const memory_desc_wrapper input_d(src_md); - if (input_d.has_runtime_dims_or_strides() && mask > 0) - return status::unimplemented; - } - - // Create and check primitive descriptor - auto _pd = make_unique_pd(attr, src_engine->kind(), src_md, - dst_engine->kind(), dst_md); - if (_pd == nullptr) return status::out_of_memory; - if (_pd->init(engine, src_engine, dst_engine) != status::success) { - return status::unimplemented; - } - - // In case we have two or four dimensions, we can't have one of the - // two first dimensions as 1. This is valid for f32->f32 and f32->bf16. - if (dst_md->dims[0] == 1 || dst_md->dims[1] == 1) { - return status::unimplemented; - } - - auto src_tag = memory_desc_matches_one_of_tag( - *src_md, format_tag::ab, format_tag::ba, format_tag::cdba); - ACL_CHECK_SUPPORT(format_tag::undef == src_tag, - "Only ab, ba or cdba source formats supported"); - - auto dst_tag = memory_desc_matches_one_of_tag(*dst_md, - format_tag::BA8b4a, format_tag::BA4b4a, format_tag::Ab4a, - format_tag::Ab8a, format_tag::Acdb8a, format_tag::Acdb4a); - ACL_CHECK_SUPPORT(format_tag::undef == dst_tag, - "Only Ab4a/Ab8a, BA8b4a/BA4b4a and Acdb8a/Acdb4a " - "destination formats supported"); - - if (dst_tag == format_tag::BA4b4a || dst_tag == format_tag::Acdb4a - || dst_tag == format_tag::Ab4a) { - _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo4; - } else if (aarch64::mayiuse(aarch64::sve_256) - && (dst_tag == format_tag::BA8b4a - || dst_tag == format_tag::Acdb8a - || dst_tag == format_tag::Ab8a)) { - _pd->app_.dst_wf = arm_compute::WeightFormat::OHWIo8; - } else { - return status::unimplemented; - } - - arm_compute::TensorShape acl_tensor_shape_in; - arm_compute::TensorShape acl_tensor_shape_out; - - // Switch for 2 or 4 dim tensors - switch (src_md->ndims) { - case 2: { - if (src_tag == format_tag::ab - && dst_md->data_type == data_type::bf16 - && utils::one_of(dst_tag, format_tag::BA8b4a, - format_tag::BA4b4a)) { // bf16 - acl_tensor_shape_in = arm_compute::TensorShape( - src_md->dims[0], src_md->dims[1]); - acl_tensor_shape_out = arm_compute::TensorShape( - dst_md->padded_dims[0], dst_md->padded_dims[1]); - } else if (src_tag == format_tag::ba - && dst_md->data_type == data_type::f32 - && !utils::one_of(dst_tag, format_tag::BA8b4a, - format_tag::BA4b4a)) { // f32 - acl_tensor_shape_in = arm_compute::TensorShape( - src_md->dims[1], src_md->dims[0]); - acl_tensor_shape_out = arm_compute::TensorShape( - dst_md->padded_dims[1], dst_md->padded_dims[0]); - } else { - return status::unimplemented; - } - } break; - case 4: { - // Currently only supporting AxBx1x1 cases - if (dst_md->dims[2] != 1 || dst_md->dims[3] != 1) { - return status::unimplemented; - } - - acl_tensor_shape_in = arm_compute::TensorShape( - src_md->dims[3], src_md->dims[2], src_md->dims[1], - src_md->dims[0]); - acl_tensor_shape_out = arm_compute::TensorShape( - dst_md->padded_dims[3], dst_md->padded_dims[2], - dst_md->padded_dims[1], dst_md->padded_dims[0]); - break; - } - default: return status::unimplemented; - } - - // Choose the data layout - const auto acl_layout = arm_compute::DataLayout::NCHW; - - // Set Source WeightFormat - _pd->app_.src_wf = arm_compute::WeightFormat::OHWI; - - // Create ACL tensor infos - const arm_compute::DataType src_acl_data_t - = acl_utils::get_acl_data_t(src_md->data_type); - _pd->app_.src_info = arm_compute::TensorInfo( - acl_tensor_shape_in, 1, src_acl_data_t, acl_layout); - - const arm_compute::DataType dst_acl_data_t - = acl_utils::get_acl_data_t(dst_md->data_type); - _pd->app_.dst_info = arm_compute::TensorInfo( - acl_tensor_shape_out, 1, dst_acl_data_t, acl_layout); - - ACL_CHECK_VALID(arm_compute::NEReorderLayer::validate( - &_pd->app_.src_info, &_pd->app_.dst_info, _pd->app_.src_wf, - _pd->app_.dst_wf)); - - // Init scratch memory, not used so 0 in this implementation - _pd->init_scratchpad_md(); - - return safe_ptr_assign(*reorder_pd, _pd.release()); - } // create - - friend dnnl::impl::impl_list_item_t; - acl_reorder_conf_t app_; - - }; // pd_t - - acl_reorder_fwd_t(const pd_t *apd) : primitive_t(apd) {} - - status_t create_resource( - engine_t *engine, resource_mapper_t &mapper) const override { - if (mapper.has_resource(this)) return status::success; - - auto r = utils::make_unique(); - if (!r) return status::out_of_memory; - - // Configure the resource based on information from primitive descriptor - CHECK(r->configure(pd()->app_)); - - mapper.add(this, std::move(r)); - return status::success; - } - - status_t execute(const exec_ctx_t &ctx) const override { - return execute_forward(ctx); - } - -private: - // To guard the const execute_forward, the mutex must be 'mutable' - mutable std::mutex mtx; - status_t execute_forward(const exec_ctx_t &ctx) const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - -}; // acl_reorder_fwd_t - -} // namespace acl -} // namespace cpu -} // namespace impl -} // namespace dnnl - -#endif // CPU_ACL_REORDER_HPP diff --git a/src/cpu/aarch64/acl_utils.cpp b/src/cpu/aarch64/acl_utils.cpp new file mode 100644 index 00000000000..ec7f162891f --- /dev/null +++ b/src/cpu/aarch64/acl_utils.cpp @@ -0,0 +1,368 @@ +/******************************************************************************* +* Copyright 2021-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/aarch64/acl_utils.hpp" +#include + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +namespace acl_utils { + +using namespace dnnl::impl::alg_kind; +using namespace data_type; + +status_t safe_set_strides(arm_compute::Strides &strides, size_t dim, size_t val, + bool inc_dim = true) { + // ACL stride value is uint32, check for overflow + if (val > std::numeric_limits::max()) { + return status::unimplemented; + } + + strides.set(dim, val, inc_dim); + return status::success; +} + +arm_compute::DataType get_acl_data_t( + const dnnl_data_type_t dt, const bool is_quantized) { + switch (dt) { + case bf16: return arm_compute::DataType::BFLOAT16; + case f32: return arm_compute::DataType::F32; + case s32: return arm_compute::DataType::S32; + case f16: return arm_compute::DataType::F16; + case s8: + if (is_quantized) + return arm_compute::DataType::QASYMM8_SIGNED; + else + return arm_compute::DataType::S8; + case u8: + if (is_quantized) + return arm_compute::DataType::QASYMM8; + else + return arm_compute::DataType::U8; + default: return arm_compute::DataType::UNKNOWN; + } +} + +status_t convert_to_acl_act(alg_kind_t eltwise_alg, float alpha, float beta, + arm_compute::ActivationLayerInfo &act_info) { + + using namespace arm_compute; + using act_func = ActivationLayerInfo::ActivationFunction; + + switch (eltwise_alg) { + case eltwise_relu: + // oneDNN defines RELU: f(x) = (x > 0) ? x : a*x + // Compute Library defines LEAKY_RELU: f(x) = (x > 0) ? x : a*x + // whilst Compute Library RELU is defined as: f(x) = max(0,x) + if (alpha == 0) { + act_info = ActivationLayerInfo(act_func::RELU, alpha, beta); + } else { + act_info = ActivationLayerInfo( + act_func::LEAKY_RELU, alpha, beta); + } + break; + case eltwise_tanh: + // oneDNN defines TANH activation as: f(x) = tanh(x) + // Compute Library defines TANH activation as: f(x) = a*tanh(b*x) + // Setting a=b=1 makes the two equivalent + act_info = ActivationLayerInfo(act_func::TANH, 1.f, 1.f); + break; + case eltwise_elu: + act_info = ActivationLayerInfo(act_func::ELU, alpha, beta); + break; + case eltwise_square: + act_info = ActivationLayerInfo(act_func::SQUARE, alpha, beta); + break; + case eltwise_abs: + act_info = ActivationLayerInfo(act_func::ABS, alpha, beta); + break; + case eltwise_sqrt: + act_info = ActivationLayerInfo(act_func::SQRT, alpha, beta); + break; + case eltwise_linear: + act_info = ActivationLayerInfo(act_func::LINEAR, alpha, beta); + break; + case eltwise_soft_relu: + if (alpha == 1.f) { + act_info + = ActivationLayerInfo(act_func::SOFT_RELU, alpha, beta); + break; + } else { + return status::unimplemented; + } + case eltwise_logistic: + act_info = ActivationLayerInfo(act_func::LOGISTIC, alpha, beta); + break; + case eltwise_clip: + // oneDNN uses alpha in CLIP as lower bound and beta as upper bound. + // Compute Library uses beta as lower bound and alpha as upper in + // the equivalent function LU_BOUNDED_RELU. + // Switching order of alpha and beta makes the two equivalent. + act_info = ActivationLayerInfo( + act_func::LU_BOUNDED_RELU, beta, alpha); + break; + case eltwise_gelu_erf: + act_info = ActivationLayerInfo(act_func::GELU); + break; + default: act_info = ActivationLayerInfo(); return status::unimplemented; + } + + return status::success; +} + +status_t convert_to_acl_act( + const eltwise_desc_t &ed, arm_compute::ActivationLayerInfo &act_info) { + return convert_to_acl_act(ed.alg_kind, ed.alpha, ed.beta, act_info); +} + +status_t convert_to_acl_act(const post_ops_t::entry_t::eltwise_t &elt, + arm_compute::ActivationLayerInfo &act_info) { + return convert_to_acl_act(elt.alg, elt.alpha, elt.beta, act_info); +} + +status_t tensor_info(arm_compute::TensorInfo &info, const memory_desc_t &md) { + const memory_desc_wrapper md_wrap(&md); + return tensor_info(info, md_wrap); +} + +status_t tensor_info( + arm_compute::TensorInfo &info, const memory_desc_wrapper &md) { + + // All the cases we don't support + if (!md.is_blocking_desc() || !md.is_dense() || !md.is_plain() + || md.has_zero_dim()) + return status::unimplemented; + + // Set each of the dimensions in the TensorShape from the memory desc + // ACL indexes dimensions the opposite way to oneDNN + arm_compute::TensorShape shape; + size_t acl_dim_i = 0; + for (int i = md.ndims() - 1; i >= 0; --i) { + shape.set(acl_dim_i, md.dims()[i]); + acl_dim_i++; + } + + // Set each of the ACL Strides from the memory blocking desc + // ACL indexes strides the opposite way to oneDNN + arm_compute::Strides strides_in_bytes; + const blocking_desc_t &blocking_desc = md.blocking_desc(); + size_t acl_stride_i = 0; + for (int i = md.ndims() - 1; i >= 0; --i) { + // ACL strides are in bytes, oneDNN strides are in numbers of elements, + // multiply by data type size to convert + CHECK(safe_set_strides(strides_in_bytes, acl_stride_i, + blocking_desc.strides[i] * md.data_type_size())); + ++acl_stride_i; + } + + arm_compute::DataType data_type = get_acl_data_t(md.data_type()); + size_t num_channels = 1; + size_t offset_first_element_in_bytes = 0; + size_t total_size_in_bytes = md.size(); + + info.init(shape, num_channels, data_type, strides_in_bytes, + offset_first_element_in_bytes, total_size_in_bytes); + + return status::success; +} + +status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i) { + + // Max 6 dims in ACL, so we can't insert another + if (ti.num_dimensions() >= 6) return status::unimplemented; + + // Copy dimensions from old to new shape, inserting a dimension of size 1 + arm_compute::TensorShape shape = ti.tensor_shape(); + for (size_t old_i = 0, new_i = 0; old_i < ti.num_dimensions(); ++old_i) { + if (old_i == dim_i) { + shape.set(new_i, 1, false); + ++new_i; + } + shape.set(new_i, ti.tensor_shape()[old_i], false); + ++new_i; + } + + // Copy strides from old to new tensor, inserting a duplicate stride + arm_compute::Strides strides; + for (size_t old_i = 0, new_i = 0; old_i < ti.num_dimensions(); ++old_i) { + if (old_i == dim_i) { + CHECK(safe_set_strides( + strides, new_i, ti.strides_in_bytes()[old_i], false)); + ++new_i; + } + CHECK(safe_set_strides( + strides, new_i, ti.strides_in_bytes()[old_i], false)); + ++new_i; + } + + // Reinit TensorInfo with modified shape and strides + ti.init(shape, ti.num_channels(), ti.data_type(), strides, + ti.offset_first_element_in_bytes(), ti.total_size()); + + return status::success; +} + +int reorder_dimensions_by_stride(std::vector permuted_mds, + std::vector mds) { + + // Vectors must be the same length and not empty + if (permuted_mds.size() != mds.size() || mds.empty()) return 0; + + const dim_t ndims = mds[0]->ndims; + + for (const auto &md : mds) { + // Number of dimensions must match and must be blocked + if (md->ndims != ndims || md->format_kind != format_kind::blocked) + return 0; + } + + int reordered_dims = 0; + + // Create initial permutation which swaps nothing + std::vector perm(ndims); + std::iota(perm.begin(), perm.end(), 0); + + // For each dimension d1, find a dimension (d2) in which every md has the + // next smallest stride, then swap d2 into d1. Stride is initially 1 (i.e. + // dense) but will increase each time we find a dimension. The target + // strides may be different across dimensions if they are broadcasted. + std::vector next_smallest_stride(mds.size(), 1); + for (dim_t d1 = ndims - 1; d1 >= 0; --d1) { + bool found_swap = false; + for (dim_t d2 = d1; d2 >= 0; --d2) { + // Check that all mds have the right stride + found_swap = true; + for (size_t i = 0; i < mds.size(); i++) { + auto &md_strides = mds[i]->format_desc.blocking.strides; + // Either it is the next smallest stride, or the dimensions is 1 + // so we can ignore it + bool can_swap = md_strides[perm[d2]] == next_smallest_stride[i] + || mds[i]->dims[perm[d2]] == 1; + if (!can_swap) { + found_swap = false; + break; + } + } + if (found_swap) { + // Multiply next smallest strides by dimension we just found + for (size_t i = 0; i < mds.size(); i++) + next_smallest_stride[i] *= mds[i]->dims[perm[d2]]; + + // Swap the found dimension (perm[d2]) into d1 + nstl::swap(perm[d2], perm[d1]); + ++reordered_dims; + break; + } + } + // We didn't find a swap for this dimension, we can't continue + if (!found_swap) break; + } + + // memory_desc_permute_axes applies the inverse of the permutation + // so we need to invert our permutation to get what we want + std::vector invperm(ndims); + for (dim_t d = 0; d < ndims; ++d) + invperm[perm[d]] = d; + + // Apply the inverse permutation to each dimension axis + for (size_t i = 0; i < mds.size(); i++) { + memory_desc_permute_axes(*permuted_mds[i], *mds[i], invperm.data()); + } + + return reordered_dims; +} + +status_t reorder_to_weight_format(arm_compute::TensorInfo &info, + memory_desc_t &md, arm_compute::WeightFormat wf, dim_t I_dim, + dim_t O_dim, const std::vector &spatial_dims, + const std::vector &batch_dims) { + + md.format_kind = format_kind::blocked; + md.format_desc.blocking = blocking_desc_t {}; + const int interleaved_by = arm_compute::interleave_by(wf); + const int block_by = arm_compute::block_by(wf); + + // I dimension becomes densest (apart from blocking) + md.format_desc.blocking.strides[I_dim] = interleaved_by * block_by; + md.padded_dims[I_dim] = utils::rnd_up(md.dims[I_dim], block_by); + + // Then any spatial dimensions (e.g. HW) + dim_t ldb = interleaved_by * md.padded_dims[I_dim]; + for (dim_t sd : spatial_dims) { + md.format_desc.blocking.strides[sd] = ldb; + ldb *= md.padded_dims[sd]; + } + + // O dim (which was the innermost) becomes the outermost (apart from batching) + md.format_desc.blocking.strides[O_dim] = ldb; + md.padded_dims[O_dim] = utils::rnd_up(md.dims[O_dim], interleaved_by); + + // Update the batch dimensions, starting with stride of the innermost batch + const dim_t innermost_batch_stride + = md.padded_dims[I_dim] * md.padded_dims[O_dim]; + dim_t batch_stride = innermost_batch_stride; + for (dim_t bd : batch_dims) { + md.format_desc.blocking.strides[bd] = batch_stride; + batch_stride *= md.padded_dims[bd]; + } + + // Weights can only be blocked if they are also interleaved + if (interleaved_by > 1) { + md.format_desc.blocking.inner_nblks = 1 + (block_by > 1); + + md.format_desc.blocking.inner_idxs[0] = O_dim; + md.format_desc.blocking.inner_blks[0] = interleaved_by; + if (block_by > 1) { + md.format_desc.blocking.inner_idxs[1] = I_dim; + md.format_desc.blocking.inner_blks[1] = block_by; + } + } + + if (arm_compute::is_fixed_format_fast_math(wf)) { + md.data_type = dnnl_bf16; + info.set_data_type(arm_compute::DataType::BFLOAT16); + } + + // The data layout is now determined by the manually set strides + info.set_data_layout(arm_compute::DataLayout::UNKNOWN); + + // x is ignored in fixed format kernels + // y is the leading dimension of b (ldb) in the GEMM d = a*b + c + // This is the stride of O_dim in the md + // z is the batch dimension (not strictly needed if there's only 1 batch) + // i.e. how much do I need to stride to get to the next matmul (ignoring + // the interleaving). Note that we use the innermost_batch_stride + // because all the batched dimensions are collapsed (as required by ACL). + arm_compute::Strides new_strides_in_bytes = info.strides_in_bytes(); + CHECK(safe_set_strides(new_strides_in_bytes, 1, ldb * info.element_size())); + CHECK(safe_set_strides(new_strides_in_bytes, 2, + innermost_batch_stride * info.element_size())); + + info.init(info.tensor_shape(), info.num_channels(), info.data_type(), + new_strides_in_bytes, info.offset_first_element_in_bytes(), + memory_desc_wrapper(md).size()); + return status::success; +} + +} // namespace acl_utils + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/acl_utils.hpp b/src/cpu/aarch64/acl_utils.hpp new file mode 100644 index 00000000000..b1ec3f345da --- /dev/null +++ b/src/cpu/aarch64/acl_utils.hpp @@ -0,0 +1,133 @@ +/******************************************************************************* +* Copyright 2021-2023, 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_ACL_UTILS_HPP +#define CPU_AARCH64_ACL_UTILS_HPP + +#include + +#include "oneapi/dnnl/dnnl_types.h" + +#include "common/dnnl_thread.hpp" +#include "common/memory_tracking.hpp" +#include "common/primitive.hpp" +#include "common/resource.hpp" +#include "common/utils.hpp" + +#include "arm_compute/runtime/NEON/NEFunctions.h" +#include "arm_compute/runtime/Scheduler.h" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +namespace acl_utils { + +arm_compute::DataType get_acl_data_t( + const dnnl_data_type_t dt, const bool is_quantized = false); + +// Convert alg_kind_t, alpha and beta into an ACL ActivationLayerInfo. Will +// return unimplemented and a disabled ActivationLayerInfo if the conversion +// fails +status_t convert_to_acl_act(alg_kind_t eltwise_alg, float alpha, float beta, + arm_compute::ActivationLayerInfo &act_info); + +// Convert an eltwise_desc_t into an ACL ActivationLayerInfo. Will return +// unimplemented and a disabled ActivationLayerInfo if the conversion fails +status_t convert_to_acl_act( + const eltwise_desc_t &ed, arm_compute::ActivationLayerInfo &act_info); + +// Convert an eltwise post op into an ACL ActivationLayerInfo. Will return +// unimplemented and a disabled ActivationLayerInfo if the conversion fails +status_t convert_to_acl_act(const post_ops_t::entry_t::eltwise_t &elt, + arm_compute::ActivationLayerInfo &act_info); + +// Convert a memory desc to an arm_compute::TensorInfo. Note that memory desc +// must be blocking format, plain, dense and have no zero dimensions. +status_t tensor_info(arm_compute::TensorInfo &info, const memory_desc_t &md); +status_t tensor_info( + arm_compute::TensorInfo &info, const memory_desc_wrapper &md); + +// Insert a dimension of size 1 at the index dim_i of TensorInfo +status_t insert_singleton_dimension(arm_compute::TensorInfo &ti, size_t dim_i); + +// Reorder the logical dimensions of the memory descriptors (mds) by stride so +// that accessing the tensor elements in the natural order is dense. Note, this +// does not reorder the data, it just reorders the logical indices. The +// permutation is common to all mds, so the function returns when it cannot find +// a dimension with a common smallest stride. Returns the number of dimensions +// that we managed to reorder to be dense. +int reorder_dimensions_by_stride(std::vector permuted_mds, + std::vector mds); + +// Reorder a memory_desc_t and set the strides on a arm_compute::TensorInfo to +// match an arm_compute::WeightFormat. You are required to specify how various +// logical dimensions in oneDNN correspond to logical dimensions in arm_compute. +// info TensorInfo where the strides will be changed to match the reordering +// md memory descriptor where the stride and padded dimensions will be +// changed or reordering +// wf Describes the memory format/layout of the weights +// I_dim The logical dimension of md corresponding to the input channel of +// a convolution or the K dimension in a matmul +// O_dim The logical dimension of md corresponding to the output channel of a +//   convolution or the N dimension in a matmul +// spatial_dims The logical dimensions of md corresponding to the spatial +// dimensions of the weights (H, W, D for example). These will be +// the next densest after the inner blocks and the input channel. +// batch_dims The logical dimensions of md related to the batch in a batched +// matmul, ordered from innermost to outermost. ACL calls these +// the multi_stride_b. These will become the outermost (least dense) +// dimensions and will be collapsed. +status_t reorder_to_weight_format(arm_compute::TensorInfo &info, + memory_desc_t &md, arm_compute::WeightFormat wf, dim_t I_dim, + dim_t O_dim, const std::vector &spatial_dims, + const std::vector &batch_dims = {}); + +// Logs a custom 'info' line describing an unsupported case +#define LOG_ACL_UNSUPPORTED(msg) \ + do { \ + if (get_verbose(verbose_t::create_dispatch)) \ + verbose_printf("cpu,acl,unsupported: %s\n", (msg)); \ + } while (0) + +// Returns unimplemented if error code x is NOT OK +#define ACL_CHECK_VALID(x) \ + do { \ + arm_compute::Status s = x; \ + if (s.error_code() != arm_compute::ErrorCode::OK) { \ + LOG_ACL_UNSUPPORTED(s.error_description().c_str()); \ + return dnnl::impl::status::unimplemented; \ + } \ + } while (0) + +// Returns unimplemented on condition x == true +#define ACL_CHECK_SUPPORT(x, msg) \ + do { \ + if (x) { \ + LOG_ACL_UNSUPPORTED(msg); \ + return dnnl::impl::status::unimplemented; \ + } \ + } while (0) + +} // namespace acl_utils + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_AARCH64_ACL_UTILS_HPP diff --git a/src/cpu/aarch64/jit_uni_pooling.hpp b/src/cpu/aarch64/jit_uni_pooling.hpp index ac854c75fce..4df7f22c115 100644 --- a/src/cpu/aarch64/jit_uni_pooling.hpp +++ b/src/cpu/aarch64/jit_uni_pooling.hpp @@ -30,7 +30,7 @@ #include "cpu/cpu_pooling_pd.hpp" #include "cpu/aarch64/jit_uni_pool_kernel.hpp" -#include "cpu/aarch64/jit_uni_reorder.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder.hpp" namespace dnnl { namespace impl { diff --git a/src/cpu/aarch64/jit_uni_reorder.cpp b/src/cpu/aarch64/jit_uni_reorder.cpp deleted file mode 100644 index f2aa1f42d2c..00000000000 --- a/src/cpu/aarch64/jit_uni_reorder.cpp +++ /dev/null @@ -1,3350 +0,0 @@ -/******************************************************************************* -* Copyright 2018-2023 Intel Corporation -* Copyright 2020-2024 FUJITSU LIMITED -* Copyright 2022-2025 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include -#include - -#include "oneapi/dnnl/dnnl_debug.h" - -#include "common/c_types_map.hpp" -#include "common/dnnl_thread.hpp" -#include "common/memory_desc_wrapper.hpp" -#include "common/nstl.hpp" -#include "common/primitive.hpp" -#include "common/type_helpers.hpp" -#include "common/utils.hpp" - -#include "cpu/cpu_primitive.hpp" -#include "cpu/reorder/cpu_reorder_pd.hpp" - -#include "cpu/aarch64/jit_uni_reorder.hpp" - -#include "cpu/aarch64/jit_generator.hpp" - -// #define DNNL_DEV_MODE -#if defined(DNNL_DEV_MODE) -#define DEBUg(...) \ - do { \ - if (get_verbose(verbose_t::debuginfo) > 1) { __VA_ARGS__ } \ - } while (0) -#else -#define DEBUg(...) -#endif -#define DEBUG(...) DEBUg(__VA_ARGS__) - -using namespace Xbyak_aarch64; -using namespace dnnl::impl::types; - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -namespace tr { - -static bool prb_has_small_strides(const prb_t &prb) { - constexpr ptrdiff_t max_stride = (1LL << 31) - 1; - for (int d = 0; d < prb.ndims; ++d) { - const ptrdiff_t cms = max_stride / prb.nodes[d].n; - const bool small_strides = true - && prb.nodes[d].is < cms / (int)data_type_size(prb.itype) - && prb.nodes[d].os < cms / (int)data_type_size(prb.otype); - if (!small_strides) return false; - } - return true; -} - -/** Minimal reasonable/desirable kernel size. - * The constant might be used to determine how a problem should be split - * between kernel and threading driver. */ -const size_t ker_prb_size_min = 64; - -/* kernel */ -struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) - - void operator()(const call_param_t *c) const override { - jit_generator::operator()(c); - } - void operator()(const tail_call_param_t *c) const override { - jit_generator::operator()(c); - } - - status_t create_kernel() override { return jit_generator::create_kernel(); } - - enum class scale_arg_t { NONE, SRC, DST }; - - enum { - len_unroll_max = 256, - ndims_jit_loop_max = 3, - }; - - struct simple_impl_desc_t { - int ndims_full_unroll = 0; - int len_last_dim_unroll = 0; - int tail_len_unroll = 0; - int len_unroll = 0; - }; - -#define PARAM(x) \ - abi_param1, \ - prb_.is_tail_present ? offsetof(tail_call_param_t, base_params) \ - + offsetof(call_param_t, x) \ - : offsetof(call_param_t, x) -#define TAIL_PARAM(x) abi_param1, offsetof(tail_call_param_t, x) - - static bool simple_impl_desc_init( - const prb_t &prb, simple_impl_desc_t *desc) { - const int ndims = prb.ndims; - - int ndims_full_unroll = 0; - int len_last_dim_unroll = 1; - int tail_len_unroll = 0; - int len_unroll = 1; - - // It is responsible for finding as many values - // as kernel can unroll. If tail is present then - // kernel will unroll only last node (possible improvement). - // If there is no tail kernel can unroll a few nodes without any loops etc. - // ndims_full_unroll - how many nodes will be unrolled - // len_last_dim_unroll - what piece of last unrolled node will be unrolled - if (prb.is_tail_present) { - ndims_full_unroll = 1; - len_unroll = prb.nodes[0].n; - tail_len_unroll = prb.nodes[0].is_zero_pad_needed - ? 0 - : static_cast(prb.nodes[0].tail_size); - } else { - for (int d = 0; d < ndims; ++d) { - const auto &node = prb.nodes[d]; - if (len_unroll * node.n <= len_unroll_max) { - ndims_full_unroll++; - len_unroll *= node.n; - } else { - len_last_dim_unroll = len_unroll_max / len_unroll; - while (node.n % len_last_dim_unroll) - --len_last_dim_unroll; - len_unroll *= len_last_dim_unroll; - break; - } - } - } - - if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) return false; - - if (desc) { - desc->ndims_full_unroll = ndims_full_unroll; - desc->len_last_dim_unroll = len_last_dim_unroll; - desc->tail_len_unroll = tail_len_unroll; - desc->len_unroll = len_unroll; - } - - return true; - } - - static bool applicable(const prb_t &p) { - using namespace data_type; - - bool bf16_ok - = (mayiuse_bf16() && (p.itype == bf16) && (p.otype == bf16) - && !interim_f32_needed(p, false) && p.beta == 0.f) - || (p.itype != bf16 && p.otype != bf16) - || (p.itype == f32 && p.otype == bf16 && mayiuse_bf16() - && p.beta == 0.f) - || (p.itype == bf16 && p.otype == f32 && mayiuse_bf16() - && p.beta == 0.f); - - bool is_f16 = (p.itype == f16 || p.otype == f16); - bool f16_ok = (p.itype == f32 && p.otype == f16 && p.beta == 0.f) - || (p.itype == f16 && p.otype == f32 && p.beta == 0.f); - - bool ok = true && p.ndims > 0 - && utils::one_of( - p.itype, f32, f16, bf16, s32, data_type::s8, u8) - && utils::one_of( - p.otype, f32, f16, bf16, s32, data_type::s8, u8) - && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ - && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ - && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p) - && bf16_ok && IMPLICATION(is_f16, f16_ok); - - return ok; - } - - XReg o_addr(int o_off, bool with_type_multiplier = true) { - if (o_off) { - add_imm(X_DEFAULT_ADDR, x_ptr_out_off, - o_off * (with_type_multiplier ? otype_sz_ : 1), X_TMP); - return X_DEFAULT_ADDR; - } - - return x_ptr_out_off; - } - - XReg src_s_addr(int s_off) { - if (s_off) { - add_imm(X_DEFAULT_ADDR, x_ptr_src_scale_off, s_off * stype_sz_, - X_TMP); - return X_DEFAULT_ADDR; - } else { - return x_ptr_src_scale_off; - } - } - - XReg dst_s_addr(int s_off) { - if (s_off) { - add_imm(X_DEFAULT_ADDR, x_ptr_dst_scale_off, s_off * stype_sz_, - X_TMP); - return X_DEFAULT_ADDR; - } else { - return x_ptr_dst_scale_off; - } - } - - XReg c_addr(int c_off) { - if (c_off) { - add_imm(X_DEFAULT_ADDR, x_ptr_comp_off, c_off * sizeof(int32_t), - X_TMP); - return X_DEFAULT_ADDR; - } - - return x_ptr_comp_off; - } - - XReg data_chunk_addr(int node_id) { - add_imm(X_DEFAULT_ADDR, abi_param1, - offsetof(tail_call_param_t, curr_data_chunks) - + sizeof(int64_t) * (node_id), - X_TMP); - return X_DEFAULT_ADDR; - } - - void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, - int prev_c_off, int &i_off, int &o_off, int &s_off, int &c_off, - int step_size = 1) { - i_off = prev_i_off; - o_off = prev_o_off; - s_off = prev_s_off; - c_off = prev_c_off; - - if (off == 0) return; - - int start_dim = 0, dims_prod = 1; - for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) - dims_prod *= prb_.n(start_dim); - assert(start_dim < prb_.ndims); - off /= step_size; - - for (int dim_id = start_dim; dim_id < prb_.ndims; ++dim_id) { - i_off += prb_.is(dim_id); - o_off += prb_.os(dim_id); - s_off += prb_.ss(dim_id); - c_off += prb_.cs(dim_id); - - if (off % prb_.n(dim_id)) break; - - i_off += -prb_.n(dim_id) * prb_.is(dim_id); - o_off += -prb_.n(dim_id) * prb_.os(dim_id); - s_off += -prb_.n(dim_id) * prb_.ss(dim_id); - c_off += -prb_.n(dim_id) * prb_.cs(dim_id); - - off /= prb_.n(dim_id); - - if (off == 0) break; /* FIXME: is it really required? */ - } - } - - void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, - int step_size = 1) { - int dummy = 0; - step(off, prev_i_off, prev_o_off, dummy, dummy, i_off, o_off, dummy, - dummy, step_size); - } - - void tr8x8_sve256(int i_off, int o_off) { - using namespace data_type; - - const auto cvt2ps - = [=](const int startIdx, const int regNum, data_type_t idt) { - switch (idt) { - case f32: - /* do nothing */ - break; - case f16: cvt_v_f16_f32(startIdx, regNum); break; - case s32: cvt_z_s32_f32(startIdx, regNum); break; - case bf16: cvt_v_bf16_fp32(startIdx, regNum); break; - case data_type::s8: - cvt_z_s8_s32(startIdx, regNum); - cvt_z_s32_f32(startIdx, regNum); - break; - case u8: - cvt_z_u8_s32(startIdx, regNum); - cvt_z_s32_f32(startIdx, regNum); - break; - default: assert(!"unreachable"); - } - }; - - const auto cvt2odt = [=](const int startIdx, const int regNum, - data_type_t odt, data_type_t idt) { - switch (odt) { - case s32: - if (idt == f32) - cvt_z_f32_s32(startIdx, regNum); - else if (idt == data_type::s8) - cvt_z_s8_s32(startIdx, regNum); - else if (idt == u8) - cvt_z_u8_s32(startIdx, regNum); - break; - case data_type::s8: - if (idt == f32) cvt_z_f32_s32(startIdx, regNum); - if (utils::one_of(idt, f32, s32)) - cvt_z_s32_s8(startIdx, regNum); - if (idt == u8) cvt_z_u8_s8(startIdx, regNum); - break; - case data_type::bf16: - if (idt == f32) cvt_v_f32_bf16(startIdx, regNum); - break; - case data_type::f16: - if (idt == f32) cvt_v_f32_f16(startIdx, regNum); - break; - case u8: - if (idt == f32) cvt_z_f32_s32(startIdx, regNum); - if (utils::one_of(idt, f32, s32)) - cvt_z_s32_u8(startIdx, regNum); - if (idt == data_type::s8) cvt_z_s8_u8(startIdx, regNum); - break; - default: assert(!"unreachable"); - } - }; - - const int unroll = 8; - - const bool interim_f32 = (prb_.itype != f32) - || utils::one_of(f32, prb_.itype, prb_.otype); - - const bool need_saturation - = (utils::one_of(prb_.otype, u8, data_type::s8, s32) - && interim_f32); - const uint64_t sveLen = get_sve_length(); - - PReg p_size(DUMMY_IDX); - switch (unroll * itype_sz_) { - case 32: p_size = p_lsb_256; break; - case 16: p_size = p_lsb_128; break; - case 8: p_size = p_lsb_64; break; - default: assert(!"unreachable"); - } - - const int node_0_input_stride = prb_.is(0); - add_imm(X_TMP_0, XReg(x_ptr_in_off), itype_sz_ * i_off, X_DEFAULT_ADDR); - for (int i = 1; i < unroll / 2; i++) - add_imm(x_tmp_vec[i], x_tmp_vec[i - 1], - itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR); - for (uint32_t i = 0; i < unroll / 2; i++) - ld1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i])); - for (int i = 0; i < unroll / 2; i++) - add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4], - itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR); - for (uint32_t i = 0; i < unroll / 2; i++) - ld1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i])); - - if (interim_f32) cvt2ps(0, unroll, prb_.itype); - -#if 0 - /* Debug code to forcedly set test pattern. */ - index(z0.s, 0, 1); - mov(z0.s, P_NOT_256/T_m, 0); - mov(z_tmp_vec[0].s, 16); - for(uint32_t i=1; i<8; i++) { - add(ZRegS{i}, ZRegS{i-1}, z_tmp_vec[0].s); - mov(ZRegS{i}, P_NOT_256/T_m, 0); - } -#endif - - ptrue(p_tmp0.s, VL4); - /* 1st turn */ - for (uint32_t i = 0; i < unroll / 2; i++) { - trn1(z_tmp_vec[i].s, ZRegS {2 * i}, ZRegS {2 * i + 1}); - trn2(z_tmp_vec[unroll / 2 + i].s, ZRegS {2 * i}, ZRegS {2 * i + 1}); - } - - /* 2nd turn */ - trn1(z4.d, z_tmp_vec[0].d, z_tmp_vec[1].d); - trn1(z5.d, z_tmp_vec[4].d, z_tmp_vec[5].d); - trn2(z6.d, z_tmp_vec[0].d, z_tmp_vec[1].d); - trn2(z7.d, z_tmp_vec[4].d, z_tmp_vec[5].d); - trn1(z_tmp_vec[0].d, z_tmp_vec[2].d, z_tmp_vec[3].d); - trn1(z_tmp_vec[1].d, z_tmp_vec[6].d, z_tmp_vec[7].d); - trn2(z_tmp_vec[2].d, z_tmp_vec[2].d, z_tmp_vec[3].d); - trn2(z_tmp_vec[3].d, z_tmp_vec[6].d, z_tmp_vec[7].d); - - /* 3rd turn */ - for (uint32_t i = 0; i < unroll / 2; i++) { - mov(ZRegD {i}, ZRegD {unroll / 2 + i}); - mov(z_tmp_vec[unroll / 2 + i].d, z_tmp_vec[i].d); - } - - /* 4th turn */ - for (uint32_t i = 0; i < unroll / 2; i++) { - ZRegB z {unroll / 2 + i}; - ZRegB z_tmp = z_tmp_vec[unroll / 2 + i].b; - /* Move bit 0-127 to 128-255. */ - ext(z, z, 16); - /* Move bit 128-255 to 0-127. */ - ext(z_tmp, z_tmp, sveLen - 16); - } - - /* 5th turn */ - for (uint32_t i = 0; i < unroll / 2; i++) { - ZRegS z0 {i}; - ZRegS z1 {unroll / 2 + i}; - sel(z0, p_tmp0.s, z0, z_tmp_vec[unroll / 2 + i].s); - sel(z1, p_tmp0, z1, z_tmp_vec[i].s); - } - - if (need_saturation) { - init_saturate_f32(ymm_zero_, ymm_saturation_ubound_, X_TMP_0, - interim_f32 ? f32 : prb_.itype, prb_.otype); - for (int i = 0; i < unroll; i++) - saturate_f32(ZRegS(i), ymm_zero_, ymm_saturation_ubound_, - prb_.otype, P_ALL_ONE); - } - - if (prb_.otype != f32) - cvt2odt(0, unroll, prb_.otype, interim_f32 ? f32 : prb_.itype); - - const int node_1_output_stride = prb_.os(1); - - switch (unroll * otype_sz_) { - case 32: p_size = p_lsb_256; break; - case 16: p_size = p_lsb_128; break; - case 8: p_size = p_lsb_64; break; - default: assert(!"unreachable"); - } - - add_imm(X_TMP_0, XReg(x_ptr_out_off), otype_sz_ * o_off, - X_DEFAULT_ADDR); - for (int i = 1; i < unroll / 2; i++) - add_imm(x_tmp_vec[i], x_tmp_vec[i - 1], - otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR); - for (uint32_t i = 0; i < 4; i++) - st1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i])); - for (int i = 0; i < unroll / 2; i++) - add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4], - otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR); - - for (uint32_t i = 0; i < unroll / 2; i++) - st1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i])); - } - - bool can_do_tr8x8() { - using namespace data_type; - - static constexpr int desirable_node_size = 8; - static constexpr int desirable_stride = 1; - - // This processing is relied on swaping two innermost dimension. - // Therefore, input stride in second node and output stride in first node - // have to be equal to 1. - - return mayiuse(sve_256) && prb_.ndims >= 2 - && ((utils::one_of(prb_.itype, u8, data_type::s8, s32, f32) - && utils::one_of( - prb_.otype, u8, data_type::s8, s32, f32))) - && utils::everyone_is(desirable_node_size, prb_.n(0), prb_.n(1)) - && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(1)) - && !prb_.is_tail_present - && prb_.src_scale_type == scale_type_t::NONE - && prb_.dst_scale_type == scale_type_t::NONE - && prb_.beta == 0.f; - } - - bool process_unroll_tr8x8(const int ndims, const int len) { - if (!can_do_tr8x8()) return false; - - const int step_size = prb_.n(0) * prb_.n(1); - int i_off = 0, o_off = 0; - for (int off = 0; off < len; off += step_size) { - step(off, i_off, o_off, i_off, o_off, step_size); - tr8x8_sve256(i_off, o_off); - } - - return true; - } - - template - bool process_direct_copy(const int ndims, const int len) { - using namespace data_type; - - static constexpr int desirable_stride = 1; - using TRegS = - typename utils::conditional::type; - const int simd_w = cpu_isa_traits::vlen / itype_sz_; - - // TODO: support tail_processing for direct copy - - const bool do_src_zp = prb_.req_src_zp; - const bool do_dst_zp = prb_.req_dst_zp; - const bool zp_applicable = IMPLICATION( - (do_src_zp || do_dst_zp), utils::one_of(prb_.itype, s32, f32)); - const bool can_do = true && mayiuse(isa) - && compensation_needed_ == false - && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(0)) - && (false || (prb_.itype == prb_.otype ? zp_applicable : false) - || (prb_.itype == s32 && prb_.otype == f32) - || (prb_.itype == f32 && prb_.otype == s32)) - && len % simd_w == 0 && prb_.n(0) % len == 0 - && !prb_.is_tail_present - && prb_.src_scale_type == scale_type_t::NONE - && prb_.dst_scale_type == scale_type_t::NONE - && prb_.beta == 0.f; - if (!can_do) return false; - - static constexpr int vmm_zp_last_idx = 15; - const auto vmm_src_zp - = TRegS(do_dst_zp ? vmm_zp_last_idx - 1 : vmm_zp_last_idx); - if (do_src_zp) { - uni_ld1rw(vmm_src_zp, PARAM(src_zp)); - uni_scvtf(vmm_src_zp, vmm_src_zp); - } - const auto vmm_dst_zp = TRegS(vmm_zp_last_idx); - if (do_dst_zp) { - uni_ld1rw(vmm_dst_zp, PARAM(dst_zp)); - uni_scvtf(vmm_dst_zp, vmm_dst_zp); - } - - const auto apply_zp_ps = [&](const TRegS vmm) { - if (do_src_zp) fsub(vmm, vmm, vmm_src_zp); - if (do_dst_zp) fadd(vmm, vmm, vmm_dst_zp); - }; - - for (int off = 0; off < len;) { - // TODO: we need extra reg for proper saturation if otype == s32 - int unroll - = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w); - unroll = (do_src_zp || do_dst_zp) - ? nstl::min(unroll, 16 - do_src_zp - do_dst_zp) - : unroll; - - int ur = 0; - int tmp_ur = 0; - while (ur < unroll) { - int count = 0; - const int vlen = cpu_isa_traits::vlen; - - do { - add_imm(x_tmp_vec[count++], x_ptr_in_off, - (off + ur * simd_w) * itype_sz_, X_DEFAULT_ADDR); - ur++; - } while (ur < unroll && count < x_tmp_vec_size); - - for (int i = 0; i < count; i++) { - if (vlen == 64 || vlen == 32) - ld1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z, - ptr(x_tmp_vec[i])); - else if (vlen == 16) - ldr(QReg(tmp_ur + i), ptr(x_tmp_vec[i])); - else - assert(!"unreachable"); - } - tmp_ur += count; - } - - if (prb_.itype != prb_.otype) { - for (int ur = 0; ur < unroll; ++ur) { - TRegS r(ur); - if (prb_.itype == s32 && prb_.otype == f32) { - uni_scvtf(r, r); - apply_zp_ps(r); - } else if (prb_.itype == f32 && prb_.otype == s32) { - apply_zp_ps(r); - uni_frinti(r, r); - uni_fcvtzs(r, r); - } else - assert(!"unreachable"); - } - } else if (do_src_zp || do_dst_zp) { - for (int ur = 0; ur < unroll; ++ur) { - const auto vmm = TRegS(ur); - if (prb_.otype == f32) { - apply_zp_ps(vmm); - } else if (prb_.otype == s32) { - uni_scvtf(vmm, vmm); - apply_zp_ps(vmm); - uni_frinti(vmm, vmm); - uni_fcvtzs(vmm, vmm); - } - } - } - - ur = 0; - tmp_ur = 0; - while (ur < unroll) { - int count = 0; - const int vlen = cpu_isa_traits::vlen; - - do { - add_imm(x_tmp_vec[count++], x_ptr_out_off, - (off + ur * simd_w) * otype_sz_, X_DEFAULT_ADDR); - ur++; - } while (ur < unroll && count < x_tmp_vec_size); - - for (int i = 0; i < count; i++) { - if (vlen == 64 || vlen == 32) - st1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z, - ptr(x_tmp_vec[i])); - else if (vlen == 16) - str(QReg(tmp_ur + i), ptr(x_tmp_vec[i])); - else - assert(!"unreachable"); - } - tmp_ur += count; - } - - off += unroll * simd_w; - } - - return true; - } - - void process_unroll_generic_step(int reg_unroll, const int *i_off, - const int *o_off, const int *s_off, const int *c_off, - const int *zero_padding, const bool tail_processing) { - using namespace data_type; - - auto cvt2ps - = [=](const int startIdx, const int regNum, data_type_t idt) { - switch (idt) { - case f32: - /* do nothing */ - break; - case s32: cvt_v_s32_f32(startIdx, regNum); break; - case bf16: cvt_v_bf16_fp32(startIdx, regNum); break; - case f16: cvt_v_f16_f32(startIdx, regNum); break; - case data_type::s8: - cvt_v_s8_s32(startIdx, regNum); - cvt_v_s32_f32(startIdx, regNum); - break; - case u8: - cvt_v_u8_s32(startIdx, regNum); - cvt_v_s32_f32(startIdx, regNum); - break; - default: assert(!"unreachable"); - } - }; - - auto cvt2odt = [=](const int startIdx, const int regNum, - data_type_t odt, data_type_t idt) { - switch (odt) { - case f32: - if (idt == bf16) cvt_v_bf16_fp32(startIdx, regNum); - if (idt == f16) cvt_v_f16_f32(startIdx, regNum); - break; - case s32: - if (idt == f32) - cvt_v_f32_s32(startIdx, regNum); - else if (idt == data_type::s8) - cvt_v_s8_s32(startIdx, regNum); - else if (idt == u8) - cvt_v_u8_s32(startIdx, regNum); - break; - case data_type::s8: - if (idt == f32) cvt_v_f32_s32(startIdx, regNum); - if (idt == f32 || idt == s32) - cvt_v_s32_s8(startIdx, regNum); - if (idt == u8) { cvt_v_u8_s8(startIdx, regNum); } - break; - case u8: - if (idt == f32) cvt_v_f32_s32(startIdx, regNum); - if (idt == f32 || idt == s32) - cvt_v_s32_u8(startIdx, regNum); - if (idt == data_type::s8) cvt_v_s8_u8(startIdx, regNum); - break; - case bf16: - if (idt == f32) cvt_v_f32_bf16(startIdx, regNum); - break; - case f16: - if (idt == f32) cvt_v_f32_f16(startIdx, regNum); - break; - default: assert(!"unreachable"); - } - }; - - auto load_bytes_addr = [=](const int ur, const int r) { - add_imm(x_tmp_vec[r], x_ptr_in_off, i_off[ur + r] * itype_sz_, - X_DEFAULT_ADDR); - }; - auto load_bytes = [=](const int ur, int size, int r) { - switch (size) { - case 4: ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); break; - case 2: ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); break; - case 1: ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); break; - default: assert(!"unreachable"); - } - }; - - auto store = [=](const XReg &addr, const VReg ymm, int size) { - const uint32_t xmm = ymm.getIdx(); - switch (size) { - case 16: str(QReg(xmm), ptr(addr)); break; - case 8: str(DReg(xmm), ptr(addr)); break; - case 4: str(SReg(xmm), ptr(addr)); break; - case 2: str(HReg(xmm), ptr(addr)); break; - case 1: str(BReg(xmm), ptr(addr)); break; - default: assert(!"unreachable"); - } - }; - - /* check whether loading 4 values at once is possible */ - static constexpr int xmm_vlen = 4; - bool can_load_xmm = reg_unroll % xmm_vlen == 0; - for (int ur = 1; ur < reg_unroll; ++ur) - if (i_off[ur] != i_off[ur - 1] + 1) { - can_load_xmm = false; - break; - } - const int load_step = can_load_xmm ? xmm_vlen : 1; - - /* check whether storing 4 values at once is possible */ - bool can_store_xmm = reg_unroll % xmm_vlen == 0; - for (int ur = 1; ur < reg_unroll; ++ur) - if (o_off[ur] != o_off[ur - 1] + 1) { - can_store_xmm = false; - break; - } - const int ur_step = can_store_xmm ? 4 : 1; - const int load_tail_step - = !can_load_xmm && can_store_xmm ? ur_step : load_step; - - const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_); - - const bool need_saturation - = (utils::one_of(prb_.otype, u8, data_type::s8, s32) - && interim_f32); - - std::vector store_masks; - if (tail_processing) { - for (int ur = 0; ur < reg_unroll; ur += load_tail_step) { - uni_clear(VReg(ur)); - store_masks.push_back(0); - for (int r = 0; r < load_tail_step; ++r) { - if (zero_padding[ur + r] == 0) { - store_masks.back() += 1 << r; - load_bytes_addr(ur, r); - } - } - - for (int r = 0; r < load_tail_step; ++r) - if (zero_padding[ur + r] == 0) load_bytes(ur, itype_sz_, r); - } - } else { - if (!can_load_xmm && can_store_xmm) { - assert(ur_step == xmm_vlen); - /* load with stride */ - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - for (int r = 0; r < ur_step; ++r) { - load_bytes_addr(ur, r); - } - for (int r = 0; r < ur_step; ++r) - load_bytes(ur, itype_sz_, r); - } - } else { - int ur = 0; - int tmp_ur = 0; - while (ur < reg_unroll) { - int count = 0; - - do { - add_imm(x_tmp_vec[count++], x_ptr_in_off, - i_off[ur] * itype_sz_, X_DEFAULT_ADDR); - ur += load_step; - } while (ur < reg_unroll && count < x_tmp_vec_size); - - for (int i = 0; i < count; i++) { - - switch (load_step * itype_sz_) { - case 16: - ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); - break; - case 8: ldr(DReg(tmp_ur), ptr(x_tmp_vec[i])); break; - case 4: ldr(SReg(tmp_ur), ptr(x_tmp_vec[i])); break; - case 2: ldr(HReg(tmp_ur), ptr(x_tmp_vec[i])); break; - case 1: ldr(BReg(tmp_ur), ptr(x_tmp_vec[i])); break; - default: assert(!"unreachable"); - } - tmp_ur += load_step; - } - } - } - } - - /* xmm[:] <-- (f32)xmm[:] */ - if (interim_f32) { - const int cvt_step = nstl::max(load_step, ur_step); - for (int ur = 0; ur < reg_unroll; ur += cvt_step) - cvt2ps(ur, 1, prb_.itype); - } - - if (can_load_xmm && !can_store_xmm) { - // transposition on the fly - const bool fast_return = prb_.src_scale_type != scale_type_t::MANY - && prb_.dst_scale_type != scale_type_t::MANY - && prb_.beta == 0.f && !prb_.req_src_zp && !prb_.req_dst_zp; - if (fast_return) { - if (prb_.src_scale_type == scale_type_t::COMMON) - for (int ur = 0; ur < reg_unroll; ur += load_step) - fmul(VReg4S(ur), VReg4S(ur), xmm_src_scales_); - if (prb_.dst_scale_type == scale_type_t::COMMON) - for (int ur = 0; ur < reg_unroll; ur += load_step) - fmul(VReg4S(ur), VReg4S(ur), xmm_dst_scales_); - if (prb_.otype != f32) { - init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, - X_TMP_0, interim_f32 ? f32 : prb_.itype, - prb_.otype); - for (int ur = 0; ur < reg_unroll; ur += load_step) { - if (need_saturation) - saturate_f32(VReg4S(ur), xmm_zero_, - xmm_saturation_ubound_, prb_.otype, - P_ALL_ONE); - } - - for (int ur = 0; ur < reg_unroll; ur += load_step) - cvt2odt(ur, 1, prb_.otype, - interim_f32 ? f32 : prb_.itype); - } - for (int ur = 0; ur < reg_unroll; ur += load_step) { - for (int r = 0; r < load_step; ++r) { - add_imm(x_tmp_vec[r], x_ptr_out_off, - o_off[ur + r] * otype_sz_, X_DEFAULT_ADDR); - } - - for (int r = 0; r < load_step; ++r) { - if (otype_sz_ == 4) - st1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); - else if (otype_sz_ == 2) - st1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); - else - st1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); - } - } - return; - } - - /* scatter elements of xmm into 4 xmms */ - if (itype_sz_ == 4 || interim_f32) { - for (int ur = 0; ur < reg_unroll; ur += load_step) - for (int r = 1; r < load_step; ++r) { - VReg4S v(ur); - VReg4S v_r(ur + r); - dup(VReg16B(ur + r), VReg16B(ur)[0]); - ins(VReg4S(ur + r)[0], VReg4S(ur)[r]); - } - } else { - for (int ur = 0; ur < reg_unroll; ur += load_step) - for (int r = 1; r < load_step; ++r) - ext(VReg16B(ur + r), VReg16B(ur), VReg16B(ur), - itype_sz_ * r); - } - } - - /* src zero point application */ - if (prb_.req_src_zp) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - const auto xmm = VReg4S(ur); - if (interim_f32) - fsub(xmm, xmm, xmm_src_zp_); - else - sub(xmm, xmm, xmm_src_zp_); - } - } - - /* scale and beta processing */ - if (can_store_xmm) { - const auto apply_scales = [&](const VReg4S &vreg_scales, - scale_arg_t scale_arg, - scale_type_t scale_type) { - if (scale_type == scale_type_t::COMMON) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) - fmul(VReg4S(ur), VReg4S(ur), vreg_scales); - } else if (scale_type == scale_type_t::MANY) { - enum class scale_load_type_t { bcast, load, gather }; - const uint32_t idx = vreg_scales.getIdx(); - - uni_clear(VReg(idx)); - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - scale_load_type_t scale_load_type - = scale_load_type_t::bcast; // the best case - - for (int r = ur + 1; r < ur + ur_step; ++r) - if (s_off[r] != s_off[r - 1] + 0) - scale_load_type = scale_load_type_t::load; - - if (scale_load_type == scale_load_type_t::bcast - && !tail_processing) { - if (scale_arg == scale_arg_t::SRC) - ld1r(vreg_scales, ptr(src_s_addr(s_off[ur]))); - else - ld1r(vreg_scales, ptr(dst_s_addr(s_off[ur]))); - fmul(VReg4S(ur), VReg4S(ur), vreg_scales); - continue; - } - - // bcast doesn't work, the next try -- load - for (int r = ur + 1; r < ur + ur_step; ++r) - if (s_off[r] != s_off[r - 1] + 1) - scale_load_type = scale_load_type_t::gather; - - if (scale_load_type == scale_load_type_t::load - && !tail_processing) { - if (scale_arg == scale_arg_t::SRC) - ldr(QReg {idx}, ptr(src_s_addr(s_off[ur]))); - else - ldr(QReg {idx}, ptr(dst_s_addr(s_off[ur]))); - - fmul(VReg4S(ur), VReg4S(ur), VReg4S {idx}); - continue; - } - - // load doesn't work as well - // so gather the scale factors one by one - for (int r = ur; r < ur + ur_step; ++r) - if (zero_padding[r] == 0 || !tail_processing) { - if (scale_arg == scale_arg_t::SRC) - mov(x_tmp_vec[r - ur], - src_s_addr(s_off[r])); - else - mov(x_tmp_vec[r - ur], - dst_s_addr(s_off[r])); - } - for (int r = ur; r < ur + ur_step; ++r) - if (zero_padding[r] == 0 || !tail_processing) - ld1(vreg_scales[r - ur], - ptr(x_tmp_vec[r - ur])); - fmul(VReg4S(ur), VReg4S(ur), vreg_scales); - } - } - }; - /* xmm <-- src_scales * xmm[:] */ - apply_scales( - xmm_src_scales_, scale_arg_t::SRC, prb_.src_scale_type); - - /* xmm[:] <-- beta * dst + xmm[:] */ - assert(prb_.beta == 0.f || prb_.beta == 1.f); - if (prb_.beta == 1.f) { - int ur = 0; - int tmp_ur = 0; - - while (ur < reg_unroll) { - int count = 0; - - do { - add_imm(x_tmp_vec[count++], x_ptr_out_off, - o_off[ur] * otype_sz_, X_DEFAULT_ADDR); - ur += ur_step; - } while (ur < reg_unroll && count < x_tmp_vec_size); - - assert(count <= z_tmp_vec_size); - /* Firstly, data is loaded. */ - for (int i = 0; i < count; i++) { - - if (prb_.otype == f32 || prb_.otype == s32) { - ldr(QReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); // bug - } else if (prb_.otype == data_type::s8 - || prb_.otype == u8) { - ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); // bug - } else - assert(!"unreachable"); - } - - /* Secondly, it is added. */ - if (prb_.otype == f32) { - for (int i = 0; i < count; i++) { - VReg4S v(tmp_ur); - fadd(v, v, VReg4S(tmp_vec_idx[i])); - tmp_ur += ur_step; - } - } else { - for (int i = 0; i < count; i++) { - /* cvt2ps() generate successive instructions - which have save destination operand, - but out of order can be expected. */ - cvt2ps(tmp_vec_idx[i], 1, prb_.otype); - } - for (int i = 0; i < count; i++) { - VReg4S v(tmp_ur); - fadd(v, v, VReg4S(tmp_vec_idx[i])); - tmp_ur += ur_step; - } - } - } - } - - /* dst <-- dst_scales * xmm[:] */ - apply_scales( - xmm_dst_scales_, scale_arg_t::DST, prb_.dst_scale_type); - } else { - const auto apply_scales = [&](const VReg4S &vreg_scales, - scale_arg_t scale_arg, - scale_type_t scale_type) { - if (scale_type == scale_type_t::COMMON) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) - fmul(VReg4S(ur), VReg4S(ur), vreg_scales); - } else if (scale_type == scale_type_t::MANY) { -#define DUMMY_IDX_ (99) - std::vector idx_list; - std::vector offt_list; - std::vector vec_reg; - std::vector addr_reg; - const size_t max_cnt_per_loop - = std::min(tmp_vec_idx.size(), x_tmp_vec.size()); - size_t cnt = 0; // valid unroll steps count - - // 1. Listing up valid steps - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - if (zero_padding[ur] == 0 || !tail_processing) { - idx_list.push_back(ur); - offt_list.push_back(s_off[ur]); - vec_reg.push_back( - tmp_vec_idx[cnt % max_cnt_per_loop]); - if (s_off[ur]) - addr_reg.push_back( - x_tmp_vec[cnt % max_cnt_per_loop]); - else - addr_reg.push_back(scale_arg == scale_arg_t::SRC - ? x_ptr_src_scale_off - : x_ptr_dst_scale_off); - cnt++; - } - } - /* 2. Generate instructions considering instruction order. - If cnt > max_cnt_per_loop, the following instruction sets are - generated several times. - add x?, ..., add x? for calculating address - ldr s?, ..., ldr s? for loading data - fmul v?, ..., fmul v? for scaling */ - for (size_t ur = 0; ur < cnt;) { - // Calculating address - for (size_t i = ur; - i < cnt && i - ur < max_cnt_per_loop; i++) - add_imm(addr_reg[i], - scale_arg == scale_arg_t::SRC - ? x_ptr_src_scale_off - : x_ptr_dst_scale_off, - offt_list[i] * stype_sz_, X_TMP); - // Loading data - for (size_t i = ur; - i < cnt && i - ur < max_cnt_per_loop; i++) - ldr(SReg(vec_reg[i]), ptr(addr_reg[i])); - // Scaling - for (size_t i = ur; - i < cnt && i - ur < max_cnt_per_loop; i++) { - VReg4S v(idx_list[i]); - fmul(v, v, VReg4S(vec_reg[i])); - } - ur += std::min(cnt, max_cnt_per_loop); - } - } -#undef DUMMY_IDX_ - }; - - /* xmm[0] <-- src_scales * xmm[0] */ - apply_scales( - xmm_src_scales_, scale_arg_t::SRC, prb_.src_scale_type); - - /* xmm[0] <-- beta * dst + xmm[0] */ - assert(prb_.beta == 0.f || prb_.beta == 1.f); - if (prb_.beta == 1.f) { - int ur = 0; - int tmp_ur = 0; - while (ur < reg_unroll) { - int count = 0; - - do { - add_imm(x_tmp_vec[count++], x_ptr_out_off, - o_off[ur] * otype_sz_, X_DEFAULT_ADDR); - ur += ur_step; - } while (ur < reg_unroll && count < (x_tmp_vec_size / 2)); - - assert(static_cast(count) <= z_tmp_vec.size()); - - if (prb_.otype == f32) { - /* addss: dest[31:0] <- src1[31:0] + src2[31:0] - dset[MAXVL-1:32] (Unmodified) */ - for (int i = 0; i < count; i++) { - ld1(VReg4S(z_tmp_vec[i].getIdx())[0], - ptr(x_tmp_vec[i])); - } - for (int i = 0; i < count; i++) { - SReg s {tmp_vec_idx[i]}; - fadd(s, s, SReg(tmp_ur + ur_step * i)); - } - for (int i = 0; i < count; i++) { - mov(VReg4S(tmp_ur)[0], VReg4S(tmp_vec_idx[i])[0]); - tmp_ur += ur_step; - } - } else { - for (int i = 0; i < count; i++) { - if (prb_.otype == s32) { - ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); - } else if (utils::one_of( - prb_.otype, data_type::s8, u8)) { - ldr(BReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); - } else { - assert(!"unsupported o_type"); - } - cvt2ps(tmp_vec_idx[i], 1, prb_.otype); - } - for (int i = 0; i < count; i++) { - VReg4S v(tmp_ur); - fadd(v, v, VReg4S(tmp_vec_idx[i])); - tmp_ur += ur_step; - } - } - } - } - - /* dst <-- dst_scales * xmm[0] */ - apply_scales( - xmm_dst_scales_, scale_arg_t::DST, prb_.dst_scale_type); - } - - /* dst zero point application */ - if (prb_.req_dst_zp) { - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - const auto xmm = VReg4S(ur); - if (interim_f32) - fadd(xmm, xmm, xmm_dst_zp_); - else - add(xmm, xmm, xmm_dst_zp_); - } - } - - /* adjust scale application */ - if (prb_.scale_adjust != 1.f) { - dup(xmm_tmp_, reg_scale_adjust_); - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - fmul(VReg4S(ur), VReg4S(ur), xmm_tmp_); - } - } - - if (need_saturation) { - init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, X_TMP_0, f32, - prb_.otype, compensation_needed_); - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - saturate_f32(VReg4S(ur), xmm_zero_, xmm_saturation_ubound_, - prb_.otype, P_ALL_ONE, compensation_needed_); - } - - // reset back xmm_zero_ if needed. - if (compensation_needed_ && (prb_.req_src_zp || prb_.req_dst_zp)) - uni_clear(VReg(xmm_zero_.getIdx())); - } - - if (compensation_needed_) { - const uint32_t xmm_begin = 9; - const uint32_t xmm_end = 11; - uint32_t xmm_id = xmm_begin; - const auto get_temp_xmm = [&] { - const Xbyak_aarch64::VReg temp {xmm_id++}; - - if (xmm_id > xmm_end) { xmm_id = xmm_begin; } - - return temp; - }; - if (can_store_xmm) { - enum class comp_load_type_t { bcast, load, gather }; - - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - - bool all_ip_padding_one = true; - bool all_ip_padding_zero = true; - for (int r = ur; r < ur + ur_step; r++) { - if (zero_padding[r] != 1) - all_ip_padding_one = false; - else - all_ip_padding_zero = false; - } - if (all_ip_padding_one) continue; - - comp_load_type_t comp_load_type = comp_load_type_t::bcast; - - for (int r = ur + 1; r < ur + ur_step; ++r) - if (c_off[r] != c_off[r - 1] + 0) { - comp_load_type = comp_load_type_t::load; - break; - } - - if (comp_load_type == comp_load_type_t::bcast - && all_ip_padding_zero) { - frinti(xmm_compensation, VReg4S(ur)); - fcvtzs(xmm_compensation, xmm_compensation); - addv(SReg(xmm_compensation.getIdx()), xmm_compensation); - addv(SReg(xmm_compensation.getIdx()), xmm_compensation); - const auto comp_addr = c_addr(c_off[ur]); - const auto xmm_tmp_ = get_temp_xmm().s4; - ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); - add(xmm_tmp_, xmm_tmp_, xmm_compensation); - str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); - continue; - } - - if (comp_load_type == comp_load_type_t::load) - for (int r = ur + 1; r < ur + ur_step; ++r) - if (c_off[r] != c_off[r - 1] + 1) { - comp_load_type = comp_load_type_t::gather; - break; - } - - if (comp_load_type == comp_load_type_t::load - && all_ip_padding_zero) { - const auto xmm_reorder_result = VReg4S(ur); - const auto comp_addr = c_addr(c_off[ur]); - frinti(xmm_compensation, xmm_reorder_result); - fcvtzs(xmm_compensation, xmm_compensation); - const auto xmm_tmp_ = get_temp_xmm().s4; - ldr(QReg(xmm_tmp_.getIdx()), ptr(comp_addr)); - add(xmm_compensation, xmm_compensation, xmm_tmp_); - str(QReg(xmm_compensation.getIdx()), ptr(comp_addr)); - continue; - } - - frinti(xmm_compensation, VReg4S(ur)); - fcvtzs(xmm_compensation, xmm_compensation); - for (int r = ur; r < ur + ur_step; ++r) { - if (zero_padding[r] == 0 || !tail_processing) { - mov(W_TMP_0, xmm_compensation[r % 4]); - const auto comp_addr = c_addr(c_off[r]); - ldr(W_TMP_1, ptr(comp_addr)); - add(W_TMP_0, W_TMP_0, W_TMP_1); - str(W_TMP_0, ptr(comp_addr)); - } - } - } - } else { - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - if (zero_padding[ur] == 0 || !tail_processing) { - const auto comp_addr = c_addr(c_off[ur]); - frinti(xmm_compensation, VReg4S(ur)); - fcvtzs(xmm_compensation, xmm_compensation); - const auto xmm_tmp_ = get_temp_xmm().s4; - ld1(xmm_tmp_, ptr(comp_addr)); - add(xmm_compensation, xmm_compensation, xmm_tmp_); - st1(VReg(xmm_compensation.getIdx()).s[0], - ptr(comp_addr)); - } - } - } - } - - for (int ur = 0; ur < reg_unroll; ur += ur_step) { - if (prb_.req_src_zp || prb_.req_dst_zp) { - const bool use_store_masks = !store_masks.empty(); - if (use_store_masks) { - const auto mask = (~store_masks[ur / ur_step]) & 0xF; - switch (mask) { - case 0x0: - /* Do nothing */ - break; - case 0x1: ins(VReg4S(ur)[0], xmm_zero_[0]); break; - case 0x2: ins(VReg4S(ur)[1], xmm_zero_[1]); break; - case 0x3: - ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); - break; - case 0x4: ins(VReg4S(ur)[2], xmm_zero_[2]); break; - case 0x5: - ins(VReg4S(ur)[0], xmm_zero_[0]); - ins(VReg4S(ur)[2], xmm_zero_[2]); - break; - case 0x6: - ins(VReg4S(ur)[1], xmm_zero_[1]); - ins(VReg4S(ur)[2], xmm_zero_[2]); - break; - case 0x7: - ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); - ins(VReg4S(ur)[2], xmm_zero_[2]); - break; - case 0x8: ins(VReg4S(ur)[3], xmm_zero_[3]); break; - case 0x9: - ins(VReg4S(ur)[0], xmm_zero_[0]); - ins(VReg4S(ur)[3], xmm_zero_[3]); - break; - case 0xa: - ins(VReg4S(ur)[1], xmm_zero_[1]); - ins(VReg4S(ur)[3], xmm_zero_[3]); - break; - case 0xb: - ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); - ins(VReg4S(ur)[3], xmm_zero_[3]); - break; - case 0xc: - ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); - break; - case 0xd: - ins(VReg4S(ur)[0], xmm_zero_[0]); - ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); - break; - case 0xe: - ins(VReg4S(ur)[1], xmm_zero_[1]); - ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); - break; - case 0xf: movi(VReg16B(ur), 0); break; - default: assert(!"unreachable"); - } - } - } - if (prb_.otype != f32) - cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype); - - store(o_addr(o_off[ur]), VReg(ur), ur_step * otype_sz_); - } - } - - static bool interim_f32_needed(const prb_t &prb, bool compensation_needed) { - using namespace data_type; - bool ret = utils::one_of(f32, prb.itype, prb.otype) - || prb.src_scale_type != scale_type_t::NONE - || prb.dst_scale_type != scale_type_t::NONE || prb.beta != 0.f - || ((prb.req_src_zp || prb.req_dst_zp) - ? !(prb.itype == s32 && prb.otype == s32) - : false) - || (prb.itype != f32 && compensation_needed) - || prb.scale_adjust != 1.f; - return ret; - } - - void process_unroll_generic( - const int ndims, int len, const bool tail_processing) { - assert(IMPLICATION(prb_.nodes[0].tail_size > 0, - len == static_cast(prb_.nodes[0].n) - || len == static_cast(prb_.nodes[0].tail_size))); - - const int blk = 8; - - int i_off[2 * blk] = {0}; - int o_off[2 * blk] = {0}; - int s_off[2 * blk] = {0}; - int c_off[2 * blk] = {0}; - - int curr = 0; // will switch between 0 and 1 - - const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_); - - if (prb_.req_src_zp) { - add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0); - ld1r(xmm_src_zp_, ptr(X_DEFAULT_ADDR)); - if (interim_f32) scvtf(xmm_src_zp_, xmm_src_zp_); - } - if (prb_.req_dst_zp) { - add_imm(X_DEFAULT_ADDR, PARAM(dst_zp), X_TMP_0); - ld1r(xmm_dst_zp_, ptr(X_DEFAULT_ADDR)); - if (interim_f32) scvtf(xmm_dst_zp_, xmm_dst_zp_); - } - - for (int off = 0; off < len; off += blk) { - const int reg_unroll = nstl::min(off + blk, len) - off; - int zero_padding[blk] = {0}; - const auto curr_blk = curr * blk; - - /* compute offsets and tail*/ - for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { - const int ur_c = curr_blk + ur; - const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur - const bool is_tail - = off + ur >= static_cast(prb_.nodes[0].tail_size); - step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p], - c_off[ur_p], i_off[ur_c], o_off[ur_c], s_off[ur_c], - c_off[ur_c]); - if (tail_processing && is_tail) zero_padding[ur] = 1; - } - - process_unroll_generic_step(reg_unroll, i_off + curr_blk, - o_off + curr_blk, s_off + curr_blk, c_off + curr_blk, - zero_padding, tail_processing); - - curr = 1 - curr; - } - } - - void compute_ker( - const int ndims, const int len_unroll, const bool tail_processing) { - bool optimized = false; - optimized = optimized || process_direct_copy(ndims, len_unroll) - || process_direct_copy(ndims, len_unroll) - || process_unroll_tr8x8(ndims, len_unroll); - if (!optimized) - process_unroll_generic(ndims, len_unroll, tail_processing); - } - - void loop_begin(Label &l, XReg reg_cnt, int len) { - mov(reg_cnt, len); - L(l); - } - - void check_if_this_is_last_chunk(const XReg reg_curr_chunk, int node_id) { - // Chunks are backwards numered i.e: - // [0] -> [node_size] - // [1] -> [node_size - 1] - // ... - // [node_size - 1] -> [1] - - // It is done like this, because it is easier to decrement counter - // and check if it is equal to zero than increment and check - // if it is equal to node_size. - static constexpr int64_t last_chunk = 1; - cmp(reg_curr_chunk, last_chunk); - } - - void zero_dst_memory(const int bytes_to_zeroing) { - static constexpr int num_of_bytes_in_xmm = 128 / 8; - - const int xmms_to_zeroing - = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).quot; - const int tail_to_zeroing - = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).rem; - - movi(xmm_tmp_, 0); - - if (xmms_to_zeroing > 0) { - Label loop; - - mov(X_TMP_4, xmms_to_zeroing); - L(loop); - str(QReg(xmm_tmp_.getIdx()), ptr(o_addr(0))); - add_imm(reg_off_out_, reg_off_out_, num_of_bytes_in_xmm, X_TMP_0); - add_imm(x_ptr_out_off, x_ptr_out_off, num_of_bytes_in_xmm, X_TMP_0); - subs(X_TMP_4, X_TMP_4, 1); - b(NE, loop); - } - - if (tail_to_zeroing) mov_imm(W_TMP_4, 0); - for (int i = 0; i < tail_to_zeroing; i++) - strb(W_TMP_4, ptr(o_addr(i, false))); - - // Restore dst offset to initial value - if (xmms_to_zeroing > 0) { - sub_imm(reg_off_out_, reg_off_out_, - num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0); - sub_imm(x_ptr_out_off, x_ptr_out_off, - num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0); - } - } - - void finalize_tail_loop(int i_step, int o_step, int s_step, int c_step, - const int curr_node_id) { - static constexpr int empty_chunk_info = -1; - - mov(X_TMP_0, empty_chunk_info); - str(X_TMP_0, ptr(data_chunk_addr(curr_node_id))); - - const int padded_area = prb_.nodes[curr_node_id].n - - prb_.nodes[curr_node_id].tail_size; - - if (prb_.nodes[curr_node_id].is_zero_pad_needed) { - int num_of_zero_padded_values = padded_area; - for (int i = curr_node_id - 1; i >= 0; i--) { - num_of_zero_padded_values *= prb_.nodes[i].n; - } - - const int bytes_to_zeroing = num_of_zero_padded_values * otype_sz_; - zero_dst_memory(bytes_to_zeroing); - } - - // This function is called by loop_end. At the end - // of loop_end is section that is responsible for - // restoring offset values. Restoring is based on - // len value which is equal to prb.nodes[x].n. - // If fill_zero_padded_area is called then it means - // offsets were shifted prb.nodes[x].tail_size times. - // Therefore, this function has to shift offsets by - // zero pad area. - add_imm(reg_off_in_, reg_off_in_, padded_area * i_step * itype_sz_, - X_TMP_0); - add_imm(reg_off_out_, reg_off_out_, padded_area * o_step * otype_sz_, - X_TMP_0); - add_imm(x_ptr_in_off, x_ptr_in_off, padded_area * i_step * itype_sz_, - X_TMP_0); - add_imm(x_ptr_out_off, x_ptr_out_off, padded_area * o_step * otype_sz_, - X_TMP_0); - if (prb_.src_scale_type == scale_type_t::MANY) - add_imm(x_ptr_src_scale_off, x_ptr_src_scale_off, - padded_area * s_step * stype_sz_, X_TMP_0); - if (prb_.dst_scale_type == scale_type_t::MANY) - add_imm(x_ptr_dst_scale_off, x_ptr_dst_scale_off, - padded_area * s_step * stype_sz_, X_TMP_0); - - if (compensation_needed_) { - add_imm(reg_off_comp_, reg_off_comp_, - padded_area * c_step * sizeof(int32_t), X_TMP_0); - add_imm(x_ptr_comp_off, x_ptr_comp_off, - padded_area * c_step * sizeof(int32_t), X_TMP_0); - } - } - - void loop_end(Label &l, XReg reg_cnt, int len, int i_step, int o_step, - int s_step, int c_step, const int curr_node_id) { - add_imm(reg_off_in_, reg_off_in_, i_step * itype_sz_, X_TMP_0); - add_imm(reg_off_out_, reg_off_out_, o_step * otype_sz_, X_TMP_0); - add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz_, X_TMP_0); - add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz_, X_TMP_0); - - if (prb_.src_scale_type == scale_type_t::MANY) - add_imm(x_ptr_src_scale_off, x_ptr_src_scale_off, - s_step * stype_sz_, X_TMP_0); - if (prb_.dst_scale_type == scale_type_t::MANY) - add_imm(x_ptr_dst_scale_off, x_ptr_dst_scale_off, - s_step * stype_sz_, X_TMP_0); - - if (compensation_needed_) { - add_imm(reg_off_comp_, reg_off_comp_, c_step * sizeof(int32_t), - X_TMP_0); - add_imm(x_ptr_comp_off, x_ptr_comp_off, c_step * sizeof(int32_t), - X_TMP_0); - } - - subs(reg_cnt, reg_cnt, 1); - b(NE, l); - - if (prb_.tail(curr_node_id) != 0) { - Label if_end; - - // On the stack should be an information if node - // was processed with tail or not. - ldr(X_TMP_0, post_ptr(X_SP, X_TMP_0.getBit() / 8)); - - cmp(X_TMP_0, with_tail_info_); - b(NE, if_end); - finalize_tail_loop(i_step, o_step, s_step, c_step, curr_node_id); - L(if_end); - } - - // Restore offset to initial values. It means before - // loop execution. - sub_imm(reg_off_in_, reg_off_in_, len * i_step * itype_sz_, X_TMP_0); - sub_imm(reg_off_out_, reg_off_out_, len * o_step * otype_sz_, X_TMP_0); - sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz_, X_TMP_0); - sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz_, - X_TMP_0); - - if (prb_.src_scale_type == scale_type_t::MANY) - sub_imm(x_ptr_src_scale_off, x_ptr_src_scale_off, - len * s_step * stype_sz_, X_TMP_0); - if (prb_.dst_scale_type == scale_type_t::MANY) - sub_imm(x_ptr_dst_scale_off, x_ptr_dst_scale_off, - len * s_step * stype_sz_, X_TMP_0); - if (compensation_needed_) { - sub_imm(reg_off_comp_, reg_off_comp_, - len * c_step * sizeof(int32_t), X_TMP_0); - sub_imm(x_ptr_comp_off, x_ptr_comp_off, - len * c_step * sizeof(int32_t), X_TMP_0); - } - } - - void compute_blk_ker(const simple_impl_desc_t &desc) { - static constexpr bool with_tail_processing = true; - Label no_last_chunk, end_label; - int omp_ndims = prb_.full_ndims - prb_.ndims; - - if (prb_.nodes[0].tail_size > 0) { - if (!prb_.nodes[0].is_parent_empty()) { - const int parent_node_id = prb_.nodes[0].parent_node_id; - ldr(X_TMP_0, ptr(data_chunk_addr(parent_node_id))); - check_if_this_is_last_chunk(X_TMP_0, parent_node_id); - b(NE, no_last_chunk); - } - - const int len_unroll = desc.tail_len_unroll > 0 - ? desc.tail_len_unroll - : desc.len_unroll; - compute_ker(omp_ndims, len_unroll, with_tail_processing); - b(end_label); - } - - L(no_last_chunk); - compute_ker(omp_ndims, desc.len_unroll, !with_tail_processing); - L(end_label); - } - - void create_loops(const simple_impl_desc_t &desc, - const std::array ®_cnt, int jit_loop) { - assert(jit_loop <= ndims_jit_loop_max); - - if (jit_loop > 0) { - const int nfu = desc.ndims_full_unroll; - const int unroll_factor - = jit_loop == 1 ? desc.len_last_dim_unroll : 1; - const int curr_node_id = nfu + (jit_loop - 1); - const int parent_node_id = prb_.nodes[curr_node_id].parent_node_id; - const int tail_size = prb_.tail(curr_node_id) / unroll_factor; - const int node_size = prb_.n(curr_node_id) / unroll_factor; - const XReg reg_loop_cnt = reg_cnt[jit_loop - 1]; - const bool curr_node_has_tail = prb_.tail(curr_node_id) != 0; - Label loop, if_no_tail, if_end; - - if (curr_node_has_tail) { - const size_t reg_bytes = X_TMP_0.getBit() / 8; - if (prb_.nodes[curr_node_id].is_parent_empty()) { - mov(reg_loop_cnt, tail_size); - // Put info that node is being processed with tail. - mov(X_TMP_0, with_tail_info_); - str(X_TMP_0, pre_ptr(X_SP, -reg_bytes)); - } else { - ldr(X_TMP_0, ptr(data_chunk_addr(parent_node_id))); - check_if_this_is_last_chunk(X_TMP_0, parent_node_id); - b(NE, if_no_tail); - mov(reg_loop_cnt, tail_size); - // Put info that node is being processed with tail. - mov(X_TMP_0, with_tail_info_); - str(X_TMP_0, pre_ptr(X_SP, -reg_bytes)); - b(if_end); - - L(if_no_tail); - mov(reg_loop_cnt, node_size); - // Put info that node is being processed without tail. - mov(X_TMP_0, without_tail_info_); - str(X_TMP_0, pre_ptr(X_SP, -reg_bytes)); - L(if_end); - } - } - - if (prb_.is_tail_in_one_of_child_nodes(curr_node_id)) { - if (!curr_node_has_tail) { - mov(reg_loop_cnt, node_size); - str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); - } - L(loop); - if (!prb_.nodes[curr_node_id].is_parent_empty()) { - Label if_no_tail_in_child_node; - ldr(X_TMP_0, ptr(data_chunk_addr(parent_node_id))); - check_if_this_is_last_chunk(X_TMP_0, parent_node_id); - b(NE, if_no_tail_in_child_node); - str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); - L(if_no_tail_in_child_node); - } else { - str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); - } - } else if (curr_node_has_tail) { - L(loop); - } else { - loop_begin(loop, reg_loop_cnt, node_size); - } - - create_loops(desc, reg_cnt, jit_loop - 1); - - loop_end(loop, reg_loop_cnt, node_size, - prb_.is(curr_node_id) * unroll_factor, - prb_.os(curr_node_id) * unroll_factor, - prb_.ss(curr_node_id) * unroll_factor, - prb_.cs(curr_node_id) * unroll_factor, curr_node_id); - } else { - compute_blk_ker(desc); - } - } - - bool simple_impl() { - simple_impl_desc_t d; - if (!simple_impl_desc_init(prb_, &d)) return false; - - eor(reg_off_in_, reg_off_in_, reg_off_in_); - eor(reg_off_out_, reg_off_out_, reg_off_out_); - - if (prb_.src_scale_type == scale_type_t::MANY) - mov(x_ptr_src_scale_off, reg_ptr_src_scales_); - if (prb_.dst_scale_type == scale_type_t::MANY) - mov(x_ptr_dst_scale_off, reg_ptr_dst_scales_); - - if (compensation_needed_) - eor(reg_off_comp_, reg_off_comp_, reg_off_comp_); - - std::array reg_cnt({{x15, x14, x13}}); - - const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; - create_loops(d, reg_cnt, n_jit_loops); - - return true; - } - - void impl() { - if (simple_impl()) return; - assert(!"no implementation available"); - } - -#define UNROLL_INST(inst, reg, ...) \ - for (size_t i = startIdx; i < startIdx + regNum; i++) { \ - reg tmp(i); \ - inst(__VA_ARGS__); \ - } -#define UNROLL_INST2(inst, ...) \ - for (size_t i = startIdx; i < startIdx + regNum; i++) \ - inst(__VA_ARGS__); - - void cvt_z_s32_f32(const size_t startIdx, const size_t regNum) { - UNROLL_INST(scvtf, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - } - - void cvt_v_s32_f32(const size_t startIdx, const size_t regNum) { - UNROLL_INST(scvtf, VReg4S, tmp, tmp); - } - - void cvt_z_f32_s32(const size_t startIdx, const size_t regNum) { - UNROLL_INST(frinti, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - UNROLL_INST(fcvtzs, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - } - - void cvt_v_f32_s32(const size_t startIdx, const size_t regNum) { - UNROLL_INST(frinti, VReg4S, tmp, tmp); - UNROLL_INST(fcvtzs, VReg4S, tmp, tmp); - } - - void cvt_v_f32_bf16(const size_t startIdx, const size_t regNum) { - UNROLL_INST2(bfcvtn, VReg4H(i), VReg4S(i)); - } - - void cvt_v_bf16_fp32(const size_t startIdx, const size_t regNum) { - UNROLL_INST2(shll, VReg4S(i), VReg4H(i), 16); - } - - void cvt_v_f16_f32(const size_t startIdx, const size_t regNum) { - UNROLL_INST2(fcvtl, VReg4S(i), VReg4H(i)); - } - - void cvt_v_f32_f16(const size_t startIdx, const size_t regNum) { - UNROLL_INST2(fcvtn, VReg4H(i), VReg4S(i)); - } - - void cvt_z_s8_s32(const size_t startIdx, const size_t regNum) { - cvt_z_b_s(startIdx, regNum); - UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - } - - void cvt_v_s8_s32(const size_t startIdx, const size_t regNum) { - UNROLL_INST(sxtl, VReg, tmp.h8, tmp.b8); - UNROLL_INST(sxtl, VReg, tmp.s4, tmp.h4); - } - - void cvt_z_s8_f32(const size_t startIdx, const size_t regNum) { - cvt_z_b_s(startIdx, regNum); - cvt_z_s32_f32(startIdx, regNum); - } - - void cvt_v_s8_f32(const size_t startIdx, const size_t regNum) { - cvt_v_b_s(startIdx, regNum); - cvt_v_s32_f32(startIdx, regNum); - } - - void cvt_z_b_s(const size_t startIdx, const size_t regNum) { - assert(z_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < z_tmp7.getIdx()); - - dup(z_tmp7.b, 0); - UNROLL_INST(zip1, ZRegB, tmp, tmp, z_tmp7.b); - UNROLL_INST(zip1, ZRegH, tmp, tmp, z_tmp7.h); - } - - void cvt_v_b_s(const size_t startIdx, const size_t regNum) { - assert(v_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < v_tmp7.getIdx()); - - mov_imm(W_TMP_0, 0); - dup(v_tmp7.b16, W_TMP_0); - UNROLL_INST(zip1, VReg16B, tmp, tmp, v_tmp7.b16); - UNROLL_INST(zip1, VReg8H, tmp, tmp, v_tmp7.h8); - } - - void cvt_z_u8_s32(const size_t startIdx, const size_t regNum) { - cvt_z_b_s(startIdx, regNum); - UNROLL_INST(uxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); - } - - void cvt_v_u8_s32(const size_t startIdx, const size_t regNum) { - UNROLL_INST(uxtl, VReg, tmp.h8, tmp.b8); - UNROLL_INST(uxtl, VReg, tmp.s4, tmp.h4); - } - - void cvt_z_s32_s8(const size_t startIdx, const size_t regNum) { - assert(z_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < z_tmp7.getIdx()); - - dup(z_tmp7.s, 0); - UNROLL_INST2(smin, ZRegS(i), 127); - UNROLL_INST2(smax, ZRegS(i), -128); - UNROLL_INST(uzp1, ZRegH, tmp, tmp, z_tmp7.h); - UNROLL_INST(uzp1, ZRegB, tmp, tmp, z_tmp7.b); - } - - void cvt_v_s32_s8(const size_t startIdx, const size_t regNum) { - assert(v_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < v_tmp7.getIdx()); - - mov_imm(W_TMP_0, 127); - dup(v_tmp7.s4, W_TMP_0); - mov_imm(W_TMP_0, -128); - UNROLL_INST2(smin, VReg4S(i), VReg4S(i), v_tmp7.s4); - dup(v_tmp7.s4, W_TMP_0); - UNROLL_INST2(smax, VReg4S(i), VReg4S(i), v_tmp7.s4); - mov_imm(W_TMP_0, 0); - dup(v_tmp7.s4, W_TMP_0); - UNROLL_INST(uzp1, VReg8H, tmp, tmp, v_tmp7.h8); - UNROLL_INST(uzp1, VReg16B, tmp, tmp, v_tmp7.b16); - } - - void cvt_z_u8_s8(const size_t startIdx, const size_t regNum) { - UNROLL_INST2(umin, ZRegB(i), 127); - } - - void cvt_v_u8_s8(const size_t startIdx, const size_t regNum) { - assert(v_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < v_tmp7.getIdx()); - - mov_imm(W_TMP_0, 127); - dup(v_tmp7.b16, W_TMP_0); - UNROLL_INST(umin, VReg16B, tmp, tmp, v_tmp7.b16); - } - - void cvt_z_u32_u8(const size_t startIdx, const size_t regNum) { - UNROLL_INST2(umin, ZRegS(i), 255); - UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp); - UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp); - } - - void cvt_v_u32_u8(const size_t startIdx, const size_t regNum) { - assert(v_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < v_tmp7.getIdx()); - - mov_imm(W_TMP_0, 255); - dup(v_tmp7.s4, W_TMP_0); - UNROLL_INST(umin, VReg4S, tmp, tmp, v_tmp7.s4); - UNROLL_INST(uzp1, VReg8H, tmp, tmp, tmp); - UNROLL_INST(uzp1, VReg16B, tmp, tmp, tmp); - } - - void cvt_z_s32_u8(const size_t startIdx, const size_t regNum) { - assert(z_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < z_tmp7.getIdx()); - - dupm(z_tmp7.s, 255); - UNROLL_INST2(smax, ZRegS(i), 0); - UNROLL_INST2(smin, ZRegS(i), P_ALL_ONE / T_m, z_tmp7.s); - UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp); - UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp); - UNROLL_INST2(mov, ZRegB(i), P_NOT_128 / T_m, 0); - } - - void cvt_v_s32_u8(const size_t startIdx, const size_t regNum) { - assert(v_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < v_tmp7.getIdx()); - - mov_imm(W_TMP_0, 0); - dup(v_tmp7.s4, W_TMP_0); - mov_imm(W_TMP_0, 255); - UNROLL_INST(smax, VReg4S, tmp, tmp, v_tmp7.s4); - dup(v_tmp7.s4, W_TMP_0); - UNROLL_INST(smin, VReg4S, tmp, tmp, v_tmp7.s4); - UNROLL_INST(uzp1, VReg8H, tmp, tmp, tmp); - UNROLL_INST(uzp1, VReg16B, tmp, tmp, tmp); - } - - void cvt_z_s8_u8(const size_t startIdx, const size_t regNum) { - UNROLL_INST2(smax, ZRegB(i), 0); - } - - void cvt_v_s8_u8(const size_t startIdx, const size_t regNum) { - assert(v_tmp7.getIdx() < startIdx - || startIdx + regNum - 1 < v_tmp7.getIdx()); - - mov_imm(W_TMP_0, 0); - dup(v_tmp7.b16, W_TMP_0); - UNROLL_INST(smax, VReg16B, tmp, tmp, v_tmp7.b16); - } -#undef UNROLL_INST -#undef UNROLL_INST - - jit_uni_reorder_kernel_f32_t(const desc_t &desc) - : kernel_t(desc), isa_(get_max_cpu_isa()) { - assert(!utils::one_of(isa_, isa_undef, isa_all)); - itype_sz_ = data_type_size(prb_.itype); - otype_sz_ = data_type_size(prb_.otype); - stype_sz_ = sizeof(float); - } - - void generate() override { - using namespace Xbyak_aarch64::util; - uint64_t sveLen = get_sve_length(); - Label end_of_kernel; - - preamble(); - - if (prb_.src_scale_type == scale_type_t::COMMON) { - add_imm(X_DEFAULT_ADDR, PARAM(src_scales), X_TMP_1); - ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); - ld1r(xmm_src_scales_, ptr(X_TMP_0)); - } else if (prb_.src_scale_type == scale_type_t::MANY) { - add_imm(X_DEFAULT_ADDR, PARAM(src_scales), X_TMP_0); - ldr(reg_ptr_src_scales_, ptr(X_DEFAULT_ADDR)); - } - - if (prb_.dst_scale_type == scale_type_t::COMMON) { - add_imm(X_DEFAULT_ADDR, PARAM(dst_scales), X_TMP_1); - ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); - ld1r(xmm_dst_scales_, ptr(X_TMP_0)); - } else if (prb_.dst_scale_type == scale_type_t::MANY) { - add_imm(X_DEFAULT_ADDR, PARAM(dst_scales), X_TMP_0); - ldr(reg_ptr_dst_scales_, ptr(X_DEFAULT_ADDR)); - } - - if (compensation_needed_) { - add_imm(X_DEFAULT_ADDR, PARAM(compensation_scratch), X_TMP_0); - ldr(reg_ptr_comp_, ptr(X_DEFAULT_ADDR)); - } - if (prb_.scale_adjust == 0.5f) { mov(reg_scale_adjust_, 0x3f000000); } - add_imm(X_TMP_0, PARAM(in), X_TMP_2); - add_imm(X_TMP_1, PARAM(out), X_TMP_2); - ldr(reg_ptr_in_, ptr(X_TMP_0)); - ldr(reg_ptr_out_, ptr(X_TMP_1)); - - if (sveLen) { /* SVE is available. */ - ptrue(p_lsb_256.b, VL32); - ptrue(p_lsb_128.b, VL16); - ptrue(p_lsb_64.b, VL8); - } - - bool is_tail_in_drv_dims = false; - for (int i = prb_.ndims; i < prb_.full_ndims; i++) - if (prb_.nodes[i].tail_size > 0) { - is_tail_in_drv_dims = true; - break; - } - - if (is_tail_in_drv_dims) { - Label reorder_kernel; - add_imm(X_DEFAULT_ADDR, TAIL_PARAM(skip_kernel_execution), X_TMP_0); - ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); - cmp(X_TMP_0, static_cast(true)); - b(EQ, end_of_kernel); - - add_imm(X_DEFAULT_ADDR, TAIL_PARAM(zeroing_data), X_TMP_0); - ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); - cmp(X_TMP_0, static_cast(false)); - b(EQ, reorder_kernel); - // If zeroing data is set then all dst memory - // will be zeroed and nothing more will be done. - int bytes_to_zeroing = otype_sz_; - for (int i = 0; i < prb_.ndims; i++) { - bytes_to_zeroing *= prb_.nodes[i].n; - } - eor(reg_off_out_, reg_off_out_, reg_off_out_); - mov(x_ptr_out_off, reg_ptr_out_); - zero_dst_memory(bytes_to_zeroing); - b(end_of_kernel); - L(reorder_kernel); - } - - if (can_do_tr8x8()) { - dup(ymm_zero_, 0); - } else { - movi(xmm_zero_, 0); - } - - impl(); - - L(end_of_kernel); - postamble(); - } - - ~jit_uni_reorder_kernel_f32_t() override = default; - -#undef TAIL_PARAM -#undef PARAM - -private: - static constexpr int64_t with_tail_info_ = static_cast(true); - static constexpr int64_t without_tail_info_ = static_cast(false); - - int itype_sz_; - int otype_sz_; - int stype_sz_; - - const cpu_isa_t isa_; - - const XReg reg_ptr_in_ = x6; - const XReg reg_ptr_out_ = x2; - const XReg reg_ptr_src_scales_ = x1; - const XReg reg_ptr_dst_scales_ = x12; - const XReg reg_ptr_comp_ = x3; - const WReg reg_scale_adjust_ = w5; - - const XReg reg_off_in_ = x8; - const XReg reg_off_out_ = x9; - const XReg reg_off_comp_ = x11; - - /* X_TMP is required to set address to - x_tmp_vec(X_TMP_0 - X_TMP_4). */ - XReg X_TMP = x20; - - VReg4S xmm_src_scales_ = v15.s; - VReg4S xmm_dst_scales_ = v11.s; - VReg4S xmm_zero_ = v14.s; - ZRegS ymm_zero_ = z14.s; - VReg4S xmm_tmp_ = v12.s; - const VReg4S xmm_src_zp_ = v9.s; - const VReg4S xmm_dst_zp_ = v10.s; - const VReg4S xmm_compensation = v8.s; - VReg4S xmm_saturation_ubound_ = v12.s; - ZRegS ymm_saturation_ubound_ = z12.s; - - /* Note: x22 - x28 are already used as temporal registgers - in jit_generator.hpp. - x_ptr_(in|out|scale|comp)_off keeps (base + offset) address. */ - XReg x_ptr_in_off = reg_ptr_in_; - XReg x_ptr_out_off = reg_ptr_out_; - XReg x_ptr_comp_off = reg_ptr_comp_; - XReg x_ptr_src_scale_off = x19; - XReg x_ptr_dst_scale_off = x29; - - /* Caution: Chose predicate registers not used by x64's implementation. */ - PReg p_lsb_256 = p7; - PReg p_lsb_128 = p6; - PReg p_lsb_64 = p4; - PReg p_tmp0 = p5; - - const std::vector tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27}; - VReg v_tmp0 = v20; - ZReg z_tmp0 = z20; - ZReg z_tmp1 = z21; - ZReg z_tmp2 = z22; - ZReg z_tmp3 = z23; - ZReg z_tmp4 = z24; - ZReg z_tmp5 = z25; - ZReg z_tmp6 = z26; - ZReg z_tmp7 = z27; - VReg v_tmp7 = v27; - - const std::vector z_tmp_vec - = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7}; - constexpr static int z_tmp_vec_size = 8; -}; - -// Seperate class for no unroll/threading burden -struct jit_single_blk_kernel_t : public jit_generator { - DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_single_blk_kernel) - static bool applicable(const prb_t &p) { - - using namespace data_type; - - bool ok = p.ndims >= 2 && mayiuse(sve_256) - && p.src_scale_type == scale_type_t::NONE - && p.dst_scale_type == scale_type_t::NONE - && utils::one_of(p.itype, f32) && utils::one_of(p.otype, f32) - && utils::everyone_is(0, p.ioff, p.ooff) && p.beta == 0.f - && prb_has_small_strides(p); - if (!ok) return false; - - int64_t n0 = p.nodes[0].n; - auto i0 = p.nodes[0].is; - auto o0 = p.nodes[0].os; - int64_t n1 = p.nodes[1].n; - auto i1 = p.nodes[1].is; - auto o1 = p.nodes[1].os; - - /* - * for a transpose of plain to 8c case, nodes would be like: - * n is os - * m 1 8 - * 8 m 1 - * or - * 8 m 1 - * m 1 8 - */ - ok = (utils::one_of(n0, 8, 16, 32, 64) - || utils::one_of(n1, 8, 16, 32, 64)) - && ((i0 == 1 && o1 == 1 && n0 == i1 && o0 == n1) - || (o0 == 1 && i1 == 1 && n0 == o1 && i0 == n1)); - if (!ok) return false; - - // Do not handle transpose of dimensions other than last 2 - for (int i = 2; i < p.ndims; ++i) { - if (p.nodes[i].is != p.nodes[i].os) { - ok = false; - break; - } - } - - return ok; - } - - jit_single_blk_kernel_t(const tr::prb_t &prb) - : jit_generator() - , prb_(prb) - , itype_sz_(data_type_size(prb_.itype)) - , otype_sz_(data_type_size(prb_.otype)) - , block_sz(prb.nodes[0].n) {} - - void generate() override { - auto input_stride - = prb_.nodes[0].is != 1 ? prb_.nodes[0].is : prb_.nodes[1].is; - auto output_stride - = prb_.nodes[0].os != 1 ? prb_.nodes[0].os : prb_.nodes[1].os; - - Label tail_processing; - - const auto load_zp = [&](const ZRegS ymm_zp, const XReg reg_zp) { - dup(ymm_zp, WReg(reg_zp.getIdx())); - scvtf(ymm_zp, P_ALL_ONE / T_m, ymm_zp); - }; - - set_preg(p_tmp2.s, 4, X_TMP_0, X_TMP_1); - rev(p_tmp1.s, p_tmp2.s); - - preamble(); - - if (prb_.req_src_zp) load_zp(ymm_src_zp, reg_src_zp); - - if (prb_.req_dst_zp) load_zp(ymm_dst_zp, reg_dst_zp); - - cmp(reg_ptr_tail, true); - b(EQ, tail_processing); - - if (block_sz == 8) { - gen_ker8x8(0, 0, input_stride, output_stride, 8, 8); - block_sz = 8; - } else if (block_sz == 16) { - gen_ker16x16_in_8x8(0, 0, input_stride, output_stride); - block_sz = 16; - } else if (block_sz == 32) { - gen_ker32x32_in_16x16(0, 0, input_stride, output_stride); - block_sz = 32; - } else if (block_sz == 64) { - gen_ker64x64_in_32x32(0, 0, input_stride, output_stride); - block_sz = 64; - } else { - assert(!"unimplemented"); - } - - postamble(); - - L(tail_processing); - - if (block_sz == 8) { - auto i_tail = input_stride % 8 != 0 ? input_stride % 8 : 8; - auto o_tail = output_stride % 8 != 0 ? output_stride % 8 : 8; - if (i_tail != o_tail) { - auto t_mask = i_tail == 8 ? o_tail : i_tail; - gen_setmask(t_mask); - gen_ker8x8(0, 0, input_stride, output_stride, i_tail, o_tail); - } - } else if (block_sz == 16) { - auto i_tail = input_stride % 16 != 0 ? input_stride % 16 : 16; - auto o_tail = output_stride % 16 != 0 ? output_stride % 16 : 16; - if (i_tail != o_tail) { - auto t_mask = i_tail == 16 ? o_tail : i_tail; - t_mask %= 8; - if (t_mask != 0) gen_setmask(t_mask); - gen_ker16x16_in_8x8( - 0, 0, input_stride, output_stride, i_tail, o_tail); - } - } else if (block_sz == 32) { - auto i_tail = input_stride % 32 != 0 ? input_stride % 32 : 32; - auto o_tail = output_stride % 32 != 0 ? output_stride % 32 : 32; - if (i_tail != o_tail) { - auto t_mask = i_tail == 32 ? o_tail : i_tail; - t_mask %= 8; - if (t_mask != 0) gen_setmask(t_mask); - gen_ker32x32_in_16x16( - 0, 0, input_stride, output_stride, i_tail, o_tail); - } - } else if (block_sz == 64) { - auto i_tail = input_stride % 64 != 0 ? input_stride % 64 : 64; - auto o_tail = output_stride % 64 != 0 ? output_stride % 64 : 64; - if (i_tail != o_tail) { - auto t_mask = i_tail == 64 ? o_tail : i_tail; - t_mask %= 8; - if (t_mask != 0) gen_setmask(t_mask); - gen_ker64x64_in_32x32( - 0, 0, input_stride, output_stride, i_tail, o_tail); - } - } else { - assert(!"unimplemented"); - } - - postamble(); - } - - void gen_loadu(const ZRegS ymm, const XReg &addr, int size) { - QReg xmm(ymm.getIdx()); - switch (size) { - case 32: ld1w(ymm, p_lsb_256 / T_z, ptr(addr)); break; - case 16: ldr(xmm, ptr(addr)); break; - default: assert(!"unreachable"); - } - } - - void gen_storeu(const XReg &addr, const ZRegS ymm, int size) { - QReg xmm(ymm.getIdx()); - switch (size) { - case 32: st1w(ymm, p_lsb_256, ptr(addr)); break; - case 16: str(xmm, ptr(addr)); break; - default: assert(!"unreachable"); - } - } - - void gen_maskloadu( - const ZRegS ymm, const XReg &addr, const PReg mask, int size) { - switch (size) { - case 32: - case 16: ld1w(ymm, mask / T_z, ptr(addr)); break; - default: assert(!"unreachable"); - } - } - - void gen_maskstoreu( - const XReg &addr, const ZRegS ymm, const PReg mask, int size) { - switch (size) { - case 32: - case 16: st1w(ymm, mask, ptr(addr)); break; - default: assert(!"unreachable"); - } - } - - // Register allocation xmm0~11 - void gen_transpose_8x8() { - const uint64_t sveLen = get_sve_length(); - constexpr int lane = 8; - -#if 0 - /* Debug code - z0: 7, 6, 5, 4, 3, 2, 1, 0 - z1: 15, 14, 13, 12, 11, 10, 9, 8 - ... - z17: 63, 62, 61, 60, 59, 58, 57, 56 - */ - ptrue(P_ALL_ONE.b); - ptrue(P_TMP.s, VL8); - not_(P_TMP.b, P_ALL_ONE/T_z, P_TMP.b); - index(z0.s, 0, 1); - mov(z0.s, P_TMP/T_m, 0); - mov(z_tmp_vec[0].s, 8); - mov(z_tmp_vec[0].s, P_TMP/T_m, 0); - for(uint32_t i=1; i nChw()C - // or nChw()C -> nchw - void gen_setmask(int mask) { set_preg(p_mask.s, mask, x_tmp_0, x_tmp_1); } - - // TODO: Mark parameter with type information - // XXX: ! - // offset in byte offset - // stride in element number - // - // Gen specific 8x8 transform respect to certain tail condition - void gen_tr8x8(int i_off, int o_off, int input_stride, int output_stride, - int in_tail, int out_tail) { - - constexpr int lane = 8; - - if (in_tail == 0 || out_tail == 0) return; - - for (int i = 0; i < out_tail; ++i) { - if (in_tail != lane) { - add_imm(x_addr, reg_ptr_in_, - i_off + i * input_stride * itype_sz_, x_tmp_0); - gen_maskloadu(ZRegS(i), x_addr, p_mask, lane * itype_sz_); - } else { - add_imm(x_addr, reg_ptr_in_, - i_off + i * input_stride * itype_sz_, x_tmp_0); - gen_loadu(ZRegS(i), x_addr, lane * itype_sz_); - } - if (prb_.req_src_zp) { fsub(ZRegS(i), ZRegS(i), ymm_src_zp); } - } - - gen_transpose_8x8(); - - for (int i = 0; i < in_tail; ++i) { - if (prb_.req_dst_zp) { fadd(ZRegS(i), ZRegS(i), ymm_dst_zp); } - if (out_tail == lane) { - add_imm(x_addr, reg_ptr_out_, - o_off + i * output_stride * otype_sz_, x_tmp_0); - gen_storeu(x_addr, ZRegS(i), lane * otype_sz_); - } else { - add_imm(x_addr, reg_ptr_out_, - o_off + i * output_stride * otype_sz_, x_tmp_0); - gen_maskstoreu(x_addr, ZRegS(i), p_mask, lane * otype_sz_); - } - } - } - - // tail: 0 ~ 8 - // support: either in_tail or out_tail is not 8, but not both - void gen_ker8x8(int i_off, int o_off, int input_stride, int output_stride, - int in_tail, int out_tail) { - gen_tr8x8(i_off, o_off, input_stride, output_stride, in_tail, out_tail); - } - - void gen_ker16x16_in_8x8( - int i_off, int o_off, int input_stride, int output_stride) { - const auto lane = 16; - const auto sub_lane = lane / 2; - - i_off *= itype_sz_; - o_off *= otype_sz_; - - gen_tr8x8( - i_off, o_off, input_stride, output_stride, sub_lane, sub_lane); - gen_tr8x8(i_off + input_stride * sub_lane * itype_sz_, - o_off + sub_lane * otype_sz_, input_stride, output_stride, - sub_lane, sub_lane); - gen_tr8x8(i_off + sub_lane * itype_sz_, - o_off + output_stride * sub_lane * otype_sz_, input_stride, - output_stride, sub_lane, sub_lane); - gen_tr8x8(i_off + (input_stride * sub_lane + sub_lane) * itype_sz_, - o_off + (output_stride * sub_lane + sub_lane) * otype_sz_, - input_stride, output_stride, sub_lane, sub_lane); - } - - // tail can be 1 ~ 16, using sve2 for now - void gen_ker16x16_in_8x8(int i_off, int o_off, int input_stride, - int output_stride, int in_tail, int out_tail) { - constexpr auto lane = 16; - constexpr auto sub_lane = lane / 2; - auto tail = in_tail != lane ? in_tail : out_tail; - - const auto l_tail = tail < sub_lane ? tail : sub_lane; - const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; - - i_off *= itype_sz_; - o_off *= otype_sz_; - - if (tail == in_tail) { - gen_tr8x8(i_off, o_off, input_stride, output_stride, l_tail, - sub_lane); - gen_tr8x8(i_off + input_stride * sub_lane * itype_sz_, - o_off + sub_lane * otype_sz_, input_stride, output_stride, - l_tail, sub_lane); - gen_tr8x8(i_off + sub_lane * itype_sz_, - o_off + output_stride * sub_lane * otype_sz_, input_stride, - output_stride, u_tail, sub_lane); - gen_tr8x8(i_off + itype_sz_ * (input_stride * sub_lane + sub_lane), - o_off + otype_sz_ * (output_stride * sub_lane + sub_lane), - input_stride, output_stride, u_tail, sub_lane); - } else { - gen_tr8x8(i_off, o_off, input_stride, output_stride, sub_lane, - l_tail); - gen_tr8x8(i_off + input_stride * sub_lane * itype_sz_, - o_off + sub_lane * otype_sz_, input_stride, output_stride, - sub_lane, u_tail); - gen_tr8x8(i_off + sub_lane * itype_sz_, - o_off + output_stride * sub_lane * itype_sz_, input_stride, - output_stride, sub_lane, l_tail); - gen_tr8x8(i_off + itype_sz_ * (input_stride * sub_lane + sub_lane), - o_off + otype_sz_ * (output_stride * sub_lane + sub_lane), - input_stride, output_stride, sub_lane, u_tail); - } - } - - void gen_ker32x32_in_16x16( - int i_off, int o_off, int input_stride, int output_stride) { - - const auto lane = 32; - const auto sub_lane = lane / 2; - gen_ker16x16_in_8x8(i_off, o_off, input_stride, output_stride); - gen_ker16x16_in_8x8(i_off + sub_lane * input_stride, o_off + sub_lane, - input_stride, output_stride); - gen_ker16x16_in_8x8(i_off + sub_lane, o_off + output_stride * sub_lane, - input_stride, output_stride); - gen_ker16x16_in_8x8(i_off + input_stride * sub_lane + sub_lane, - o_off + output_stride * sub_lane + sub_lane, input_stride, - output_stride); - } - - void gen_ker32x32_in_16x16(int i_off, int o_off, int input_stride, - int output_stride, int in_tail, int out_tail) { - - constexpr auto lane = 32; - constexpr auto sub_lane = lane / 2; - auto tail = in_tail != lane ? in_tail : out_tail; - - const auto l_tail = tail < sub_lane ? tail : sub_lane; - const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; - - if (tail == in_tail) { - gen_ker16x16_in_8x8(i_off, o_off, input_stride, output_stride, - l_tail, sub_lane); - gen_ker16x16_in_8x8(i_off + sub_lane * input_stride, - o_off + sub_lane, input_stride, output_stride, l_tail, - sub_lane); - gen_ker16x16_in_8x8(i_off + sub_lane, - o_off + output_stride * sub_lane, input_stride, - output_stride, u_tail, sub_lane); - gen_ker16x16_in_8x8(i_off + input_stride * sub_lane + sub_lane, - o_off + output_stride * sub_lane + sub_lane, input_stride, - output_stride, u_tail, sub_lane); - } else { - gen_ker16x16_in_8x8(i_off, o_off, input_stride, output_stride, - sub_lane, l_tail); - gen_ker16x16_in_8x8(i_off + sub_lane * input_stride, - o_off + sub_lane, input_stride, output_stride, sub_lane, - u_tail); - gen_ker16x16_in_8x8(i_off + sub_lane, - o_off + output_stride * sub_lane, input_stride, - output_stride, sub_lane, l_tail); - gen_ker16x16_in_8x8(i_off + input_stride * sub_lane + sub_lane, - o_off + output_stride * sub_lane + sub_lane, input_stride, - output_stride, sub_lane, u_tail); - } - } - - void gen_ker64x64_in_32x32( - int i_off, int o_off, int input_stride, int output_stride) { - - const auto lane = 64; - const auto sub_lane = lane / 2; - gen_ker32x32_in_16x16(i_off, o_off, input_stride, output_stride); - gen_ker32x32_in_16x16(i_off + sub_lane * input_stride, o_off + sub_lane, - input_stride, output_stride); - gen_ker32x32_in_16x16(i_off + sub_lane, - o_off + output_stride * sub_lane, input_stride, output_stride); - gen_ker32x32_in_16x16(i_off + input_stride * sub_lane + sub_lane, - o_off + output_stride * sub_lane + sub_lane, input_stride, - output_stride); - } - - void gen_ker64x64_in_32x32(int i_off, int o_off, int input_stride, - int output_stride, int in_tail, int out_tail) { - constexpr auto lane = 64; - constexpr auto sub_lane = lane / 2; - auto tail = in_tail != lane ? in_tail : out_tail; - - const auto l_tail = tail < sub_lane ? tail : sub_lane; - const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; - - if (tail == in_tail) { - gen_ker32x32_in_16x16(i_off, o_off, input_stride, output_stride, - l_tail, sub_lane); - gen_ker32x32_in_16x16(i_off + sub_lane * input_stride, - o_off + sub_lane, input_stride, output_stride, l_tail, - sub_lane); - gen_ker32x32_in_16x16(i_off + sub_lane, - o_off + output_stride * sub_lane, input_stride, - output_stride, u_tail, sub_lane); - gen_ker32x32_in_16x16(i_off + input_stride * sub_lane + sub_lane, - o_off + output_stride * sub_lane + sub_lane, input_stride, - output_stride, u_tail, sub_lane); - } else { - gen_ker32x32_in_16x16(i_off, o_off, input_stride, output_stride, - sub_lane, l_tail); - gen_ker32x32_in_16x16(i_off + sub_lane * input_stride, - o_off + sub_lane, input_stride, output_stride, sub_lane, - u_tail); - gen_ker32x32_in_16x16(i_off + sub_lane, - o_off + output_stride * sub_lane, input_stride, - output_stride, sub_lane, l_tail); - gen_ker32x32_in_16x16(i_off + input_stride * sub_lane + sub_lane, - o_off + output_stride * sub_lane + sub_lane, input_stride, - output_stride, sub_lane, u_tail); - } - } - -private: - // 6 ~ 12 - constexpr static int xmm_save_start_from = 6; - constexpr static int xmm_width = 16; - - void preamble() { ptrue(p_lsb_256.b, VL32); } - - void postamble() { ret(); } - - const prb_t &prb_; - - int itype_sz_; - int otype_sz_; - int block_sz; - - XReg reg_ptr_in_ = abi_param1; - XReg reg_ptr_out_ = abi_param2; - XReg reg_ptr_tail = abi_param3; - XReg reg_src_zp = abi_param4; - XReg reg_dst_zp = abi_param5; - - /* Because the callee-saved registers are not restored blk_reorder, - the temporary registers (x9-x15) must be assigned. - Must be selected from the temporary registers (x9-x15). */ - XReg x_addr = x10; - XReg x_tmp_0 = x11; - XReg x_tmp_1 = x12; - - /* Avoid P_TMP(p7) in jit_generator.hpp. */ - PReg p_lsb_256 = p6; - PReg p_mask = p5; - PReg p_tmp1 = p4; - PReg p_tmp2 = p3; - - ZRegS ymm_tmp = z0.s; - ZRegS ymm_src_zp = z14.s; - ZRegS ymm_dst_zp = z15.s; - - const std::vector tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27}; - VReg v_tmp0 = v20; - ZReg z_tmp0 = z20; - ZReg z_tmp1 = z21; - ZReg z_tmp2 = z22; - ZReg z_tmp3 = z23; - ZReg z_tmp4 = z24; - ZReg z_tmp5 = z25; - ZReg z_tmp6 = z26; - ZReg z_tmp7 = z27; - VReg v_tmp7 = v27; - - const std::vector z_tmp_vec - = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7}; - constexpr static int z_tmp_vec_size = 8; -}; - -status_t kernel_t::desc_init( - kernel_t::desc_t &desc, const prb_t &prb, int ndims_ker_max) { - - desc.prb = prb; - desc.prb.ioff = desc.prb.ooff = 0; - - if (ndims_ker_max > prb.ndims) return status::invalid_arguments; - - auto ndims_ker_max_f = [&]() { - size_t cur_size = 1; - for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) - if (cur_size >= ker_prb_size_min) return d; - return prb.ndims; - }; - - if (ndims_ker_max <= 0) ndims_ker_max = ndims_ker_max_f(); - - /* traverse through kernel implementations */ - /* TODO: find a better way to do that... */ - desc.id = 0; - for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { - desc.prb.ndims = ndims_ker; - if (jit_uni_reorder_kernel_f32_t::applicable(desc.prb)) - return status::success; - } - - return status::unimplemented; -} - -kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { - switch (desc.id) { - case 0: return new jit_uni_reorder_kernel_f32_t(desc); - default: assert(!"unknown kernel id"); return nullptr; - } - - return nullptr; -} - -} // namespace tr - -static void prb_block_for_cache(tr::prb_t &prb) { - /* If strides for 0th and 1st nodes are cache friendly - * then one can altogether do away with blocking ! */ - static constexpr int num_elems_thr = 16; - const bool stride_cache_friendly - = ((prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > num_elems_thr) - || (prb.ndims > 1 && prb.nodes[1].is % num_elems_thr == 0 - && prb.nodes[1].n > num_elems_thr)) - && !prb.is_tail_present; - - // performance improvement for shapes with large inner-most dimension - const size_t L1_cache_sz - = size_t(3) * platform::get_per_core_cache_size(1) / 4; - const size_t itype_sz_ = data_type_size(prb.itype); - const size_t inner_block_sz = prb.nodes[0].n * itype_sz_; - const bool requires_inner_blocking = inner_block_sz > L1_cache_sz - // 'is_tail_present' is not supported for cache_blocking when - // asymmetric_comp is executed. - && IMPLICATION(prb.req_asymmetric_comp, !prb.is_tail_present); - - const bool cache_blocking_needed - = stride_cache_friendly || requires_inner_blocking; - if (!cache_blocking_needed) return; - - int unit_input_stride_idx = -1; - for (auto idx = 0; idx < prb.ndims; ++idx) { - if (prb.nodes[idx].is == 1) unit_input_stride_idx = idx; - } - - /* Re-prioritize the sequential read over sequential write: - * /-> [n0:is0:1][16n1:1:osk]... - * [n0:is0:1]...[nk:1:osk] --> or - * \-> [16n1:1:osk][n0:is0:1]... */ - if (unit_input_stride_idx != -1) { - const auto output_stride = prb.nodes[unit_input_stride_idx].os; - const auto num_elems = prb.nodes[unit_input_stride_idx].n; - - const bool split_needed = (num_elems > num_elems_thr) - && (num_elems % num_elems_thr == 0); - const int move_location = (output_stride % 4 != 0) ? 0 : 1; - if (split_needed) - prb_node_split(prb, unit_input_stride_idx, num_elems_thr); - - /* Because of cache-unfriendly nature of unit-output stride node, let - * us move unit-input stride node on or near front! */ - if (unit_input_stride_idx != move_location) - prb_node_move(prb, unit_input_stride_idx, move_location); - } - - /* Potentially, split the node with os=1 in two and pull in the node with - * is=1 between them for better cache reuse: - * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */ - if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) { - const auto num_elems = prb.nodes[0].n; - - const bool split_needed = (num_elems > num_elems_thr) - && (num_elems % num_elems_thr == 0); - if (split_needed) { - prb_node_split(prb, 0, num_elems_thr); - prb_node_move(prb, 1, 2); - - // Update node information - prb_node_dependency(prb); - - // heuristics - looping over the unrolled dims should maximize reuse - // of the already cached data; observation is choosing the smallest - // dim from the remaining (from 2 up to ndims) gives good results - constexpr int new_position = 2; - const auto dim_beg_it = std::begin(prb.nodes); - const auto dim_two_it = dim_beg_it + new_position; - const auto dim_last_it = dim_beg_it + prb.ndims; - const auto min_n_node_it = std::min_element(dim_two_it, dim_last_it, - [](const tr::node_t &lhs, const tr::node_t &rhs) { - return lhs.n < rhs.n; - }); - const auto min_idx = std::distance(dim_beg_it, min_n_node_it); - // check if min_idx node is parent of node with tail processing which - // is currently unsupported (i.e. tail processing can only be handled - // at the inner-most dimension) - bool inner_block_has_tail = false; - for (int idx = min_idx - 1; idx >= new_position; idx--) { - if (prb.nodes[idx].parent_node_id == min_idx) { - inner_block_has_tail = true; - break; - } - } - - if (min_idx > new_position && (!inner_block_has_tail)) - prb_node_move(prb, min_idx, new_position); - } - } -} - -/** finds the maximum number of dimension the kernel should process and - * optionally splits one of the dimension to achieve better balance between - * parallel driver and the kernel. */ -static void prb_thread_kernel_balance( - tr::prb_t &prb, int &ndims_ker_max, int nthr) { - size_t size_total = 1; - for (int d = 0; d < prb.ndims; ++d) - size_total *= prb.nodes[d].n; - - /* The general expression for size_drv_thr can be written as - * size_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) - * where FC and VC are fixed and variable costs respectively. - * Though for now, the below heuristic seems to be good enough */ - const size_t size_drv_thr = (nthr > 1) ? 16 * nthr : 1; - - /* size_drv_min is the minimal size for the parallel - * driver required for good parallelization */ - const size_t size_drv_min - = nstl::min(size_drv_thr, utils::div_up(size_total, 1024)); - - /* kdims -- # of dimensions processed by a kernel - * size_ker_cur -- product of the dimension processed by a kernel - * size_drv_cur -- product of the dimension processed by a driver */ - - int kdims = prb.ndims; - size_t size_drv_cur = 1; - for (; kdims > 1 && size_drv_cur < size_drv_min; --kdims) - size_drv_cur *= prb.nodes[kdims - 1].n; - - size_t size_ker_cur = 1; - for (int d = 0; d < kdims; ++d) - size_ker_cur *= prb.nodes[d].n; - - /* Initially kdims is chosen so that size_drv_cur >= size_drv_min. - * - * It might happen that for chosen kdims the size_ker_cur is too small - * (less than tr::ker_prb_size_min). In that case try to split the - * innermost driver dimension into two, to increase size_ker_cur. */ - const bool want_borrow_ker_from_drv = kdims < prb.ndims - && size_ker_cur < tr::ker_prb_size_min - && size_drv_cur > size_drv_min; - if (want_borrow_ker_from_drv) { - /* size_want_borrow is the minimal size, so that: - * o) size_ker_cur * size_want_borrow >= tr::ker_prb_size_min - * o) current innermost driver dimension is divisible by - * size_want_borrow (so that we can evenly split that - * dimension into two) - * - * In the worst case the minimal size_want_borrow is equal - * to the innermost driver dimension itself. In that case - * we will sacrifice it in favor of kernel (is it fine?). */ - size_t size_want_borrow - = utils::div_up(tr::ker_prb_size_min, size_ker_cur); - for (; prb.nodes[kdims].n % size_want_borrow; ++size_want_borrow) - ; - - if (size_want_borrow != prb.nodes[kdims].n) - prb_node_split(prb, kdims, size_want_borrow); - kdims += 1; - } - - /* On the other hand it might happen that for chosen kdims - * the size_drv_cur is too small (less than size_drv_min). In that case - * try to split the outermost kernel dimension into two, to increase - * size_drv_cur. */ - const bool want_borrow_drv_from_ker = size_ker_cur > tr::ker_prb_size_min - && size_drv_cur < size_drv_min; - if (want_borrow_drv_from_ker) { - size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur); - for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow) - ; - - if (size_want_borrow != prb.nodes[kdims - 1].n) - prb_node_split( - prb, kdims - 1, prb.nodes[kdims - 1].n / size_want_borrow); - } - - ndims_ker_max = kdims; - - if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { - DEBUG({ - verbose_printf( - verbose_t::debuginfo, "split: %s\n", prb_dump(prb).c_str()); - verbose_printf(verbose_t::debuginfo, "ndims_ker_max = %d\n", - ndims_ker_max); - }); - } -} - -status_t jit_uni_reorder_t::pd_t::init( - engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { - CHECK(cpu_reorder_pd_t::init(engine, src_engine, dst_engine)); - - CHECK(init_scratchpad()); - - return status::success; -} - -status_t jit_uni_reorder_t::pd_t::init_scratchpad() { - auto scratchpad = scratchpad_registry().registrar(); - - const bool compensation_needed - = prb_.req_s8s8_comp || prb_.req_asymmetric_comp; - if (compensation_needed) { - const memory_desc_wrapper od(dst_md()); - const auto G = with_groups_ ? od.padded_dims()[0] : 1; - const auto N = od.padded_dims()[with_groups_ ? 1 : 0]; - static constexpr int cache_line_size = 16; - const auto wspace_per_thr_size - = utils::rnd_up(G * N, cache_line_size) * sizeof(int32_t); - - const auto compensation_reduce_size = wspace_per_thr_size * nthr_; - - // Every thread gets its own scratchpad space for each N. - scratchpad.template book( - memory_tracking::names::key_reorder_space, - compensation_reduce_size); - } - - if (!attr()->scales_.has_default_values(DNNL_ARG_DST)) { - const memory_desc_wrapper input_d(src_md()); - int mask = attr()->scales_.get_mask(DNNL_ARG_DST); - get_D_values(input_d, mask, nullptr, &D_mask_, nullptr); - if (D_mask_ > 1) { - scratchpad.template book( - memory_tracking::names::key_reorder_precomputed_dst_scales, - D_mask_); - } - } - - return status::success; -} - -status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, - const memory_desc_t *src_md, engine_t *dst_engine, - const memory_desc_t *dst_md) { - if (!impl::is_dense_format_kind({src_md, dst_md})) - return status::unimplemented; - auto prb = tr::prb_t(); - - status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); - if (prb_init_status != status::success) return prb_init_status; - - prb_block_for_cache(prb); - DEBUG({ - verbose_printf( - verbose_t::debuginfo, "cache: %s\n", prb_dump(prb).c_str()); - }); - - int ndims_ker_max {}; - int nthr = dnnl_get_max_threads(); - prb_thread_kernel_balance(prb, ndims_ker_max, nthr); - - if (prb.is_tail_present) prb_node_dependency(prb); - - tr::kernel_t::desc_t ker_desc; - status_t ker_init_status - = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); - if (ker_init_status != status::success) return ker_init_status; - - const int ndims_driver = prb.ndims - ker_desc.prb.ndims; - if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) - return status::unimplemented; - - DEBUG({ - verbose_printf(verbose_t::debuginfo, "ker : %s\n", - prb_dump(ker_desc.prb).c_str()); - }); - - auto _pd = make_unique_pd( - attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); - if (_pd == nullptr) return status::out_of_memory; - - _pd->nthr_ = nthr; - _pd->prb_ = prb; - _pd->with_groups_ - = prb.compensation_mask == tr::prb_t::comp_mask_with_groups; - CHECK(_pd->init(engine, src_engine, dst_engine)); - _pd->ker_desc_ = ker_desc; - CHECK(_pd->init_scratchpad_md()); - - return safe_ptr_assign(*reorder_pd, _pd.release()); -} - -void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out, - const float *src_scales, const float *dst_scales, int src_zp, - int dst_zp, int32_t *compensation_scratch) const { - const tr::prb_t &prb = pd()->prb_; - - tr::call_param_t base_params; - base_params.in = in; - base_params.out = out; - base_params.src_scales = src_scales; - base_params.dst_scales = dst_scales; - base_params.src_zp = src_zp; - base_params.dst_zp = dst_zp; - base_params.compensation_scratch = compensation_scratch; - - if (prb.is_tail_present) { - tr::tail_call_param_t tail_params; - tail_params.base_params = base_params; - - static constexpr int omp_ndims = 0; - fill_curr_data_chunks(prb, off, nullptr, omp_ndims, tail_params); - - (*kernel_)(&tail_params); - } else { - (*kernel_)(&base_params); - } -} - -void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, - const char *in, char *out, const float *src_scales, - const float *dst_scales, int src_zp, int dst_zp, - int32_t *compensation_scratch) const { - const tr::prb_t &prb = pd()->prb_; - const tr::node_t *ns = prb.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { - tr::call_param_t base_params; - base_params.in = in + d0 * ns[0].is * data_type_size(prb.itype); - base_params.out = out + d0 * ns[0].os * data_type_size(prb.otype); - base_params.src_scales = src_scales + d0 * ns[0].ss; - base_params.dst_scales = dst_scales + d0 * ns[0].ss; - base_params.src_zp = src_zp; - base_params.dst_zp = dst_zp; - base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs; - - if (prb.is_tail_present) { - tr::tail_call_param_t tail_params; - tail_params.base_params = base_params; - - static constexpr int omp_ndims = 1; - const ptrdiff_t omp_data_chunks[omp_ndims] = {d0}; - fill_curr_data_chunks( - prb, off, omp_data_chunks, omp_ndims, tail_params); - - (*kernel_)(&tail_params); - } else { - (*kernel_)(&base_params); - } - }); -} - -void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, - const char *in, char *out, const float *src_scales, - const float *dst_scales, int src_zp, int dst_zp, - int32_t *compensation_scratch) const { - const tr::prb_t &prb = pd()->prb_; - const tr::node_t *ns = prb.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d1, ptrdiff_t d0) { - tr::call_param_t base_params; - base_params.in = in - + (d0 * ns[0].is + d1 * ns[1].is) - * data_type_size(prb.itype); - base_params.out = out - + (d0 * ns[0].os + d1 * ns[1].os) - * data_type_size(prb.otype); - base_params.src_scales - = src_scales + d0 * ns[0].ss + d1 * ns[1].ss; - base_params.dst_scales - = dst_scales + d0 * ns[0].ss + d1 * ns[1].ss; - base_params.src_zp = src_zp; - base_params.dst_zp = dst_zp; - base_params.compensation_scratch - = compensation_scratch + d0 * ns[0].cs + d1 * ns[1].cs; - - if (prb.is_tail_present) { - tr::tail_call_param_t tail_params; - tail_params.base_params = base_params; - - static constexpr int omp_ndims = 2; - const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1}; - fill_curr_data_chunks( - prb, off, omp_data_chunks, omp_ndims, tail_params); - - (*kernel_)(&tail_params); - } else { - (*kernel_)(&base_params); - } - }); -} - -void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, - const char *in, char *out, const float *src_scales, - const float *dst_scales, int src_zp, int dst_zp, - int32_t *compensation_scratch) const { - const tr::prb_t &prb = pd()->prb_; - const tr::node_t *ns = prb.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, - (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { - tr::call_param_t base_params; - base_params.in = in - + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) - * data_type_size(prb.itype); - base_params.out = out - + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) - * data_type_size(prb.otype); - base_params.src_scales = src_scales + d0 * ns[0].ss - + d1 * ns[1].ss + d2 * ns[2].ss; - base_params.dst_scales = dst_scales + d0 * ns[0].ss - + d1 * ns[1].ss + d2 * ns[2].ss; - base_params.src_zp = src_zp; - base_params.dst_zp = dst_zp; - base_params.compensation_scratch = compensation_scratch - + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs; - - if (prb.is_tail_present) { - tr::tail_call_param_t tail_params; - tail_params.base_params = base_params; - - static constexpr int omp_ndims = 3; - const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1, d2}; - fill_curr_data_chunks( - prb, off, omp_data_chunks, omp_ndims, tail_params); - - (*kernel_)(&tail_params); - } else { - (*kernel_)(&base_params); - } - }); -} - -void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, - const char *in, char *out, const float *src_scales, - const float *dst_scales, int src_zp, int dst_zp, - int32_t *compensation_scratch) const { - const tr::prb_t &prb = pd()->prb_; - const tr::node_t *ns = prb.nodes + off; - for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, - (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, - [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { - tr::call_param_t base_params; - base_params.in = in - + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is - + d3 * ns[3].is) - * data_type_size(prb.itype); - base_params.out = out - + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os - + d3 * ns[3].os) - * data_type_size(prb.otype); - base_params.src_scales = src_scales + d0 * ns[0].ss - + d1 * ns[1].ss + d2 * ns[2].ss + d3 * ns[3].ss; - base_params.dst_scales = dst_scales + d0 * ns[0].ss - + d1 * ns[1].ss + d2 * ns[2].ss + d3 * ns[3].ss; - base_params.src_zp = src_zp; - base_params.dst_zp = dst_zp; - base_params.compensation_scratch = compensation_scratch - + d0 * ns[0].cs + d1 * ns[1].cs + d2 * ns[2].cs - + d3 * ns[3].cs; - - if (prb.is_tail_present) { - tr::tail_call_param_t tail_params; - tail_params.base_params = base_params; - - static constexpr int omp_ndims = 4; - const ptrdiff_t omp_data_chunks[omp_ndims] - = {d0, d1, d2, d3}; - fill_curr_data_chunks( - prb, off, omp_data_chunks, omp_ndims, tail_params); - - (*kernel_)(&tail_params); - } else { - (*kernel_)(&base_params); - } - }); -} - -void jit_uni_reorder_t::omp_driver(const char *in, char *out, - const float *src_scales, const float *dst_scales, int src_zp, - int dst_zp, const memory_tracking::grantor_t &scratchpad) const { - in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); - out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); - - DEBUG({ - verbose_printf(verbose_t::debuginfo, "prb : %s\n", - tr::prb_dump(pd()->prb_).c_str()); - }); - DEBUG({ - verbose_printf(verbose_t::debuginfo, "ker : %s\n", - tr::prb_dump(pd()->ker_desc_.prb).c_str()); - }); - - int ndims = pd()->prb_.ndims; - int ndims_ker = pd()->ker_desc_.prb.ndims; - const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; - const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; - const bool req_compensation = req_s8s8_comp || req_asymmetric_comp; - assert(ndims - ndims_ker <= ndims_driver_max); - - int32_t *compensation_reduce_scratch = scratchpad.template get( - memory_tracking::names::key_reorder_space); - - const memory_desc_wrapper od(pd()->dst_md()); - const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; - const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; - static constexpr int cache_line_size = 16; - const auto wspace_per_thr_size = utils::rnd_up(G * N, cache_line_size); - const auto wspace_per_thr_bytes = wspace_per_thr_size * sizeof(int32_t); - - if (ndims - ndims_ker == 0) { - if (req_compensation) - std::memset(compensation_reduce_scratch, 0, wspace_per_thr_bytes); - - omp_driver_0d(ndims_ker, in, out, src_scales, dst_scales, src_zp, - dst_zp, compensation_reduce_scratch); - } else { - parallel(pd()->nthr_, [&](const int ithr, const int nthr) { - int32_t *compensation_scratch = nullptr; - if (req_compensation) { - compensation_scratch = &compensation_reduce_scratch[ithr - * wspace_per_thr_size]; - std::memset(compensation_scratch, 0, wspace_per_thr_bytes); - } - - switch (ndims - ndims_ker) { - case 1: - omp_driver_1d(ithr, nthr, ndims_ker, in, out, src_scales, - dst_scales, src_zp, dst_zp, compensation_scratch); - break; - case 2: - omp_driver_2d(ithr, nthr, ndims_ker, in, out, src_scales, - dst_scales, src_zp, dst_zp, compensation_scratch); - break; - case 3: - omp_driver_3d(ithr, nthr, ndims_ker, in, out, src_scales, - dst_scales, src_zp, dst_zp, compensation_scratch); - break; - case 4: - omp_driver_4d(ithr, nthr, ndims_ker, in, out, src_scales, - dst_scales, src_zp, dst_zp, compensation_scratch); - break; - default: assert(!"unimplemented"); - } - }); - } - - //reduction of intermediate compensation results to the final output - if (req_compensation) { - const int nthr = ndims - ndims_ker == 0 ? 1 : pd()->nthr_; - reduce_compensation( - out, compensation_reduce_scratch, nthr, wspace_per_thr_size); - } -} - -void jit_uni_reorder_t::reduce_compensation(char *out, - const int32_t *compensation_reduce_scratch, const int nthr, - const dim_t wspace_per_thr_size) const { - - const memory_desc_wrapper od(pd()->dst_md()); - const size_t offset = od.size() - od.additional_buffer_size(); - - static constexpr auto comp_dt_size = sizeof(int32_t); - static constexpr int32_t comp_s8s8_shift = 128; - - // Note: We do not need to explicitly zero-out compensation buffer, as the - // per_thread buffers are already zeroed out in the padded area. - const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; - const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; - const auto GN = G * N; - const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; - const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; - const size_t zp_offset - = offset + (pd()->prb_.req_s8s8_comp ? GN * comp_dt_size : 0); - - parallel_nd(GN, [&](int idx) { - int32_t acc = 0; - for (int ithr = 0; ithr < nthr; ithr++) { - acc -= compensation_reduce_scratch[ithr * wspace_per_thr_size - + idx]; - } - if (req_s8s8_comp) { - int32_t *out_comp = reinterpret_cast(&out[offset]); - out_comp[idx] = comp_s8s8_shift * acc; - } - if (req_asymmetric_comp) { - int32_t *out_asym_comp - = reinterpret_cast(&out[zp_offset]); - out_asym_comp[idx] = acc; - } - }); -} - -void jit_uni_reorder_t::fill_curr_data_chunks(const tr::prb_t &prb, - const int off, const ptrdiff_t *omp_data_chunks, const int omp_ndims, - tr::tail_call_param_t &c) const { - // Chunks are backwards numered i.e: - // [0] -> [node_size] - // [1] -> [node_size - 1] - // ... - // [node_size - 1] -> [1] - - // It is done like this, because it is easier to decrement counter - // and check if it is equal to zero than increment and check - // if it is equal to node_size in jit kernel. - - static constexpr int64_t empty_chunk_info = -1; - static constexpr int64_t last_chunk = 1; - - for (int curr_node_id = prb.ndims - 1; curr_node_id >= 0; curr_node_id--) { - const int parent_node_id = prb.nodes[curr_node_id].parent_node_id; - const bool is_drv_processing_this_node - = curr_node_id >= off && curr_node_id <= off + omp_ndims - 1; - const bool is_tail_processing - = prb.is_tail_in_one_of_child_nodes(curr_node_id) - || prb.nodes[curr_node_id].tail_size > 0; - - if (is_drv_processing_this_node && is_tail_processing) { - const int inner_idx = curr_node_id - off; - assert(inner_idx < omp_ndims); - const int64_t node_size = prb.nodes[curr_node_id].tail_size > 0 - ? prb.nodes[curr_node_id].tail_size - : prb.nodes[curr_node_id].n; - const int64_t data_chunk = node_size - omp_data_chunks[inner_idx]; - - if (!prb.nodes[curr_node_id].is_parent_empty()) { - const bool is_parent_chunk_last - = c.curr_data_chunks[parent_node_id] == last_chunk; - c.curr_data_chunks[curr_node_id] - = is_parent_chunk_last ? data_chunk : empty_chunk_info; - c.zeroing_data = static_cast( - is_parent_chunk_last && data_chunk <= 0); - } else { - c.curr_data_chunks[curr_node_id] = data_chunk; - c.zeroing_data = static_cast(data_chunk <= 0); - } - c.skip_kernel_execution = static_cast(c.zeroing_data - && !prb.nodes[curr_node_id].is_zero_pad_needed); - if (c.zeroing_data || c.skip_kernel_execution) break; - } else - c.curr_data_chunks[curr_node_id] = empty_chunk_info; - } -} - -status_t jit_uni_reorder_t::init(engine_t *engine) { - CHECK(safe_ptr_assign(kernel_, tr::kernel_t::create(pd()->ker_desc_))); - return kernel_->create_kernel(); -} - -status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const { - const auto &scratchpad = ctx.get_scratchpad_grantor(); - auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); - auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); - DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); - DEFINE_ARG_SCALES_BUFFER(dst_scales_, DNNL_ARG_DST); - - const float *dst_scales = pd()->precompute_scales( - scratchpad, pd()->attr(), pd()->D_mask_, dst_scales_); - assert(dst_scales); - - DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM); - DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO); - - omp_driver(in, out, src_scales, dst_scales, src_zp, dst_zp, scratchpad); - - return status::success; -} - -status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, - engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, - const memory_desc_t *src_md, engine_t *dst_engine, - const memory_desc_t *dst_md) { - if (!impl::is_dense_format_kind({src_md, dst_md})) - return status::unimplemented; - auto prb = tr::prb_t(); - // For shapes with dimension greater than thres it is found that jit:uni is better that jit:blk - auto thres = 1920 * 4096; - auto src_d = memory_desc_wrapper(src_md); - auto prd = 1; - - for (int d = 0; d < src_d.ndims(); ++d) { - const auto dim = src_d.dims()[d]; - prd *= dim; - if (prd > thres) return status::unimplemented; - } - - status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); - if (prb_init_status != status::success) return prb_init_status; - // only uni_reorder supports tail processing now - // TODO: Add tail processing support in blk_reorder - if (prb.is_tail_present) return status::unimplemented; - - prb_tile_normalize(prb); - DEBUG({ - verbose_printf( - verbose_t::debuginfo, "tile : %s\n", prb_dump(prb).c_str()); - }); - - if (!tr::jit_single_blk_kernel_t::applicable(prb)) { - return status::unimplemented; - } - - auto _pd = make_unique_pd( - attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); - if (_pd == nullptr) return status::out_of_memory; - _pd->prb_ = prb; - CHECK(_pd->init(engine, src_engine, dst_engine)); - CHECK(_pd->init_scratchpad_md()); - - return safe_ptr_assign(*reorder_pd, _pd.release()); -} - -void jit_blk_reorder_t::pd_t::prb_tile_normalize(tr::prb_t &p) { - if (!utils::one_of(p.nodes[0].n, 8ul, 16ul, 32ul, 64ul) - && utils::one_of(p.nodes[1].n, 8ul, 16ul, 32ul, 64ul)) { - nstl::swap(p.nodes[0], p.nodes[1]); - } -} - -jit_blk_reorder_t::jit_blk_reorder_t(const pd_t *apd) : primitive_t(apd) {} -jit_blk_reorder_t::~jit_blk_reorder_t() = default; - -status_t jit_blk_reorder_t::init(engine_t *engine) { - kernel_ = utils::make_unique(pd()->prb_); - return kernel_->create_kernel(); -} - -status_t jit_blk_reorder_t::execute(const exec_ctx_t &ctx) const { - const auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); - auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); - DEFINE_ZERO_POINT_VALUE(src_zp, DNNL_ARG_FROM); - DEFINE_ZERO_POINT_VALUE(dst_zp, DNNL_ARG_TO); - - // kernel handle 2-dimension tiles, a tail is possible - auto &prb = this->pd()->prb_; - ptrdiff_t BH = 1; - for (int i = 2; i < prb.ndims; ++i) { - BH *= prb.nodes[i].n; - } - - auto block_sz = prb.n(0); - auto n1 = prb.n(1); - auto i1 = prb.is(1); - auto o1 = prb.os(1); - auto FL = (n1 + block_sz - 1) / block_sz; - auto bh_stride = BH == 1 ? 0 : prb.is(2); - - auto itype_sz_ = data_type_size(pd()->prb_.itype); - auto otype_sz_ = data_type_size(pd()->prb_.otype); - - parallel_nd(BH, FL, [&](dim_t bh, dim_t fl) { - auto fl_b = fl * block_sz; - auto bh_b = bh_stride * bh; - auto *i = in + (bh_b + fl_b * i1) * itype_sz_; - auto *o = out + (bh_b + fl_b * o1) * otype_sz_; - (*kernel_)(i, o, n1 - fl_b < block_sz, src_zp, dst_zp); - }); - - return status::success; -} - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl diff --git a/src/cpu/aarch64/jit_uni_reorder.hpp b/src/cpu/aarch64/jit_uni_reorder.hpp deleted file mode 100644 index 83ac55ed855..00000000000 --- a/src/cpu/aarch64/jit_uni_reorder.hpp +++ /dev/null @@ -1,314 +0,0 @@ -/******************************************************************************* -* Copyright 2018-2023 Intel Corporation -* Copyright 2020-2023 FUJITSU LIMITED -* Copyright 2022 Arm Ltd. and affiliates -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef CPU_AARCH64_JIT_UNI_REORDER_HPP -#define CPU_AARCH64_JIT_UNI_REORDER_HPP - -#include - -#include "common/c_types_map.hpp" -#include "common/type_helpers.hpp" - -#include "cpu/reorder/cpu_reorder_pd.hpp" - -namespace dnnl { -namespace impl { -namespace cpu { -namespace aarch64 { - -namespace tr { - -constexpr int max_ndims = DNNL_MAX_NDIMS; - -struct node_t { - static constexpr int64_t empty_field = -1; - - size_t n = 0; - size_t tail_size = 0; - int dim_id = empty_field; - int parent_node_id = empty_field; - bool is_zero_pad_needed = false; - ptrdiff_t is = 0; // input stride - ptrdiff_t os = 0; // output stride - ptrdiff_t ss = 0; // scale stride - ptrdiff_t cs = 0; // compensation stride - - bool is_dim_id_empty() const { return dim_id == empty_field; } - bool is_parent_empty() const { return parent_node_id == empty_field; } -}; - -enum class scale_type_t { NONE, COMMON, MANY }; - -struct prb_t { - /* The compensation mask value indicates how big an additional buffer should be. - * Possible values for reorder: - * 1) standard compensation = 1 = 0b01 - * 2) asymmetric compensation = 2 = 0b10 - * 3) compensation if tensor contains group = 3 = 0b11 */ - static constexpr int invalid_comp_mask = 0; - static constexpr int standard_comp_mask = 0b1; - static constexpr int asymmetric_comp_mask = 0b10; - static constexpr int comp_mask_with_groups - = standard_comp_mask + asymmetric_comp_mask; - - bool is_tail_in_one_of_child_nodes(int parent_node_id) const { - for (int i = parent_node_id; i >= 0; i--) { - if (nodes[i].parent_node_id == parent_node_id) { - if (nodes[i].tail_size != 0) - return true; - else - parent_node_id = i; - } - } - - return false; - } - - int tail(int d) const { - assert(d < ndims); - return static_cast(nodes[d].tail_size); - } - - int n(int d) const { - assert(d < ndims); - return static_cast(nodes[d].n); - } - int is(int d) const { - assert(d < ndims); - return static_cast(nodes[d].is); - } - int os(int d) const { - assert(d < ndims); - return static_cast(nodes[d].os); - } - int ss(int d) const { - assert(d < ndims); - return static_cast(nodes[d].ss); - } - - int cs(int d) const { - assert(d < ndims); - return static_cast(nodes[d].cs); - } - - data_type_t itype; - data_type_t otype; - int ndims; - node_t nodes[max_ndims]; - ptrdiff_t ioff; - ptrdiff_t ooff; - scale_type_t src_scale_type; - scale_type_t dst_scale_type; - float beta; - int full_ndims; - bool is_tail_present = false; - float scale_adjust = 1.f; - int compensation_mask = invalid_comp_mask; - bool req_s8s8_comp = false; - bool req_asymmetric_comp = false; - bool req_src_zp = false; - bool req_dst_zp = false; -}; - -status_t prb_init(prb_t &prb, const memory_desc_t &imd, - const memory_desc_t &omd, const primitive_attr_t *attr); - -/** sorts the problem nodes so that output strides come in ascending order */ -void prb_normalize(prb_t &p); - -/** fill parent node info for blocked nodes */ -void prb_node_dependency(prb_t &p); - -/** folds nodes together if possible */ -void prb_simplify(prb_t &p); - -/** splits the node dim into two of sizes n1 and n / n1 - * @warning n must be multiple of n1 */ -void prb_node_split(prb_t &p, int dim, size_t n1); - -/** swaps d0 and d1 nodes */ -void prb_node_swap(prb_t &p, int d0, int d1); - -/** moves node d0 to the d1 position. - * nodes (d0, d1] are shifted to the left if d0 < d1 or - * to the right if d0 > d1 */ -void prb_node_move(prb_t &p, int d0, int d1); - -/** dumps the problem to a string */ -std::string prb_dump(const prb_t &p); - -struct call_param_t { - const void *in = nullptr; - void *out = nullptr; - const float *src_scales = nullptr; - const float *dst_scales = nullptr; - int32_t src_zp = 0; - int32_t dst_zp = 0; - int32_t *compensation_scratch = nullptr; -}; - -// The additional structure is needed because -// using a data structure with tail processing -// data for non-tail cases reduces kernel -// performance. This is because there is too -// much data that has to be transferred to the kernel. -struct tail_call_param_t { - call_param_t base_params; - int64_t curr_data_chunks[DNNL_MAX_NDIMS] = {-1}; - int64_t zeroing_data = static_cast(false); - int64_t skip_kernel_execution = static_cast(false); -}; - -struct kernel_t { - struct desc_t { - int id; - prb_t prb; - }; - - kernel_t(const desc_t &desc) - : desc_(desc) - , compensation_needed_( - desc.prb.req_s8s8_comp || desc.prb.req_asymmetric_comp) {} - virtual void operator()(const call_param_t *c) const = 0; - virtual void operator()(const tail_call_param_t *c) const = 0; - virtual status_t create_kernel() = 0; - virtual ~kernel_t() {} - - /** inits kernel descriptor: - * desc -- kernel descriptor (output) - * prb -- transposition problem (input) - * ndims_ker_max -- limit the maximum number of dimensions kernel - * will process (optional, 0 -- no limitation) */ - static status_t desc_init( - desc_t &desc, const prb_t &prb, int ndims_ker_max = 0); - - /** creates kernel for the problem described in desc */ - static kernel_t *create(const desc_t &desc); - -protected: - const desc_t desc_; - const prb_t &prb_ = desc_.prb; - bool compensation_needed_ = false; -}; - -/* TODO: add trans_t class */ - -struct jit_single_blk_kernel_t; - -} // namespace tr - -struct jit_uni_reorder_t : public primitive_t { - using primitive_t::primitive_t; - struct pd_t : public cpu_reorder_pd_t { - using cpu_reorder_pd_t::cpu_reorder_pd_t; - - DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t); - - tr::prb_t prb_; - tr::kernel_t::desc_t ker_desc_; - int nthr_; - bool with_groups_ = false; - dim_t D_mask_ = 0; - - status_t init( - engine_t *engine, engine_t *src_engine, engine_t *dst_engine); - - private: - status_t init_scratchpad(); - static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, - const primitive_attr_t *attr, engine_t *src_engine, - const memory_desc_t *src_md, engine_t *dst_engine, - const memory_desc_t *dst_md); - - friend dnnl::impl::impl_list_item_t; - }; - - status_t init(engine_t *engine) override; - status_t execute(const exec_ctx_t &ctx) const override; - - enum { ndims_driver_max = 4 }; - -private: - void omp_driver_0d(int off, const char *in, char *out, - const float *src_scales, const float *dst_scales, int src_zp, - int dst_zp, int32_t *compensation_scratch) const; - void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, - const float *src_scales, const float *dst_scales, int src_zp, - int dst_zp, int32_t *compensation_scratch) const; - void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, - const float *src_scales, const float *dst_scales, int src_zp, - int dst_zp, int32_t *compensation_scratch) const; - void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, - const float *src_scales, const float *dst_scales, int src_zp, - int dst_zp, int32_t *compensation_scratch) const; - void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, - const float *src_scales, const float *dst_scales, int src_zp, - int dst_zp, int32_t *compensation_scratch) const; - - void omp_driver(const char *in, char *out, const float *src_scales, - const float *dst_scales, int src_zp, int dst_zp, - const memory_tracking::grantor_t &scratchpad) const; - - void fill_curr_data_chunks(const tr::prb_t &prb, const int off, - const ptrdiff_t *omp_data_chunks, const int omp_ndims, - tr::tail_call_param_t &c) const; - - void reduce_compensation(char *out, - const int32_t *compensation_reduce_scratch, const int nthr, - const dim_t wspace_per_thr_size) const; - - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - std::unique_ptr kernel_; -}; - -struct jit_blk_reorder_t : public primitive_t { - using primitive_t::primitive_t; - struct pd_t : public cpu_reorder_pd_t { - using cpu_reorder_pd_t::cpu_reorder_pd_t; - DECLARE_COMMON_PD_T("jit:blk", jit_blk_reorder_t); - - tr::prb_t prb_; - - private: - static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, - const primitive_attr_t *attr, engine_t *src_engine, - const memory_desc_t *src_md, engine_t *dst_engine, - const memory_desc_t *dst_md); - - // Swap last two nodes, put block 4, 8, 16 nodes to first - static void prb_tile_normalize(tr::prb_t &p); - friend dnnl::impl::impl_list_item_t; - }; - - status_t init(engine_t *engine) override; - status_t execute(const exec_ctx_t &ctx) const override; - - jit_blk_reorder_t(const pd_t *apd); - ~jit_blk_reorder_t(); - -private: - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - std::unique_ptr kernel_; -}; - -} // namespace aarch64 -} // namespace cpu -} // namespace impl -} // namespace dnnl - -#endif diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.cpp b/src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.cpp index 3d734f3b536..928497ed476 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.cpp @@ -1,6 +1,8 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2025 Arm Ltd. and affiliates +* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -19,6 +21,7 @@ #include "common/type_helpers.hpp" #include "common/utils.hpp" #include "cpu/aarch64/jit_generator.hpp" +#include "xbyak_aarch64/xbyak_aarch64/xbyak_aarch64_reg.h" #include "cpu/aarch64/matmul/brgemm_matmul_copy_utils.hpp" @@ -37,7 +40,7 @@ using namespace Xbyak_aarch64; #define LDR_IMM(reg, addr, off) \ { \ const uint64_t IMM12_MASK = ~uint64_t(0xfff); \ - if ((off & IMM12_MASK) == 0) { \ + if (((off) & IMM12_MASK) == 0) { \ ldr(reg, ptr(addr, off)); \ } else { \ add_imm(X_DEFAULT_ADDR, addr, off, X_TMP_0); \ @@ -48,7 +51,7 @@ using namespace Xbyak_aarch64; #define STR_IMM(reg, addr, off) \ { \ const uint64_t IMM12_MASK = ~uint64_t(0xfff); \ - if ((off & IMM12_MASK) == 0) { \ + if (((off) & IMM12_MASK) == 0) { \ str(reg, ptr(addr, off)); \ } else { \ add_imm(X_DEFAULT_ADDR, addr, off, X_TMP_0); \ @@ -63,7 +66,6 @@ struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t, jit_brgemm_matmul_copy_a_impl_t(const brgemm_matmul_conf_t *conf) : jit_brgemm_matmul_copy_a_t(conf) - , jit_generator() , typesize_(conf_->a_dt_sz) , tr_typesize_(conf_->tr_a_dt_sz) , vnni_granularity_(data_type_vnni_granularity(conf_->src_dt)) @@ -78,7 +80,9 @@ struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t, , vmm_copy_idx_(29) {} void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } - status_t create_kernel() override { return jit_generator::create_kernel(); } + status_t create_kernel() override { + return jit_generator::create_kernel(); + } private: using reg64_t = const Xbyak_aarch64::XReg; @@ -235,7 +239,6 @@ struct jit_brgemm_matmul_copy_a_transposed_impl_t jit_brgemm_matmul_copy_a_transposed_impl_t(const brgemm_matmul_conf_t *conf) : jit_brgemm_matmul_copy_a_t(conf) - , jit_generator() , typesize(conf_->a_dt_sz) , tr_typesize(conf_->tr_a_dt_sz) , src_stride(conf_->copy_A_src_stride) @@ -256,7 +259,9 @@ struct jit_brgemm_matmul_copy_a_transposed_impl_t } void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } - status_t create_kernel() override { return jit_generator::create_kernel(); } + status_t create_kernel() override { + return jit_generator::create_kernel(); + } private: using reg64_t = const Xbyak_aarch64::XReg; @@ -361,7 +366,6 @@ struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t, jit_brgemm_matmul_copy_b_int8_t(const brgemm_matmul_conf_t *conf) : jit_brgemm_matmul_copy_b_t(conf) - , jit_generator() , src_stride_(conf->wei_tag == format_tag::acbd ? conf->copy_B_wei_stride : conf->N * sizeof(int8_t)) @@ -371,7 +375,9 @@ struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t, , comp_acc_idx_(25) {} void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } - status_t create_kernel() override { return jit_generator::create_kernel(); } + status_t create_kernel() override { + return jit_generator::create_kernel(); + } protected: using reg64_t = const Xbyak_aarch64::XReg; @@ -501,7 +507,6 @@ struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, jit_brgemm_matmul_copy_b_f32_t(const brgemm_matmul_conf_t *conf) : jit_brgemm_matmul_copy_b_t(conf) - , jit_generator() , dt_in_(data_type::f32) , typesize_in_(types::data_type_size(dt_in_)) , src_stride_(conf_->wei_tag == acbd ? conf_->copy_B_wei_stride @@ -512,7 +517,9 @@ struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, } void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } - status_t create_kernel() override { return jit_generator::create_kernel(); } + status_t create_kernel() override { + return jit_generator::create_kernel(); + } private: using reg64_t = const Xbyak_aarch64::XReg; @@ -525,7 +532,6 @@ struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, const size_t typesize_in_; const size_t typesize_out_ = sizeof(float); dim_t src_stride_, tr_src_stride_; - const bool is_sve_256 = !mayiuse(sve_512); opmask_t kTail = p7; opmask_t kFFFF = p6; @@ -550,7 +556,7 @@ struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t, void jit_brgemm_matmul_copy_b_f32_t::copy_16_8_x_n_block( int nrows, int ncolumns) { - int n_blk_step = is_sve_256 ? 8 : 16; + int n_blk_step = get_sve_length() / typesize_in_; auto get_zmm = [](int reg_idx) { assert(reg_idx >= 0 && reg_idx < max_regs_available); @@ -581,7 +587,7 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_8_x_n_block( continue; } - const opmask_t curr_msk = zero_padding < n_blk_step ? kTail : kFFFF; + const opmask_t curr_msk = zero_padding < n_blk_step ? kTail : P_ALL_ONE; const int blk_idx = iter % max_regs_available; load(blk_idx, k, n, curr_msk); add_imm(X_DEFAULT_ADDR, reg_tr_src, tr_src_off, X_TMP_0); @@ -610,12 +616,13 @@ void jit_brgemm_matmul_copy_b_f32_t::compute_k_loop(int ncolumns) { L(K_end_label); }; - int k_unroll = is_sve_256 ? 8 : 16; + int k_unroll = get_sve_length() / typesize_in_; compute_uni_k_loop(k_unroll); compute_uni_k_loop(1); } void jit_brgemm_matmul_copy_b_f32_t::generate() { + preamble(); eor(zmm_zero.d, zmm_zero.d, zmm_zero.d); LDR_IMM(reg_src, param1, GET_OFF(src)); @@ -648,7 +655,6 @@ struct jit_brgemm_matmul_copy_b_transposed_t jit_brgemm_matmul_copy_b_transposed_t(const brgemm_matmul_conf_t *conf) : jit_brgemm_matmul_copy_b_t(conf) - , jit_generator() , typesize_(conf_->b_dt_sz) , tr_typesize_(conf_->tr_b_dt_sz) , vnni_granularity_(data_type_vnni_granularity(conf_->wei_dt)) @@ -665,13 +671,14 @@ struct jit_brgemm_matmul_copy_b_transposed_t , tr_src_stride_(conf_->LDB * vnni_granularity_ * tr_typesize_) {} void operator()(ctx_t *ctx) override { jit_generator::operator()(ctx); } - status_t create_kernel() override { return jit_generator::create_kernel(); } + status_t create_kernel() override { + return jit_generator::create_kernel(); + } private: using reg64_t = const Xbyak_aarch64::XReg; using reg32_t = const Xbyak_aarch64::WReg; using opmask_t = const Xbyak_aarch64::PReg; - using ZReg = const Xbyak_aarch64::ZReg; static constexpr bool is_sve256_ = isa == sve_256; static constexpr cpu_isa_t isa_ = isa; @@ -719,20 +726,20 @@ struct jit_brgemm_matmul_copy_b_transposed_t // Note: for the SVE256 implementation, reserve ZReg(8) and ZReg(9) as // temporary compute registers. - ZReg vmm_comp_mul = Xbyak_aarch64::ZReg(max_vmm_regs_ - 1); - ZReg vmm_comp_acc = Xbyak_aarch64::ZReg(max_vmm_regs_ - 2); - ZReg vmm_zp_a_neg_val = Xbyak_aarch64::ZReg(max_vmm_regs_ - 3); - ZReg vmm_s8s8_comp_acc = Xbyak_aarch64::ZReg(max_vmm_regs_ - 4); - ZReg vmm_all_bits_1 = Xbyak_aarch64::ZReg(max_vmm_regs_ - 5); - ZReg vmm_one_s32 = Xbyak_aarch64::ZReg(max_vmm_regs_ - 6); - - ZReg vmm_ones_words = ZReg(max_vmm_regs_ - 7); - ZReg vmm_dot_product_temp = ZReg(max_vmm_regs_ - 8); - - ZReg z_tmp_0 = ZReg(28); - ZReg z_tmp_1 = ZReg(29); - ZReg z_tmp_3 = ZReg(30); - ZReg z_tmp_2 = ZReg(27); + const ZReg vmm_comp_mul {max_vmm_regs_ - 1}; + const ZReg vmm_comp_acc {max_vmm_regs_ - 2}; + const ZReg vmm_zp_a_neg_val {max_vmm_regs_ - 3}; + const ZReg vmm_s8s8_comp_acc {max_vmm_regs_ - 4}; + const ZReg vmm_all_bits_1 {max_vmm_regs_ - 5}; + const ZReg vmm_one_s32 {max_vmm_regs_ - 6}; + const ZReg vmm_ones_words {max_vmm_regs_ - 7}; + const ZReg vmm_dot_product_temp {max_vmm_regs_ - 8}; + + const ZReg z_tmp_0 {28}; + const ZReg z_tmp_1 {29}; + const ZReg z_tmp_3 {30}; + const ZReg z_tmp_2 {27}; + PReg p_tmp_0 = p7; PReg p_02 = p8; PReg p_AA = p9; @@ -746,11 +753,11 @@ struct jit_brgemm_matmul_copy_b_transposed_t void kmovw(Xbyak_aarch64::PReg k, unsigned w) { assert(!"under construction"); - }; + } void kmovq(Xbyak_aarch64::PReg k, size_t q) { assert(!"under construction"); - }; + } ZReg src_vmm(int i) { assert(i >= 0 && i < n_blk_step_); @@ -1081,6 +1088,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::generate() { template struct jit_brgemm_matmul_copy_b_transposed_t; template struct jit_brgemm_matmul_copy_b_transposed_t; +template struct jit_brgemm_matmul_copy_b_transposed_t; status_t create_brgemm_matmul_copy_b( std::unique_ptr ©_ker, diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.hpp b/src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.hpp index f3a73b53cbc..365a61db661 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.hpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_copy_utils.hpp @@ -1,6 +1,8 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2025 Arm Ltd. and affiliates +* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -43,7 +45,7 @@ struct jit_brgemm_matmul_copy_b_t { jit_brgemm_matmul_copy_b_t(const brgemm_matmul_conf_t *conf) : conf_(conf) {} - virtual ~jit_brgemm_matmul_copy_b_t() {} + virtual ~jit_brgemm_matmul_copy_b_t() = default; const brgemm_matmul_conf_t *conf_; }; @@ -68,7 +70,7 @@ struct jit_brgemm_matmul_copy_a_t { jit_brgemm_matmul_copy_a_t(const brgemm_matmul_conf_t *conf) : conf_(conf) {} - virtual ~jit_brgemm_matmul_copy_a_t() {} + virtual ~jit_brgemm_matmul_copy_a_t() = default; const brgemm_matmul_conf_t *conf_; }; diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp b/src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp index bfb917afbbd..081dc33e39b 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_reorders.cpp @@ -1,6 +1,8 @@ /******************************************************************************* -* Copyright 2022-2023 Intel Corporation +* Copyright 2022 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2025 Arm Ltd. and affiliates +* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -15,6 +17,7 @@ *******************************************************************************/ #include "common/dnnl_thread.hpp" +#include "cpu/aarch64/cpu_isa_traits.hpp" #include "cpu/aarch64/matmul/brgemm_matmul_reorders.hpp" @@ -23,7 +26,7 @@ namespace impl { namespace cpu { namespace aarch64 { -status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init( +status_t brgemm_matmul_copy_reorder_t::pd_t::init( engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { using namespace status; using namespace format_tag; @@ -110,7 +113,10 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init( : brgemm_broadcast_t::none; matmul_conf_for_reorder_.has_zero_point_a = matmul_conf_for_reorder_.src_zp_type != brgemm_broadcast_t::none; - matmul_conf_for_reorder_.isa = (!mayiuse(sve_512)) ? sve_256 : sve_512; + + // asimd not supported, so we need >sve_128 + if (!mayiuse(sve_128)) return status::unimplemented; + matmul_conf_for_reorder_.isa = get_max_cpu_isa(); auto mask_ok = [&](bool check, int mask) { return IMPLICATION( @@ -128,9 +134,8 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::init( return status::success; } -status_t brgemm_matmul_matrix_B_reorder_t::pd_t::create( - reorder_pd_t **reorder_pd, engine_t *engine, - const primitive_attr_t *attr, engine_t *src_engine, +status_t brgemm_matmul_copy_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, const memory_desc_t *src_md, engine_t *dst_engine, const memory_desc_t *dst_md) { using namespace status; @@ -143,7 +148,7 @@ status_t brgemm_matmul_matrix_B_reorder_t::pd_t::create( return safe_ptr_assign(*reorder_pd, _pd.release()); } -status_t brgemm_matmul_matrix_B_reorder_t::execute_body( +status_t brgemm_matmul_copy_reorder_t::execute_body( const exec_ctx_t &ctx) const { using namespace utils; @@ -160,7 +165,7 @@ status_t brgemm_matmul_matrix_B_reorder_t::execute_body( = dst_d.size() - dst_d.additional_buffer_size(); const size_t s8s8_comp_size_bytes = kernel_conf.s8s8_compensation_required ? dst_d.additional_buffer_size( - memory_extra_flags::compensation_conv_s8s8) + memory_extra_flags::compensation_conv_s8s8) : 0; const size_t zp_comp_offset_bytes = comp_offset_bytes + s8s8_comp_size_bytes; @@ -178,63 +183,58 @@ status_t brgemm_matmul_matrix_B_reorder_t::execute_body( parallel_nd(kernel_conf.batch, div_up(kernel_conf.N, kernel_conf.N_blk), [&](dim_t batch, dim_t n_blk_idx) { - const auto n = n_blk_idx * kernel_conf.N_blk; - const bool is_N_tail = (kernel_conf.N - n < kernel_conf.N_blk); - auto ker_exec_ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); - ker_exec_ctx.current_N_blk - = is_N_tail ? kernel_conf.N_tail : kernel_conf.N_blk; - - const auto comp_offset = batch * kernel_conf.s8s8_comp_b_str - + n_blk_idx * kernel_conf.s8s8_comp_n_str; - - ker_exec_ctx.zp_a_compensation_ptr - = kernel_conf.has_zero_point_a - ? (void *)&zp[comp_offset] - : nullptr; - ker_exec_ctx.compensation_ptr - = kernel_conf.s8s8_compensation_required - ? (void *)&cp[comp_offset] - : nullptr; - - // required to compute zp compensation - int tmp_neg_a_zp_val = -1; - ker_exec_ctx.zp_a_neg_value_ptr = &tmp_neg_a_zp_val; - - int k_blk_idx = 0; - for (; k_blk_idx < kernel_conf.K / kernel_conf.K_blk; - k_blk_idx++) { - const auto k = k_blk_idx * kernel_conf.K_blk; - ker_exec_ctx.src = (void *)&src[get_blk_off( - src_d, sdt_sz, batch, k, n)]; - ker_exec_ctx.tr_src = (void *)&dst[get_blk_off( - dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx)]; - ker_exec_ctx.current_K_start = k; - ker_exec_ctx.current_K_iters = kernel_conf.K_blk; - (*kernel_)(&ker_exec_ctx); - } - if (kernel_conf.K_tail > 0) { - const auto k = k_blk_idx * kernel_conf.K_blk; - ker_exec_ctx.src = (void *)&src[get_blk_off( - src_d, sdt_sz, batch, k, n)]; - const auto dst_offset = get_blk_off( - dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx); - ker_exec_ctx.tr_src = (void *)&dst[dst_offset]; - ker_exec_ctx.current_K_start = k; - ker_exec_ctx.current_K_iters = kernel_conf.K_tail; - (*kernel_)(&ker_exec_ctx); - const auto vnni_granularity - = data_type_vnni_granularity(type_o); - const auto dst_zero_out_offset - = rnd_up(kernel_conf.K_tail, vnni_granularity) - * kernel_conf.N_blk * ddt_sz; - const auto elems_to_zero - = rnd_dn(kernel_conf.K_blk - kernel_conf.K_tail, - vnni_granularity) - * kernel_conf.N_blk * ddt_sz; - array_set(&dst[dst_offset + dst_zero_out_offset], 0, - elems_to_zero); - } - }); + const auto n = n_blk_idx * kernel_conf.N_blk; + const bool is_N_tail = (kernel_conf.N - n < kernel_conf.N_blk); + auto ker_exec_ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); + ker_exec_ctx.current_N_blk + = is_N_tail ? kernel_conf.N_tail : kernel_conf.N_blk; + + const auto comp_offset = batch * kernel_conf.s8s8_comp_b_str + + n_blk_idx * kernel_conf.s8s8_comp_n_str; + + ker_exec_ctx.zp_a_compensation_ptr = kernel_conf.has_zero_point_a + ? (void *)&zp[comp_offset] + : nullptr; + ker_exec_ctx.compensation_ptr = kernel_conf.s8s8_compensation_required + ? (void *)&cp[comp_offset] + : nullptr; + + // required to compute zp compensation + int tmp_neg_a_zp_val = -1; + ker_exec_ctx.zp_a_neg_value_ptr = &tmp_neg_a_zp_val; + + int k_blk_idx = 0; + for (; k_blk_idx < kernel_conf.K / kernel_conf.K_blk; k_blk_idx++) { + const auto k = k_blk_idx * kernel_conf.K_blk; + ker_exec_ctx.src + = (void *)&src[get_blk_off(src_d, sdt_sz, batch, k, n)]; + ker_exec_ctx.tr_src = (void *)&dst[get_blk_off( + dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx)]; + ker_exec_ctx.current_K_start = k; + ker_exec_ctx.current_K_iters = kernel_conf.K_blk; + (*kernel_)(&ker_exec_ctx); + } + if (kernel_conf.K_tail > 0) { + const auto k = k_blk_idx * kernel_conf.K_blk; + ker_exec_ctx.src + = (void *)&src[get_blk_off(src_d, sdt_sz, batch, k, n)]; + const auto dst_offset + = get_blk_off(dst_d, ddt_sz, batch, k_blk_idx, n_blk_idx); + ker_exec_ctx.tr_src = (void *)&dst[dst_offset]; + ker_exec_ctx.current_K_start = k; + ker_exec_ctx.current_K_iters = kernel_conf.K_tail; + (*kernel_)(&ker_exec_ctx); + const auto vnni_granularity = data_type_vnni_granularity(type_o); + const auto dst_zero_out_offset + = rnd_up(kernel_conf.K_tail, vnni_granularity) + * kernel_conf.N_blk * ddt_sz; + const auto elems_to_zero + = rnd_dn(kernel_conf.K_blk - kernel_conf.K_tail, + vnni_granularity) + * kernel_conf.N_blk * ddt_sz; + array_set(&dst[dst_offset + dst_zero_out_offset], 0, elems_to_zero); + } + }); #undef get_blk_off diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_reorders.hpp b/src/cpu/aarch64/matmul/brgemm_matmul_reorders.hpp index ccf0a12a89c..e556116f9ca 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_reorders.hpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_reorders.hpp @@ -1,6 +1,8 @@ /******************************************************************************* * Copyright 2022 Intel Corporation * Copyright 2024 FUJITSU LIMITED +* Copyright 2025 Arm Ltd. and affiliates +* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -25,12 +27,12 @@ namespace impl { namespace cpu { namespace aarch64 { -struct brgemm_matmul_matrix_B_reorder_t : public primitive_t { +struct brgemm_matmul_copy_reorder_t : public primitive_t { struct pd_t : public cpu_reorder_pd_t { using cpu_reorder_pd_t::cpu_reorder_pd_t; - DECLARE_COMMON_PD_T("brgemm_matmul_matrix_B_reorder_t", - brgemm_matmul_matrix_B_reorder_t); + DECLARE_COMMON_PD_T( + "brgemm_matmul_copy_reorder_t", brgemm_matmul_copy_reorder_t); // required to re-use brgemm matmul copy_b jit kernels matmul::brgemm_matmul_conf_t matmul_conf_for_reorder_; @@ -47,7 +49,7 @@ struct brgemm_matmul_matrix_B_reorder_t : public primitive_t { friend dnnl::impl::impl_list_item_t; }; - brgemm_matmul_matrix_B_reorder_t(const pd_t *apd) : primitive_t(apd) {} + brgemm_matmul_copy_reorder_t(const pd_t *apd) : primitive_t(apd) {} status_t init(engine_t *engine) override { CHECK(matmul::create_brgemm_matmul_copy_b( kernel_, &pd()->matmul_conf_for_reorder_)); diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp b/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp index 0610147c752..6f95065da32 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_utils.cpp @@ -1,7 +1,7 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021 Intel Corporation * Copyright 2023-2024 FUJITSU LIMITED -* Copyright 2024 Arm Ltd. and affiliates +* Copyright 2024-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,16 +16,14 @@ * limitations under the License. *******************************************************************************/ -#include - +#include "cpu/aarch64/matmul/brgemm_matmul_utils.hpp" +#include "common/c_types_map.hpp" #include "common/dnnl_thread.hpp" +#include "common/type_helpers.hpp" #include "cpu/aarch64/injectors/jit_uni_postops_injector.hpp" -#include "cpu/aarch64/matmul/brgemm_matmul_utils.hpp" -#include "cpu/platform.hpp" #include "cpu/binary_injector_utils.hpp" #include "cpu/matmul/matmul_utils.hpp" -#include "oneapi/dnnl/dnnl_debug.h" // TODO add a method to print brgemm conf info #define VCONDCHECK_BG(cond, msg, ...) \ @@ -215,23 +213,22 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag( } else { bgmmc.wei_tag = blocked_B_layouts_allowed ? memory_desc_matches_one_of_tag(B_md, plain_tensor_layout_tag, - transposed_tensor_layout_tag, blocked_64n_B_layout_tag, - blocked_48n_B_layout_tag, blocked_32n_B_layout_tag, - blocked_16n_B_layout_tag) + transposed_tensor_layout_tag, + blocked_64n_B_layout_tag, blocked_48n_B_layout_tag, + blocked_32n_B_layout_tag, blocked_16n_B_layout_tag) : memory_desc_matches_one_of_tag(B_md, plain_tensor_layout_tag, - transposed_tensor_layout_tag, acbd, adbc); - - // For cases when the weights tensor is transposed but has - // 'dim_size == 1', we can ignore transposition and compute as a plain - // format tensor. This removes the need of allocating a scratchpad for - // copy_B. - if (transposed_tensor_layout_tag == bgmmc.wei_tag) { - memory_desc_t B_md_plain; - const status_t status - = memory_desc_init_by_tag(B_md_plain, B_md.ndims, B_md.dims, - B_md.data_type, plain_tensor_layout_tag); - if (status != status::success) return status; - if (B_md_plain == B_md) bgmmc.wei_tag = plain_tensor_layout_tag; + transposed_tensor_layout_tag, acbd, adbc); + + // If the B memory descriptor matches both the transposed and plain + // version that means that for dims = [P, Q, K, N] in the weight matrix, + // then (K || N) == 1. That is, B is either a row or column vector. In + // this case it makes no difference if we treat B as row-major or + // column-major since they are identical for a vector. Therefore we + // chose to treat it as "plain" since that saves us the extra time and + // scratchpad memory we would need for an unnecessary transpose. + if (memory_desc_matches_tag(B_md, transposed_tensor_layout_tag) + && memory_desc_matches_tag(B_md, plain_tensor_layout_tag)) { + bgmmc.wei_tag = plain_tensor_layout_tag; } if (format_tag::undef == bgmmc.wei_tag) return status::unimplemented; @@ -263,16 +260,16 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md, bgmmc.src_tag = (this->is_bf16() || this->is_f32() || this->is_bf32() || this->is_f16()) ? memory_desc_matches_one_of_tag(A_md, plain_tensor_layout_tag, - transposed_tensor_layout_tag, acbd, adbc) + transposed_tensor_layout_tag, acbd, adbc) // Enable support of int8 problems with formally transposed A // layout which can be treated as plain. // TODO: remove this extra code path after transposed A is // supported for int8 : (this->is_int8() && can_treat_transposed_A_as_plain) ? memory_desc_matches_one_of_tag(A_md, plain_tensor_layout_tag, - transposed_tensor_layout_tag, acbd) + transposed_tensor_layout_tag, acbd) : memory_desc_matches_one_of_tag( - A_md, plain_tensor_layout_tag, acbd); + A_md, plain_tensor_layout_tag, acbd); } if (C_any_layout) { @@ -934,10 +931,17 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, // We need to correct A_strides if batched dimensions are merged in M and // A layout is formally transposed but could be treated as plain - if (merge_batch_dims_into_M && treat_transposed_A_as_plain) { + if (merge_batch_dims_into_M + && (src_d.matches_tag(acbd) || treat_transposed_A_as_plain)) { bgmmc.A_strides[1] = bgmmc.A_strides[2]; } + // We need to correct C_strides if batched dimensions are merged in M and + // C layout is formally transposed but could be treated as plain + if (merge_batch_dims_into_M && dst_d.matches_tag(acbd)) { + bgmmc.C_strides[1] = bgmmc.C_strides[2]; + } + // BF32 'Hint' Heuristic: // Under the following conditions, F32 through SVE512 performs better // than using BF32 arithmetic. @@ -985,6 +989,67 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, return status::success; } +status_t init_conf(brgemm_matmul_conf_t &conf, dim_t batch, dim_t M, dim_t K, + dim_t N, dim_t in_ld, dim_t n_blk, data_type_t in_type, + data_type_t out_type, format_tag_t in_tag) { + if (n_blk <= 0 && M <= 0) return status::invalid_arguments; + + const auto vnni_granularity = data_type_vnni_granularity(out_type); + if (vnni_granularity <= 0) return status::invalid_arguments; + + // Zero initialize the `conf` to avoid access to 'garbage' in members. + conf = brgemm_matmul_conf_t(); + + const bool is_bf16 = one_of(in_type, bf16) || one_of(out_type, bf16); + const bool is_s8u8 = one_of(in_type, s8, u8) || one_of(out_type, s8, u8); + + VCONDCHECK_BG(!(is_bf16 || is_s8u8), VERBOSE_UNSUPPORTED_DT); + + const bool is_copyB = N > 0; + conf.isa = get_max_cpu_isa(); // Just use the best ISA possible. + conf.is_bf32 = false; + conf.batch = batch; + conf.src_dt = conf.wei_dt = out_type; + conf.orig_src_dt = conf.orig_wei_dt = in_type; + // Note: will need to change `tr_a_dt_sz` for copyA in cases where src_dt != dst_dt + conf.a_dt_sz = conf.tr_a_dt_sz = types::data_type_size(conf.src_dt); + conf.N = N; + conf.M = M; + conf.K = K; + const dim_t copyA_K_blk = isa_num_vregs(conf.isa) / 2; + const dim_t copyB_K_blk = 16 * vnni_granularity; + conf.K_blk = is_copyB ? copyB_K_blk : copyA_K_blk; + conf.K_tail = conf.K % conf.K_blk; + if (!is_copyB) { + // Note: current implementation always calls the transposed kernel. + conf.transposed_A = true; + conf.M_blk = (dim_t)isa_max_vlen(conf.isa) / conf.a_dt_sz; + conf.M_tail = conf.M % conf.M_blk; + conf.copy_A_src_stride = in_ld * conf.a_dt_sz; + // setting LDA parameter required for plain transpose + conf.LDA = conf.K; + } else { + conf.blocked_B = !utils::one_of(in_tag, ab, ba, abc, acb); + conf.transposed_B = utils::one_of(in_tag, ba, acb); + conf.wei_tag = in_tag; + conf.wei_n_blk = conf.N_blk = conf.LDB = n_blk; + conf.N_tail = conf.N % conf.N_blk; + conf.b_dt_sz = types::data_type_size(in_type); + conf.tr_b_dt_sz = types::data_type_size(conf.wei_dt); + conf.copy_B_wei_stride = in_ld * conf.b_dt_sz; + conf.N_chunk_elems = conf.N; + conf.s8s8_comp_b_str = utils::rnd_up(conf.N, conf.wei_n_blk); + conf.s8s8_comp_n_str = conf.wei_n_blk; + } + + conf.s8s8_compensation_required = false; + conf.src_zp_type = brgemm_broadcast_t::none; + conf.has_zero_point_a = false; + conf.has_zero_point_b = false; + + return status::success; +} + void init_aux_values(brgemm_matmul_conf_t &bgmmc, const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) { @@ -1031,14 +1096,24 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc, bgmmc.A_ptr_shift_b = 0; bgmmc.copy_A_src_stride = bgmmc.a_dt_sz * (bgmmc.transposed_A ? bgmmc.M : bgmmc.K); - if (bgmmc.src_tag == acbd || bgmmc.src_tag == adbc) { - const dim_t factor = bgmmc.src_dt == f32 ? 2 : 1; - const dim_t src_stride = bgmmc.src_tag == acbd ? bgmmc.A_strides[1] - : bgmmc.A_strides[0]; - bgmmc.copy_A_src_stride = nstl::min(src_d.blocking_desc().strides[0], - src_stride / factor) - * factor; - const dim_t bcast_shift_b = bgmmc.src_tag == acbd ? bgmmc.K : bgmmc.M; + + // If src have dimensions equal to 1, multiple tags can be matched so + // we need to make sure: + // - A_ptr_shift_b is set for acbd and adbc even if bgmmc.src_tag is abcd + // - Plain md that matches acbd or adbc does not dispatch into their codepath + if (src_d.matches_one_of_tag(acbd, adbc) != format_tag::undef) { + if (src_d.matches_one_of_tag(abcd, abdc) == format_tag::undef) { + const dim_t factor = bgmmc.src_dt == f32 ? 2 : 1; + const dim_t src_stride = src_d.matches_tag(acbd) + ? bgmmc.A_strides[1] + : bgmmc.A_strides[0]; + bgmmc.copy_A_src_stride + = nstl::min(src_d.blocking_desc().strides[0], + src_stride / factor) + * factor; + } + + const dim_t bcast_shift_b = src_d.matches_tag(acbd) ? bgmmc.K : bgmmc.M; bgmmc.A_ptr_shift_b = (bgmmc.bcast_A_desc.bcast_mask == 2 ? bcast_shift_b @@ -1048,14 +1123,24 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc, bgmmc.B_ptr_shift_b = 0; bgmmc.copy_B_wei_stride = 0; - if (one_of(bgmmc.wei_tag, acbd, adbc)) { - const dim_t factor = bgmmc.wei_dt == f32 ? 2 : 1; - const dim_t wei_stride = bgmmc.wei_tag == acbd ? bgmmc.B_strides[1] - : bgmmc.B_strides[0]; - bgmmc.copy_B_wei_stride = nstl::min(wei_d.blocking_desc().strides[0], - wei_stride / factor) - * factor; - const dim_t bcast_shift_b = bgmmc.wei_tag == acbd ? bgmmc.N : bgmmc.K; + // If weights have dimensions equal to 1, multiple tags can be matched so + // we need to make sure: + // - B_ptr_shift_b is set for acbd and adbc even if bgmmc.wei_tag is abcd + // - Plain md that matches acbd or adbc does not dispatch into their codepath + // - Plain md that matches transposed tag does not dispatch into its codepath + if (wei_d.matches_one_of_tag(acbd, adbc) != format_tag::undef) { + if (wei_d.matches_one_of_tag(abcd, abdc) == format_tag::undef) { + const dim_t factor = bgmmc.wei_dt == f32 ? 2 : 1; + const dim_t wei_stride = wei_d.matches_tag(acbd) + ? bgmmc.B_strides[1] + : bgmmc.B_strides[0]; + bgmmc.copy_B_wei_stride + = nstl::min(wei_d.blocking_desc().strides[0], + wei_stride / factor) + * factor; + } + + const dim_t bcast_shift_b = wei_d.matches_tag(acbd) ? bgmmc.N : bgmmc.K; bgmmc.B_ptr_shift_b = (bgmmc.bcast_B_desc.bcast_mask == 2 ? bcast_shift_b @@ -1063,7 +1148,7 @@ void init_aux_values(brgemm_matmul_conf_t &bgmmc, * bgmmc.b_dt_sz; } - bgmmc.C_ptr_shift_b = bgmmc.dst_tag == acbd + bgmmc.C_ptr_shift_b = dst_d.matches_tag(acbd) ? dst_d.blocking_desc().strides[0] * bgmmc.c_dt_sz : 0; diff --git a/src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp b/src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp index ec4e1b75a27..13f8c50d5ff 100644 --- a/src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp +++ b/src/cpu/aarch64/matmul/brgemm_matmul_utils.hpp @@ -1,6 +1,7 @@ /******************************************************************************* -* Copyright 2021-2023 Intel Corporation +* Copyright 2021 Intel Corporation * Copyright 2023-2024 FUJITSU LIMITED +* Copyright 2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,6 +122,8 @@ struct brgemm_matmul_conf_t { data_type_t wei_dt; data_type_t acc_dt; data_type_t bia_dt; + data_type_t orig_src_dt; + data_type_t orig_wei_dt; int nthr; int nthr_k; @@ -166,6 +169,7 @@ struct brgemm_matmul_conf_t { bool has_zero_point_a, has_zero_point_b, has_zero_point_c; bool post_ops_applicable; bool transposed_A; + bool transposed_B; bool blocked_B; dim_t zp_a_comp_shift_n; @@ -210,6 +214,12 @@ struct brgemm_matmul_conf_utils_t { } inline bool use_buffer_b(bool use_heuristic = true) const { + // In the case of 1xK gemmv, we should avoid copying the weights if + // they are in BA format, since the copy would be more expensive than + // the gemv itself. + if (bgmmc.M == 1 && bgmmc.N > 1 && bgmmc.wei_tag == format_tag::ba) { + return false; + } // Values based on measured performance difference // between plain and copy-to-blocked routine. size_t big_LDB = bgmmc.N > 256; @@ -301,6 +311,10 @@ struct brgemm_matmul_conf_utils_t { const cpu_isa_t isa_; }; +status_t init_conf(brgemm_matmul_conf_t &conf, dim_t batch, dim_t M, dim_t K, + dim_t N, dim_t in_ld, dim_t n_blk, data_type_t in_type, + data_type_t out_type, format_tag_t in_tag); + void init_aux_values(brgemm_matmul_conf_t &bgmmc, const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d); diff --git a/src/cpu/aarch64/reorder/acl_reorder.cpp b/src/cpu/aarch64/reorder/acl_reorder.cpp new file mode 100644 index 00000000000..7934f58a536 --- /dev/null +++ b/src/cpu/aarch64/reorder/acl_reorder.cpp @@ -0,0 +1,280 @@ +/******************************************************************************* +* Copyright 2023, 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/aarch64/reorder/acl_reorder.hpp" +#include "cpu/aarch64/cpu_isa_traits.hpp" + +namespace { +/* +* Find the index of the dense dimension. +* The stride of the inner most dense block will be +* multiplied by the blocking of all prior blocks. +*/ +int find_innermost_dense_idx(const dnnl::impl::memory_desc_t *md) { + uint32_t dense_blk = 1; + for (int i = 0; i < md->format_desc.blocking.inner_nblks; i++) { + dense_blk *= md->format_desc.blocking.inner_blks[i]; + } + + int dense_idx = -1; + for (int i = 0; i < md->ndims; i++) { + if (md->format_desc.blocking.strides[i] == dense_blk) dense_idx = i; + } + return dense_idx; +} +} // namespace + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +status_t acl_reorder_resource_t::configure(const acl_reorder_conf_t &app) { + if (!acl_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor + acl_obj_->src_tensor.allocator()->init(app.src_info); + acl_obj_->dst_tensor.allocator()->init(app.dst_info); + + // clang-format off + acl_obj_->reorder.configure( + &acl_obj_->src_tensor, + &acl_obj_->dst_tensor, + app.src_wf, + app.dst_wf, + app.transpose + ); + // clang-format on + + return status::success; +} + +status_t acl_reorder_fwd_t::pd_t::create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, + const memory_desc_t *dst_md) { + using namespace acl_utils; + using namespace dnnl::impl; + + // ComputeLibrary reorders support f32->f32 and f32->bf16 + bool ok = src_md->data_type == data_type::f32 + && utils::one_of(dst_md->data_type, data_type::f32, data_type::bf16) + && attr->has_default_values(); + + VDISPATCH_REORDER_IC(ok, "unsupported datatype"); + + // Create and check primitive descriptor + auto _pd = make_unique_pd( + attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); + if (_pd == nullptr) return status::out_of_memory; + VDISPATCH_REORDER_IC( + _pd->init(engine, src_engine, dst_engine) == status::success, + "pd initialization failed"); + + // In case we have two or four dimensions, we can't have one of the + // two first dimensions as 1. This is valid for f32->f32 and f32->bf16. + VDISPATCH_REORDER_IC(dst_md->dims[0] != 1 && dst_md->dims[1] != 1, + "first two dimensions of the reorder being 1 is not supported"); + + auto src_tag = memory_desc_matches_one_of_tag( + *src_md, format_tag::ab, format_tag::ba, format_tag::cdba); + VDISPATCH_REORDER_IC(format_tag::undef != src_tag, + "Only ab, ba or cdba source formats supported"); + + auto dst_tag = memory_desc_matches_one_of_tag(*dst_md, format_tag::BA8b4a, + format_tag::BA4b4a, format_tag::Ab4a, format_tag::Ab8a, + format_tag::Acdb8a, format_tag::Acdb4a); + ACL_CHECK_SUPPORT(format_tag::undef == dst_tag, + "Only Ab4a/Ab8a, BA8b4a/BA4b4a and Acdb8a/Acdb4a " + "destination formats supported"); + + auto &transpose = _pd->app_.transpose; + auto &dst_blocking = dst_md->format_desc.blocking; + + VDISPATCH_REORDER_IC(src_md->ndims == dst_md->ndims, + "Number of dimensions in src and dst do not match"); + VDISPATCH_REORDER_IC((dst_md->ndims == 2 || dst_md->ndims == 4), + "ACL only supports 2D and 4D reorders"); + // Check if a transpose is needed during the reorder + if (src_md->ndims == 4) { + VDISPATCH_REORDER_IC( + memory_desc_matches_tag(*src_md, dnnl::impl::format_tag::cdba) + && (memory_desc_matches_one_of_tag(*dst_md, + dnnl::impl::format_tag::Acdb4a, + dnnl::impl::format_tag::Acdb8a) + != format_tag::undef), + VERBOSE_UNSUPPORTED_TAG); + transpose = true; + } else { + int src_dense_idx = find_innermost_dense_idx(src_md); + int dst_dense_idx = find_innermost_dense_idx(dst_md); + + transpose = src_dense_idx != dst_dense_idx; + } + + // Return unimplemented for non-transposed reorders for now + // as they are faster in JIT for most cases. + VDISPATCH_REORDER_IC( + transpose, "non-transposed reorders are not supported"); + + // Optimised f32:bf16 ab->BA8b4a SVE-256 JIT reorder available + VDISPATCH_REORDER_IC( + !(mayiuse(sve_256), + transpose && src_md->ndims == 2 + && src_md->data_type == data_type::f32 + && dst_md->data_type == data_type::bf16 + && memory_desc_matches_one_of_tag(*dst_md, + format_tag::BA8b4a, format_tag::AB8a4b)), + "skipping in favour of optimised JIT implementation"); + + auto &dst_wf = _pd->app_.dst_wf; + + VDISPATCH_REORDER_IC( + dst_blocking.inner_nblks <= 2, VERBOSE_UNSUPPORTED_TAG); + // Offsets to calculate the enum for ComputeLibrary weight formats + // defined in arm_compute/core/CoreTypes.h + const auto interleave_offset = 0x000100; + const auto block_by_offset = 0x100000; + for (int i = 0; i < dst_blocking.inner_nblks; i++) { + auto blk = dst_blocking.inner_blks[i]; + if (i == 0) { + auto offset = interleave_offset; + dst_wf = (arm_compute::WeightFormat)( + static_cast(dst_wf) + offset * (blk - 1)); + } else if (i == 1) { + auto offset = block_by_offset; + // Set block_by + dst_wf = (arm_compute::WeightFormat)( + static_cast(dst_wf) + offset * (blk - 1)); + } + } + + arm_compute::TensorShape acl_tensor_shape_in; + arm_compute::TensorShape acl_tensor_shape_out; + + // Switch for 2 or 4 dim tensors + switch (src_md->ndims) { + case 2: { + if ((src_tag == format_tag::ab && transpose) + || (src_tag == format_tag::ba && !transpose)) { + acl_tensor_shape_in = arm_compute::TensorShape( + src_md->dims[0], src_md->dims[1]); + acl_tensor_shape_out = arm_compute::TensorShape( + dst_md->padded_dims[0], dst_md->padded_dims[1]); + } else if ((src_tag == format_tag::ba && transpose) + || (src_tag == format_tag::ab && !transpose)) { + acl_tensor_shape_in = arm_compute::TensorShape( + src_md->dims[1], src_md->dims[0]); + acl_tensor_shape_out = arm_compute::TensorShape( + dst_md->padded_dims[1], dst_md->padded_dims[0]); + } else { + VINFO(primitive, create, dispatch, reorder, + "Unsupported source tag for 2D reorder"); + return status::unimplemented; + } + } break; + case 4: { + // Currently only supporting AxBx1x1 cases + VDISPATCH_REORDER_IC(dst_md->dims[2] == 1 && dst_md->dims[3] == 1, + "currently only AxBx1x1 4d reorders are supported"); + + acl_tensor_shape_in = arm_compute::TensorShape(src_md->dims[3], + src_md->dims[2], src_md->dims[1], src_md->dims[0]); + acl_tensor_shape_out = arm_compute::TensorShape( + dst_md->padded_dims[3], dst_md->padded_dims[2], + dst_md->padded_dims[1], dst_md->padded_dims[0]); + break; + } + default: { + VINFO(primitive, create, dispatch, reorder, + VERBOSE_UNSUPPORTED_TAG); + return status::unimplemented; + } + } + + // Choose the data layout + const auto acl_layout = arm_compute::DataLayout::NCHW; + + // Set Source WeightFormat + _pd->app_.src_wf = arm_compute::WeightFormat::OHWI; + + // Create ACL tensor infos + const arm_compute::DataType src_acl_data_t + = acl_utils::get_acl_data_t(src_md->data_type); + _pd->app_.src_info = arm_compute::TensorInfo( + acl_tensor_shape_in, 1, src_acl_data_t, acl_layout); + + const arm_compute::DataType dst_acl_data_t + = acl_utils::get_acl_data_t(dst_md->data_type); + _pd->app_.dst_info = arm_compute::TensorInfo( + acl_tensor_shape_out, 1, dst_acl_data_t, acl_layout); + + ACL_CHECK_VALID(arm_compute::NEReorderLayer::validate(&_pd->app_.src_info, + &_pd->app_.dst_info, _pd->app_.src_wf, dst_wf, + _pd->app_.transpose)); + // Init scratch memory, not used so 0 in this implementation + _pd->init_scratchpad_md(); + + return safe_ptr_assign(*reorder_pd, _pd.release()); +} + +status_t acl_reorder_fwd_t::create_resource( + engine_t *engine, resource_mapper_t &mapper) const { + if (mapper.has_resource(this)) return status::success; + + auto r = utils::make_unique(); + if (!r) return status::out_of_memory; + + // Configure the resource based on information from primitive descriptor + CHECK(r->configure(pd()->app_)); + + mapper.add(this, std::move(r)); + return status::success; +} + +status_t acl_reorder_fwd_t::execute(const exec_ctx_t &ctx) const { + return execute_forward(ctx); +} + +status_t acl_reorder_fwd_t::execute_forward(const exec_ctx_t &ctx) const { + // Lock here is needed because resource_mapper does not support + // concurrent multithreaded access. + std::lock_guard _lock {this->mtx}; + + auto src = CTX_IN_MEM(const void *, DNNL_ARG_FROM); + auto dst = CTX_OUT_MEM(void *, DNNL_ARG_TO); + + // Retrieve primitive resource and configured Compute Library objects + auto *acl_resource + = ctx.get_resource_mapper()->get(this); + + acl_reorder_obj_t &acl_obj = acl_resource->get_acl_obj(); + + acl_obj.src_tensor.allocator()->import_memory(const_cast(src)); + acl_obj.dst_tensor.allocator()->import_memory(dst); + + acl_obj.reorder.run(); + + acl_obj.src_tensor.allocator()->free(); + acl_obj.dst_tensor.allocator()->free(); + + return status::success; +} + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/reorder/acl_reorder.hpp b/src/cpu/aarch64/reorder/acl_reorder.hpp new file mode 100644 index 00000000000..83a6ed88636 --- /dev/null +++ b/src/cpu/aarch64/reorder/acl_reorder.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright 2023-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ +#ifndef CPU_AARCH64_REORDER_ACL_REORDER_HPP +#define CPU_AARCH64_REORDER_ACL_REORDER_HPP + +#include "common/utils.hpp" +#include "cpu/aarch64/acl_utils.hpp" +#include "cpu/reorder/cpu_reorder_pd.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +struct acl_reorder_obj_t { + arm_compute::NEReorderLayer reorder; + arm_compute::Tensor src_tensor; + arm_compute::Tensor dst_tensor; + arm_compute::WeightFormat src_wf; + arm_compute::WeightFormat dst_wf; +}; + +struct acl_reorder_conf_t { + arm_compute::TensorInfo src_info; + arm_compute::TensorInfo dst_info; + arm_compute::WeightFormat src_wf = arm_compute::WeightFormat::OHWI; + arm_compute::WeightFormat dst_wf = arm_compute::WeightFormat::OHWI; + bool transpose; +}; + +struct acl_reorder_resource_t : public resource_t { + acl_reorder_resource_t() + : acl_obj_(utils::make_unique()) {} + + status_t configure(const acl_reorder_conf_t &app); + + acl_reorder_obj_t &get_acl_obj() const { return *acl_obj_; } + DNNL_DISALLOW_COPY_AND_ASSIGN(acl_reorder_resource_t); + +private: + std::unique_ptr acl_obj_; +}; // acl_reorder_resource_t + +struct acl_reorder_fwd_t : public primitive_t { + using primitive_t::primitive_t; + struct pd_t : public cpu_reorder_pd_t { + + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("acl", acl_reorder_fwd_t); + + static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, + const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, + const memory_desc_t *dst_md); + + friend dnnl::impl::impl_list_item_t; + acl_reorder_conf_t app_; + + }; // pd_t + + acl_reorder_fwd_t(const pd_t *apd) : primitive_t(apd) {} + + status_t create_resource( + engine_t *engine, resource_mapper_t &mapper) const override; + + status_t execute(const exec_ctx_t &ctx) const override; + +private: + // To guard the const execute_forward, the mutex must be 'mutable' + mutable std::mutex mtx; + status_t execute_forward(const exec_ctx_t &ctx) const; + inline const pd_t *pd() const { + return (const pd_t *)primitive_t::pd().get(); + } + +}; // acl_reorder_fwd_t + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_AARCH64_REORDER_ACL_REORDER_HPP diff --git a/src/cpu/aarch64/reorder/jit_blk_reorder.cpp b/src/cpu/aarch64/reorder/jit_blk_reorder.cpp new file mode 100644 index 00000000000..a19d58c4e7e --- /dev/null +++ b/src/cpu/aarch64/reorder/jit_blk_reorder.cpp @@ -0,0 +1,152 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* Copyright 2020-2024 FUJITSU LIMITED +* Copyright 2022-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory_desc_wrapper.hpp" +#include "common/nstl.hpp" +#include "common/primitive.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/aarch64/jit_generator.hpp" +#include "cpu/aarch64/reorder/jit_blk_reorder.hpp" + +// #define DNNL_DEV_MODE +#if defined(DNNL_DEV_MODE) +#define DEBUg(...) \ + do { \ + if (get_verbose(verbose_t::debuginfo) > 1) { __VA_ARGS__ } \ + } while (0) +#else +#define DEBUg(...) +#endif +#define DEBUG(...) DEBUg(__VA_ARGS__) + +using namespace Xbyak_aarch64; +using namespace dnnl::impl::types; + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +status_t jit_blk_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, + const memory_desc_t *dst_md) { + if (!impl::is_dense_format_kind({src_md, dst_md})) + return status::unimplemented; + auto prb = tr::prb_t(); + // For shapes with dimension greater than thres it is found that jit:uni is better that jit:blk + auto upper_thres = 1920 * 4096; + auto src_d = memory_desc_wrapper(src_md); + auto prd = 1; + + for (int d = 0; d < src_d.ndims(); ++d) { + const auto dim = src_d.dims()[d]; + prd *= dim; + if (prd > upper_thres) return status::unimplemented; + } + + // Very small shapes are faster on jit uni for SVE-128 + auto lower_thres = 128 * 128; + + if (get_max_cpu_isa() == sve_128 && prd < lower_thres) { + return status::unimplemented; + } + + status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); + if (prb_init_status != status::success) return prb_init_status; + // only uni_reorder supports tail processing now + // TODO: Add tail processing support in blk_reorder + if (prb.is_tail_present) return status::unimplemented; + + prb_tile_normalize(prb); + DEBUG({ + verbose_printf( + verbose_t::debuginfo, "tile : %s\n", prb_dump(prb).c_str()); + }); + + if (!tr::jit_single_blk_kernel_t::applicable(prb)) { + return status::unimplemented; + } + + auto _pd = make_unique_pd( + attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); + if (_pd == nullptr) return status::out_of_memory; + _pd->prb_ = prb; + CHECK(_pd->init(engine, src_engine, dst_engine)); + CHECK(_pd->init_scratchpad_md()); + + return safe_ptr_assign(*reorder_pd, _pd.release()); +} + +void jit_blk_reorder_t::pd_t::prb_tile_normalize(tr::prb_t &p) { + if (!utils::one_of(p.nodes[0].n, 4ul, 8ul, 16ul, 32ul, 64ul) + && utils::one_of(p.nodes[1].n, 4ul, 8ul, 16ul, 32ul, 64ul)) { + nstl::swap(p.nodes[0], p.nodes[1]); + } +} + +jit_blk_reorder_t::jit_blk_reorder_t(const pd_t *apd) : primitive_t(apd) {} +jit_blk_reorder_t::~jit_blk_reorder_t() = default; + +status_t jit_blk_reorder_t::init(engine_t *engine) { + kernel_ = utils::make_unique(pd()->prb_); + return kernel_->create_kernel(); +} + +status_t jit_blk_reorder_t::execute(const exec_ctx_t &ctx) const { + const auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); + auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); + + // kernel handle 2-dimension tiles, a tail is possible + auto &prb = this->pd()->prb_; + ptrdiff_t BH = 1; + for (int i = 2; i < prb.ndims; ++i) { + BH *= prb.nodes[i].n; + } + + auto block_sz = prb.n(0); + auto n1 = prb.n(1); + auto i1 = prb.is(1); + auto o1 = prb.os(1); + auto FL = (n1 + block_sz - 1) / block_sz; + auto bh_stride = BH == 1 ? 0 : prb.is(2); + + auto itype_sz_ = data_type_size(pd()->prb_.itype); + auto otype_sz_ = data_type_size(pd()->prb_.otype); + + parallel_nd(BH, FL, [&](dim_t bh, dim_t fl) { + auto fl_b = fl * block_sz; + auto bh_b = bh_stride * bh; + auto *i = in + (bh_b + fl_b * i1) * itype_sz_; + auto *o = out + (bh_b + fl_b * o1) * otype_sz_; + (*kernel_)(i, o, n1 - fl_b < block_sz); + }); + + return status::success; +} + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/reorder/jit_blk_reorder.hpp b/src/cpu/aarch64/reorder/jit_blk_reorder.hpp new file mode 100644 index 00000000000..01f77add571 --- /dev/null +++ b/src/cpu/aarch64/reorder/jit_blk_reorder.hpp @@ -0,0 +1,69 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* Copyright 2020-2023 FUJITSU LIMITED +* Copyright 2022, 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_REORDER_JIT_BLK_REORDER_HPP +#define CPU_AARCH64_REORDER_JIT_BLK_REORDER_HPP + +#include + +#include "common/c_types_map.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder_utils.hpp" +#include "cpu/reorder/cpu_reorder_pd.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +struct jit_blk_reorder_t : public primitive_t { + using primitive_t::primitive_t; + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + DECLARE_COMMON_PD_T("jit:blk", jit_blk_reorder_t); + + tr::prb_t prb_; + + private: + static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, + const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, + const memory_desc_t *dst_md); + + // Swap last two nodes, put block 4, 8, 16 nodes to first + static void prb_tile_normalize(tr::prb_t &p); + friend dnnl::impl::impl_list_item_t; + }; + + status_t init(engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; + + jit_blk_reorder_t(const pd_t *apd); + ~jit_blk_reorder_t() override; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::unique_ptr kernel_; +}; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/aarch64/reorder/jit_blk_reorder_kernel.cpp b/src/cpu/aarch64/reorder/jit_blk_reorder_kernel.cpp new file mode 100644 index 00000000000..7740b926fd5 --- /dev/null +++ b/src/cpu/aarch64/reorder/jit_blk_reorder_kernel.cpp @@ -0,0 +1,560 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* Copyright 2020-2024 FUJITSU LIMITED +* Copyright 2022-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "common/c_types_map.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +namespace tr { +using namespace Xbyak_aarch64; +using namespace dnnl::impl::types; + +// Seperate class for no unroll/threading burden +bool jit_single_blk_kernel_t::applicable(const prb_t &p) { + + using namespace data_type; + + bool ok = p.ndims >= 2 && utils::one_of(get_max_cpu_isa(), sve_128, sve_256) + && p.src_scale_type == scale_type_t::NONE + && p.dst_scale_type == scale_type_t::NONE + && utils::one_of(p.itype, f32) && utils::one_of(p.otype, f32) + && utils::everyone_is(0, p.ioff, p.ooff) && p.beta == 0.f + && prb_has_small_strides(p); + if (!ok) return false; + + int64_t n0 = p.nodes[0].n; + auto i0 = p.nodes[0].is; + auto o0 = p.nodes[0].os; + int64_t n1 = p.nodes[1].n; + auto i1 = p.nodes[1].is; + auto o1 = p.nodes[1].os; + + /* + * for a transpose of plain to 8c case, nodes would be like: + * n is os + * m 1 8 + * 8 m 1 + * or + * 8 m 1 + * m 1 8 + */ + ok = (utils::one_of(n0, 4, 8, 16, 32, 64) + || utils::one_of(n1, 4, 8, 16, 32, 64)) + && ((i0 == 1 && o1 == 1 && n0 == i1 && o0 == n1) + || (o0 == 1 && i1 == 1 && n0 == o1 && i0 == n1)); + if (!ok) return false; + + // The 128-bit version only supports blocking of exactly 4, while the + // 256-bit version only suppports the larger block sizes. + if (get_max_cpu_isa() == sve_128) { + if (n0 != 4 && n1 != 4) { return false; } + } else if (get_max_cpu_isa() == sve_256) { + if (n0 == 4 || n1 == 4) return false; + } + + // Do not handle transpose of dimensions other than last 2 + for (int i = 2; i < p.ndims; ++i) { + if (p.nodes[i].is != p.nodes[i].os) { + ok = false; + break; + } + } + + return ok; +} + +jit_single_blk_kernel_t::jit_single_blk_kernel_t(const prb_t &prb) + : prb_(prb) + , itype_sz_(data_type_size(prb_.itype)) + , otype_sz_(data_type_size(prb_.otype)) + , block_sz(prb.nodes[0].n) {} + +void jit_single_blk_kernel_t::preamble() { + if (get_sve_length() == 32) { + ptrue(p_lsb_256.b, VL32); + } else { + ptrue(p_lsb_256.b, VL16); + } +} + +void jit_single_blk_kernel_t::postamble() { + ret(); +} + +void jit_single_blk_kernel_t::generate() { + auto input_stride + = prb_.nodes[0].is != 1 ? prb_.nodes[0].is : prb_.nodes[1].is; + auto output_stride + = prb_.nodes[0].os != 1 ? prb_.nodes[0].os : prb_.nodes[1].os; + + Label tail_processing; + + set_preg(p_tmp2.s, 4, X_TMP_0, X_TMP_1); + rev(p_tmp1.s, p_tmp2.s); + + preamble(); + + cmp(reg_ptr_tail, true); + b(EQ, tail_processing); + + if (block_sz == 4) { + gen_ker4x4(0, 0, input_stride, output_stride, 4, 4); + } else if (block_sz == 8) { + gen_ker8x8(0, 0, input_stride, output_stride, 8, 8); + } else if (block_sz == 16) { + gen_ker16x16_in_8x8(0, 0, input_stride, output_stride); + } else if (block_sz == 32) { + gen_ker32x32_in_16x16(0, 0, input_stride, output_stride); + } else if (block_sz == 64) { + gen_ker64x64_in_32x32(0, 0, input_stride, output_stride); + } else { + assert(!"unimplemented"); + } + + postamble(); + + L(tail_processing); + + if (block_sz == 4) { + auto i_tail = input_stride % 4 != 0 ? input_stride % 4 : 4; + auto o_tail = output_stride % 4 != 0 ? output_stride % 4 : 4; + auto t_mask = i_tail == 4 ? o_tail : i_tail; + gen_setmask(t_mask); + gen_ker4x4(0, 0, input_stride, output_stride, i_tail, o_tail); + } else if (block_sz == 8) { + auto i_tail = input_stride % 8 != 0 ? input_stride % 8 : 8; + auto o_tail = output_stride % 8 != 0 ? output_stride % 8 : 8; + if (i_tail != o_tail) { + auto t_mask = i_tail == 8 ? o_tail : i_tail; + gen_setmask(t_mask); + gen_ker8x8(0, 0, input_stride, output_stride, i_tail, o_tail); + } + } else if (block_sz == 16) { + auto i_tail = input_stride % 16 != 0 ? input_stride % 16 : 16; + auto o_tail = output_stride % 16 != 0 ? output_stride % 16 : 16; + if (i_tail != o_tail) { + auto t_mask = i_tail == 16 ? o_tail : i_tail; + t_mask %= 8; + if (t_mask != 0) gen_setmask(t_mask); + gen_ker16x16_in_8x8( + 0, 0, input_stride, output_stride, i_tail, o_tail); + } + } else if (block_sz == 32) { + auto i_tail = input_stride % 32 != 0 ? input_stride % 32 : 32; + auto o_tail = output_stride % 32 != 0 ? output_stride % 32 : 32; + if (i_tail != o_tail) { + auto t_mask = i_tail == 32 ? o_tail : i_tail; + t_mask %= 8; + if (t_mask != 0) gen_setmask(t_mask); + gen_ker32x32_in_16x16( + 0, 0, input_stride, output_stride, i_tail, o_tail); + } + } else if (block_sz == 64) { + auto i_tail = input_stride % 64 != 0 ? input_stride % 64 : 64; + auto o_tail = output_stride % 64 != 0 ? output_stride % 64 : 64; + if (i_tail != o_tail) { + auto t_mask = i_tail == 64 ? o_tail : i_tail; + t_mask %= 8; + if (t_mask != 0) gen_setmask(t_mask); + gen_ker64x64_in_32x32( + 0, 0, input_stride, output_stride, i_tail, o_tail); + } + } else { + assert(!"unimplemented"); + } + + postamble(); +} + +void jit_single_blk_kernel_t::gen_loadu( + const ZRegS ymm, const XReg &addr, int size) { + QReg xmm(ymm.getIdx()); + switch (size) { + case 32: ld1w(ymm, p_lsb_256 / T_z, ptr(addr)); break; + case 16: ldr(xmm, ptr(addr)); break; + default: assert(!"unreachable"); + } +} + +void jit_single_blk_kernel_t::gen_storeu( + const XReg &addr, const ZRegS ymm, int size) { + QReg xmm(ymm.getIdx()); + switch (size) { + case 32: st1w(ymm, p_lsb_256, ptr(addr)); break; + case 16: str(xmm, ptr(addr)); break; + default: assert(!"unreachable"); + } +} + +void jit_single_blk_kernel_t::gen_maskloadu( + const ZRegS ymm, const XReg &addr, const PReg mask, int size) { + switch (size) { + case 32: + case 16: ld1w(ymm, mask / T_z, ptr(addr)); break; + default: assert(!"unreachable"); + } +} + +void jit_single_blk_kernel_t::gen_maskstoreu( + const XReg &addr, const ZRegS ymm, const PReg mask, int size) { + switch (size) { + case 32: + case 16: st1w(ymm, mask, ptr(addr)); break; + default: assert(!"unreachable"); + } +} + +// Register allocation xmm0~11 +void jit_single_blk_kernel_t::gen_transpose_8x8() { + const uint64_t sveLen = get_sve_length(); + constexpr int lane = 8; + +#if 0 + /* Debug code + z0: 7, 6, 5, 4, 3, 2, 1, 0 + z1: 15, 14, 13, 12, 11, 10, 9, 8 + ... + z17: 63, 62, 61, 60, 59, 58, 57, 56 + */ + ptrue(P_ALL_ONE.b); + ptrue(P_TMP.s, VL8); + not_(P_TMP.b, P_ALL_ONE/T_z, P_TMP.b); + index(z0.s, 0, 1); + mov(z0.s, P_TMP/T_m, 0); + mov(z_tmp_vec[0].s, 8); + mov(z_tmp_vec[0].s, P_TMP/T_m, 0); + for(uint32_t i=1; i nChw()C +// or nChw()C -> nchw +void jit_single_blk_kernel_t::gen_setmask(int mask) { + set_preg(p_mask.s, mask, x_tmp_0, x_tmp_1); +} + +void jit_single_blk_kernel_t::gen_transpose_4x4() { + auto &z_tmp4 = z_tmp_vec[0]; + auto &z_tmp5 = z_tmp_vec[1]; + auto &z_tmp6 = z_tmp_vec[2]; + auto &z_tmp7 = z_tmp_vec[3]; + + /* 1st turn */ + trn1(z_tmp4.s, z0.s, z1.s); + trn1(z_tmp5.s, z2.s, z3.s); + trn2(z_tmp6.s, z0.s, z1.s); + trn2(z_tmp7.s, z2.s, z3.s); + + trn1(z0.d, z_tmp4.d, z_tmp5.d); + trn1(z1.d, z_tmp6.d, z_tmp7.d); + trn2(z2.d, z_tmp4.d, z_tmp5.d); + trn2(z3.d, z_tmp6.d, z_tmp7.d); +} + +void jit_single_blk_kernel_t::gen_tr4x4(int i_off, int o_off, int input_stride, + int output_stride, int in_tail, int out_tail) { + + constexpr int lane = 4; + + if (in_tail == 0 || out_tail == 0) return; + + for (int i = 0; i < out_tail; ++i) { + if (in_tail != lane) { + add_imm(x_addr, reg_ptr_in_, i_off + i * input_stride * itype_sz_, + x_tmp_0); + gen_maskloadu(ZRegS(i), x_addr, p_mask, lane * itype_sz_); + } else { + add_imm(x_addr, reg_ptr_in_, i_off + i * input_stride * itype_sz_, + x_tmp_0); + gen_loadu(ZRegS(i), x_addr, lane * itype_sz_); + } + } + + gen_transpose_4x4(); + + for (int i = 0; i < in_tail; ++i) { + if (out_tail == lane) { + add_imm(x_addr, reg_ptr_out_, o_off + i * output_stride * otype_sz_, + x_tmp_0); + gen_storeu(x_addr, ZRegS(i), lane * otype_sz_); + } else { + add_imm(x_addr, reg_ptr_out_, o_off + i * output_stride * otype_sz_, + x_tmp_0); + gen_maskstoreu(x_addr, ZRegS(i), p_mask, lane * otype_sz_); + } + } +} + +void jit_single_blk_kernel_t::gen_ker4x4(int i_off, int o_off, int input_stride, + int output_stride, int in_tail, int out_tail) { + gen_tr4x4(i_off, o_off, input_stride, output_stride, in_tail, out_tail); +} + +void jit_single_blk_kernel_t::gen_tr8x8(int i_off, int o_off, int input_stride, + int output_stride, int in_tail, int out_tail) { + + constexpr int lane = 8; + + if (in_tail == 0 || out_tail == 0) return; + + for (int i = 0; i < out_tail; ++i) { + if (in_tail != lane) { + add_imm(x_addr, reg_ptr_in_, i_off + i * input_stride * itype_sz_, + x_tmp_0); + gen_maskloadu(ZRegS(i), x_addr, p_mask, lane * itype_sz_); + } else { + add_imm(x_addr, reg_ptr_in_, i_off + i * input_stride * itype_sz_, + x_tmp_0); + gen_loadu(ZRegS(i), x_addr, lane * itype_sz_); + } + } + + gen_transpose_8x8(); + + for (int i = 0; i < in_tail; ++i) { + if (out_tail == lane) { + add_imm(x_addr, reg_ptr_out_, o_off + i * output_stride * otype_sz_, + x_tmp_0); + gen_storeu(x_addr, ZRegS(i), lane * otype_sz_); + } else { + add_imm(x_addr, reg_ptr_out_, o_off + i * output_stride * otype_sz_, + x_tmp_0); + gen_maskstoreu(x_addr, ZRegS(i), p_mask, lane * otype_sz_); + } + } +} + +// tail: 0 ~ 8 +// support: either in_tail or out_tail is not 8, but not both +void jit_single_blk_kernel_t::gen_ker8x8(int i_off, int o_off, int input_stride, + int output_stride, int in_tail, int out_tail) { + gen_tr8x8(i_off, o_off, input_stride, output_stride, in_tail, out_tail); +} + +void jit_single_blk_kernel_t::gen_ker16x16_in_8x8( + int i_off, int o_off, int input_stride, int output_stride) { + const auto lane = 16; + const auto sub_lane = lane / 2; + + i_off *= itype_sz_; + o_off *= otype_sz_; + + gen_tr8x8(i_off, o_off, input_stride, output_stride, sub_lane, sub_lane); + gen_tr8x8(i_off + input_stride * sub_lane * itype_sz_, + o_off + sub_lane * otype_sz_, input_stride, output_stride, sub_lane, + sub_lane); + gen_tr8x8(i_off + sub_lane * itype_sz_, + o_off + output_stride * sub_lane * otype_sz_, input_stride, + output_stride, sub_lane, sub_lane); + gen_tr8x8(i_off + (input_stride * sub_lane + sub_lane) * itype_sz_, + o_off + (output_stride * sub_lane + sub_lane) * otype_sz_, + input_stride, output_stride, sub_lane, sub_lane); +} + +// tail can be 1 ~ 16, using sve2 for now +void jit_single_blk_kernel_t::gen_ker16x16_in_8x8(int i_off, int o_off, + int input_stride, int output_stride, int in_tail, int out_tail) { + constexpr auto lane = 16; + constexpr auto sub_lane = lane / 2; + auto tail = in_tail != lane ? in_tail : out_tail; + + const auto l_tail = tail < sub_lane ? tail : sub_lane; + const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; + + i_off *= itype_sz_; + o_off *= otype_sz_; + + if (tail == in_tail) { + gen_tr8x8(i_off, o_off, input_stride, output_stride, l_tail, sub_lane); + gen_tr8x8(i_off + input_stride * sub_lane * itype_sz_, + o_off + sub_lane * otype_sz_, input_stride, output_stride, + l_tail, sub_lane); + gen_tr8x8(i_off + sub_lane * itype_sz_, + o_off + output_stride * sub_lane * otype_sz_, input_stride, + output_stride, u_tail, sub_lane); + gen_tr8x8(i_off + itype_sz_ * (input_stride * sub_lane + sub_lane), + o_off + otype_sz_ * (output_stride * sub_lane + sub_lane), + input_stride, output_stride, u_tail, sub_lane); + } else { + gen_tr8x8(i_off, o_off, input_stride, output_stride, sub_lane, l_tail); + gen_tr8x8(i_off + input_stride * sub_lane * itype_sz_, + o_off + sub_lane * otype_sz_, input_stride, output_stride, + sub_lane, u_tail); + gen_tr8x8(i_off + sub_lane * itype_sz_, + o_off + output_stride * sub_lane * itype_sz_, input_stride, + output_stride, sub_lane, l_tail); + gen_tr8x8(i_off + itype_sz_ * (input_stride * sub_lane + sub_lane), + o_off + otype_sz_ * (output_stride * sub_lane + sub_lane), + input_stride, output_stride, sub_lane, u_tail); + } +} + +void jit_single_blk_kernel_t::gen_ker32x32_in_16x16( + int i_off, int o_off, int input_stride, int output_stride) { + + const auto lane = 32; + const auto sub_lane = lane / 2; + gen_ker16x16_in_8x8(i_off, o_off, input_stride, output_stride); + gen_ker16x16_in_8x8(i_off + sub_lane * input_stride, o_off + sub_lane, + input_stride, output_stride); + gen_ker16x16_in_8x8(i_off + sub_lane, o_off + output_stride * sub_lane, + input_stride, output_stride); + gen_ker16x16_in_8x8(i_off + input_stride * sub_lane + sub_lane, + o_off + output_stride * sub_lane + sub_lane, input_stride, + output_stride); +} + +void jit_single_blk_kernel_t::gen_ker32x32_in_16x16(int i_off, int o_off, + int input_stride, int output_stride, int in_tail, int out_tail) { + + constexpr auto lane = 32; + constexpr auto sub_lane = lane / 2; + auto tail = in_tail != lane ? in_tail : out_tail; + + const auto l_tail = tail < sub_lane ? tail : sub_lane; + const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; + + if (tail == in_tail) { + gen_ker16x16_in_8x8( + i_off, o_off, input_stride, output_stride, l_tail, sub_lane); + gen_ker16x16_in_8x8(i_off + sub_lane * input_stride, o_off + sub_lane, + input_stride, output_stride, l_tail, sub_lane); + gen_ker16x16_in_8x8(i_off + sub_lane, o_off + output_stride * sub_lane, + input_stride, output_stride, u_tail, sub_lane); + gen_ker16x16_in_8x8(i_off + input_stride * sub_lane + sub_lane, + o_off + output_stride * sub_lane + sub_lane, input_stride, + output_stride, u_tail, sub_lane); + } else { + gen_ker16x16_in_8x8( + i_off, o_off, input_stride, output_stride, sub_lane, l_tail); + gen_ker16x16_in_8x8(i_off + sub_lane * input_stride, o_off + sub_lane, + input_stride, output_stride, sub_lane, u_tail); + gen_ker16x16_in_8x8(i_off + sub_lane, o_off + output_stride * sub_lane, + input_stride, output_stride, sub_lane, l_tail); + gen_ker16x16_in_8x8(i_off + input_stride * sub_lane + sub_lane, + o_off + output_stride * sub_lane + sub_lane, input_stride, + output_stride, sub_lane, u_tail); + } +} + +void jit_single_blk_kernel_t::gen_ker64x64_in_32x32( + int i_off, int o_off, int input_stride, int output_stride) { + + const auto lane = 64; + const auto sub_lane = lane / 2; + gen_ker32x32_in_16x16(i_off, o_off, input_stride, output_stride); + gen_ker32x32_in_16x16(i_off + sub_lane * input_stride, o_off + sub_lane, + input_stride, output_stride); + gen_ker32x32_in_16x16(i_off + sub_lane, o_off + output_stride * sub_lane, + input_stride, output_stride); + gen_ker32x32_in_16x16(i_off + input_stride * sub_lane + sub_lane, + o_off + output_stride * sub_lane + sub_lane, input_stride, + output_stride); +} + +void jit_single_blk_kernel_t::gen_ker64x64_in_32x32(int i_off, int o_off, + int input_stride, int output_stride, int in_tail, int out_tail) { + constexpr auto lane = 64; + constexpr auto sub_lane = lane / 2; + auto tail = in_tail != lane ? in_tail : out_tail; + + const auto l_tail = tail < sub_lane ? tail : sub_lane; + const auto u_tail = tail < sub_lane ? 0 : tail - sub_lane; + + if (tail == in_tail) { + gen_ker32x32_in_16x16( + i_off, o_off, input_stride, output_stride, l_tail, sub_lane); + gen_ker32x32_in_16x16(i_off + sub_lane * input_stride, o_off + sub_lane, + input_stride, output_stride, l_tail, sub_lane); + gen_ker32x32_in_16x16(i_off + sub_lane, + o_off + output_stride * sub_lane, input_stride, output_stride, + u_tail, sub_lane); + gen_ker32x32_in_16x16(i_off + input_stride * sub_lane + sub_lane, + o_off + output_stride * sub_lane + sub_lane, input_stride, + output_stride, u_tail, sub_lane); + } else { + gen_ker32x32_in_16x16( + i_off, o_off, input_stride, output_stride, sub_lane, l_tail); + gen_ker32x32_in_16x16(i_off + sub_lane * input_stride, o_off + sub_lane, + input_stride, output_stride, sub_lane, u_tail); + gen_ker32x32_in_16x16(i_off + sub_lane, + o_off + output_stride * sub_lane, input_stride, output_stride, + sub_lane, l_tail); + gen_ker32x32_in_16x16(i_off + input_stride * sub_lane + sub_lane, + o_off + output_stride * sub_lane + sub_lane, input_stride, + output_stride, sub_lane, u_tail); + } +} + +} // namespace tr + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/reorder/jit_uni_reorder.cpp b/src/cpu/aarch64/reorder/jit_uni_reorder.cpp new file mode 100644 index 00000000000..6851223ea2f --- /dev/null +++ b/src/cpu/aarch64/reorder/jit_uni_reorder.cpp @@ -0,0 +1,525 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* Copyright 2020-2024 FUJITSU LIMITED +* Copyright 2022-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include + +#include "common/c_types_map.hpp" +#include "common/dnnl_thread.hpp" +#include "common/memory_desc_wrapper.hpp" +#include "common/primitive.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/cpu_primitive.hpp" +#include "cpu/reorder/cpu_reorder_pd.hpp" + +#include "cpu/aarch64/reorder/jit_uni_reorder.hpp" + +// #define DNNL_DEV_MODE +#if defined(DNNL_DEV_MODE) +#define DEBUg(...) \ + do { \ + if (get_verbose(verbose_t::debuginfo) > 1) { __VA_ARGS__ } \ + } while (0) +#else +#define DEBUg(...) +#endif +#define DEBUG(...) DEBUg(__VA_ARGS__) + +using namespace Xbyak_aarch64; +using namespace dnnl::impl::types; + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +status_t jit_uni_reorder_t::pd_t::init( + engine_t *engine, engine_t *src_engine, engine_t *dst_engine) { + CHECK(cpu_reorder_pd_t::init(engine, src_engine, dst_engine)); + + CHECK(init_scratchpad()); + + return status::success; +} + +status_t jit_uni_reorder_t::pd_t::init_scratchpad() { + auto scratchpad = scratchpad_registry().registrar(); + + const bool compensation_needed + = prb_.req_s8s8_comp || prb_.req_asymmetric_comp; + if (compensation_needed) { + const memory_desc_wrapper od(dst_md()); + const auto G = with_groups_ ? od.padded_dims()[0] : 1; + const auto N = od.padded_dims()[with_groups_ ? 1 : 0]; + static constexpr int cache_line_size = 16; + const auto wspace_per_thr_size + = utils::rnd_up(G * N, cache_line_size) * sizeof(int32_t); + + const auto compensation_reduce_size = wspace_per_thr_size * nthr_; + + // Every thread gets its own scratchpad space for each N. + scratchpad.template book( + memory_tracking::names::key_reorder_space, + compensation_reduce_size); + } + + if (!attr()->scales_.has_default_values(DNNL_ARG_DST)) { + const memory_desc_wrapper input_d(src_md()); + int mask = attr()->scales_.get_mask(DNNL_ARG_DST); + get_D_values(input_d, mask, nullptr, &D_mask_, nullptr); + if (D_mask_ > 1) { + scratchpad.template book( + memory_tracking::names::key_reorder_precomputed_dst_scales, + D_mask_); + } + } + + return status::success; +} + +status_t jit_uni_reorder_t::pd_t::create(reorder_pd_t **reorder_pd, + engine_t *engine, const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, + const memory_desc_t *dst_md) { + if (!impl::is_dense_format_kind({src_md, dst_md})) + return status::unimplemented; + auto prb = tr::prb_t(); + + status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr); + if (prb_init_status != status::success) return prb_init_status; + + tr::prb_block_for_cache(prb); + DEBUG({ + verbose_printf( + verbose_t::debuginfo, "cache: %s\n", prb_dump(prb).c_str()); + }); + + int ndims_ker_max {}; + int nthr = dnnl_get_max_threads(); + tr::prb_thread_kernel_balance(prb, ndims_ker_max, nthr); + + if (prb.is_tail_present) prb_node_dependency(prb); + + tr::kernel_t::desc_t ker_desc; + status_t ker_init_status + = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max); + if (ker_init_status != status::success) return ker_init_status; + + const int ndims_driver = prb.ndims - ker_desc.prb.ndims; + if (ndims_driver > jit_uni_reorder_t::ndims_driver_max) + return status::unimplemented; + + DEBUG({ + verbose_printf(verbose_t::debuginfo, "ker : %s\n", + prb_dump(ker_desc.prb).c_str()); + }); + + auto _pd = make_unique_pd( + attr, src_engine->kind(), src_md, dst_engine->kind(), dst_md); + if (_pd == nullptr) return status::out_of_memory; + + _pd->nthr_ = nthr; + _pd->prb_ = prb; + _pd->with_groups_ + = prb.compensation_mask == tr::prb_t::comp_mask_with_groups; + CHECK(_pd->init(engine, src_engine, dst_engine)); + _pd->ker_desc_ = ker_desc; + CHECK(_pd->init_scratchpad_md()); + + return safe_ptr_assign(*reorder_pd, _pd.release()); +} + +void jit_uni_reorder_t::omp_driver_0d(int off, const char *in, char *out, + const float *src_scales, const float *dst_scales, int src_zp, + int dst_zp, int32_t *compensation_scratch) const { + const tr::prb_t &prb = pd()->prb_; + + tr::call_param_t base_params; + base_params.in = in; + base_params.out = out; + base_params.src_scales = src_scales; + base_params.dst_scales = dst_scales; + base_params.src_zp = src_zp; + base_params.dst_zp = dst_zp; + base_params.compensation_scratch = compensation_scratch; + + if (prb.is_tail_present) { + tr::tail_call_param_t tail_params; + tail_params.base_params = base_params; + + static constexpr int omp_ndims = 0; + fill_curr_data_chunks(prb, off, nullptr, omp_ndims, tail_params); + + (*kernel_)(&tail_params); + } else { + (*kernel_)(&base_params); + } +} + +void jit_uni_reorder_t::omp_driver_1d(int ithr, int nthr, int off, + const char *in, char *out, const float *src_scales, + const float *dst_scales, int src_zp, int dst_zp, + int32_t *compensation_scratch) const { + const tr::prb_t &prb = pd()->prb_; + const tr::node_t *ns = prb.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) { + tr::call_param_t base_params; + base_params.in = in + d0 * ns[0].is * data_type_size(prb.itype); + base_params.out = out + d0 * ns[0].os * data_type_size(prb.otype); + base_params.src_scales = src_scales + d0 * ns[0].ss; + base_params.dst_scales = dst_scales + d0 * ns[0].ss; + base_params.src_zp = src_zp; + base_params.dst_zp = dst_zp; + base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs; + + if (prb.is_tail_present) { + tr::tail_call_param_t tail_params; + tail_params.base_params = base_params; + + static constexpr int omp_ndims = 1; + const ptrdiff_t omp_data_chunks[omp_ndims] = {d0}; + fill_curr_data_chunks( + prb, off, omp_data_chunks, omp_ndims, tail_params); + + (*kernel_)(&tail_params); + } else { + (*kernel_)(&base_params); + } + }); +} + +void jit_uni_reorder_t::omp_driver_2d(int ithr, int nthr, int off, + const char *in, char *out, const float *src_scales, + const float *dst_scales, int src_zp, int dst_zp, + int32_t *compensation_scratch) const { + const tr::prb_t &prb = pd()->prb_; + const tr::node_t *ns = prb.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d1, ptrdiff_t d0) { + tr::call_param_t base_params; + base_params.in = in + + (d0 * ns[0].is + d1 * ns[1].is) * data_type_size(prb.itype); + base_params.out = out + + (d0 * ns[0].os + d1 * ns[1].os) * data_type_size(prb.otype); + base_params.src_scales = src_scales + d0 * ns[0].ss + d1 * ns[1].ss; + base_params.dst_scales = dst_scales + d0 * ns[0].ss + d1 * ns[1].ss; + base_params.src_zp = src_zp; + base_params.dst_zp = dst_zp; + base_params.compensation_scratch + = compensation_scratch + d0 * ns[0].cs + d1 * ns[1].cs; + + if (prb.is_tail_present) { + tr::tail_call_param_t tail_params; + tail_params.base_params = base_params; + + static constexpr int omp_ndims = 2; + const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1}; + fill_curr_data_chunks( + prb, off, omp_data_chunks, omp_ndims, tail_params); + + (*kernel_)(&tail_params); + } else { + (*kernel_)(&base_params); + } + }); +} + +void jit_uni_reorder_t::omp_driver_3d(int ithr, int nthr, int off, + const char *in, char *out, const float *src_scales, + const float *dst_scales, int src_zp, int dst_zp, + int32_t *compensation_scratch) const { + const tr::prb_t &prb = pd()->prb_; + const tr::node_t *ns = prb.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n, + (ptrdiff_t)ns[0].n, [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + tr::call_param_t base_params; + base_params.in = in + + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is) + * data_type_size(prb.itype); + base_params.out = out + + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os) + * data_type_size(prb.otype); + base_params.src_scales + = src_scales + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; + base_params.dst_scales + = dst_scales + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss; + base_params.src_zp = src_zp; + base_params.dst_zp = dst_zp; + base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs + + d1 * ns[1].cs + d2 * ns[2].cs; + + if (prb.is_tail_present) { + tr::tail_call_param_t tail_params; + tail_params.base_params = base_params; + + static constexpr int omp_ndims = 3; + const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1, d2}; + fill_curr_data_chunks( + prb, off, omp_data_chunks, omp_ndims, tail_params); + + (*kernel_)(&tail_params); + } else { + (*kernel_)(&base_params); + } + }); +} + +void jit_uni_reorder_t::omp_driver_4d(int ithr, int nthr, int off, + const char *in, char *out, const float *src_scales, + const float *dst_scales, int src_zp, int dst_zp, + int32_t *compensation_scratch) const { + const tr::prb_t &prb = pd()->prb_; + const tr::node_t *ns = prb.nodes + off; + for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n, + (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n, + [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) { + tr::call_param_t base_params; + base_params.in = in + + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is + + d3 * ns[3].is) + * data_type_size(prb.itype); + base_params.out = out + + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os + + d3 * ns[3].os) + * data_type_size(prb.otype); + base_params.src_scales = src_scales + d0 * ns[0].ss + d1 * ns[1].ss + + d2 * ns[2].ss + d3 * ns[3].ss; + base_params.dst_scales = dst_scales + d0 * ns[0].ss + d1 * ns[1].ss + + d2 * ns[2].ss + d3 * ns[3].ss; + base_params.src_zp = src_zp; + base_params.dst_zp = dst_zp; + base_params.compensation_scratch = compensation_scratch + d0 * ns[0].cs + + d1 * ns[1].cs + d2 * ns[2].cs + d3 * ns[3].cs; + + if (prb.is_tail_present) { + tr::tail_call_param_t tail_params; + tail_params.base_params = base_params; + + static constexpr int omp_ndims = 4; + const ptrdiff_t omp_data_chunks[omp_ndims] = {d0, d1, d2, d3}; + fill_curr_data_chunks( + prb, off, omp_data_chunks, omp_ndims, tail_params); + + (*kernel_)(&tail_params); + } else { + (*kernel_)(&base_params); + } + }); +} + +void jit_uni_reorder_t::omp_driver(const char *in, char *out, + const float *src_scales, const float *dst_scales, + const int32_t *src_zero_points, const int32_t *dst_zero_points, + const memory_tracking::grantor_t &scratchpad) const { + in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype); + out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype); + + DEBUG({ + verbose_printf(verbose_t::debuginfo, "prb : %s\n", + tr::prb_dump(pd()->prb_).c_str()); + }); + DEBUG({ + verbose_printf(verbose_t::debuginfo, "ker : %s\n", + tr::prb_dump(pd()->ker_desc_.prb).c_str()); + }); + + int ndims = pd()->prb_.ndims; + int ndims_ker = pd()->ker_desc_.prb.ndims; + const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; + const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; + const bool req_compensation = req_s8s8_comp || req_asymmetric_comp; + assert(ndims - ndims_ker <= ndims_driver_max); + + auto src_zp = src_zero_points ? src_zero_points[0] : 0; + auto dst_zp = dst_zero_points ? dst_zero_points[0] : 0; + int32_t *compensation_reduce_scratch = scratchpad.template get( + memory_tracking::names::key_reorder_space); + + const memory_desc_wrapper od(pd()->dst_md()); + const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; + const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; + static constexpr int cache_line_size = 16; + const auto wspace_per_thr_size = utils::rnd_up(G * N, cache_line_size); + const auto wspace_per_thr_bytes = wspace_per_thr_size * sizeof(int32_t); + + if (ndims - ndims_ker == 0) { + if (req_compensation) + std::memset(compensation_reduce_scratch, 0, wspace_per_thr_bytes); + + omp_driver_0d(ndims_ker, in, out, src_scales, dst_scales, src_zp, + dst_zp, compensation_reduce_scratch); + } else { + parallel(pd()->nthr_, [&](const int ithr, const int nthr) { + int32_t *compensation_scratch = nullptr; + if (req_compensation) { + compensation_scratch = &compensation_reduce_scratch[ithr + * wspace_per_thr_size]; + std::memset(compensation_scratch, 0, wspace_per_thr_bytes); + } + + switch (ndims - ndims_ker) { + case 1: + omp_driver_1d(ithr, nthr, ndims_ker, in, out, src_scales, + dst_scales, src_zp, dst_zp, compensation_scratch); + break; + case 2: + omp_driver_2d(ithr, nthr, ndims_ker, in, out, src_scales, + dst_scales, src_zp, dst_zp, compensation_scratch); + break; + case 3: + omp_driver_3d(ithr, nthr, ndims_ker, in, out, src_scales, + dst_scales, src_zp, dst_zp, compensation_scratch); + break; + case 4: + omp_driver_4d(ithr, nthr, ndims_ker, in, out, src_scales, + dst_scales, src_zp, dst_zp, compensation_scratch); + break; + default: assert(!"unimplemented"); + } + }); + } + + //reduction of intermediate compensation results to the final output + if (req_compensation) { + const int nthr = ndims - ndims_ker == 0 ? 1 : pd()->nthr_; + reduce_compensation( + out, compensation_reduce_scratch, nthr, wspace_per_thr_size); + } +} + +void jit_uni_reorder_t::reduce_compensation(char *out, + const int32_t *compensation_reduce_scratch, const int nthr, + const dim_t wspace_per_thr_size) const { + + const memory_desc_wrapper od(pd()->dst_md()); + const size_t offset = od.size() - od.additional_buffer_size(); + + static constexpr auto comp_dt_size = sizeof(int32_t); + static constexpr int32_t comp_s8s8_shift = 128; + + // Note: We do not need to explicitly zero-out compensation buffer, as the + // per_thread buffers are already zeroed out in the padded area. + const auto G = pd()->with_groups_ ? od.padded_dims()[0] : 1; + const auto N = od.padded_dims()[pd()->with_groups_ ? 1 : 0]; + const auto GN = G * N; + const bool req_s8s8_comp = pd()->prb_.req_s8s8_comp; + const bool req_asymmetric_comp = pd()->prb_.req_asymmetric_comp; + const size_t zp_offset + = offset + (pd()->prb_.req_s8s8_comp ? GN * comp_dt_size : 0); + + parallel_nd(GN, [&](int idx) { + int32_t acc = 0; + for (int ithr = 0; ithr < nthr; ithr++) { + acc -= compensation_reduce_scratch[ithr * wspace_per_thr_size + + idx]; + } + if (req_s8s8_comp) { + int32_t *out_comp = reinterpret_cast(&out[offset]); + out_comp[idx] = comp_s8s8_shift * acc; + } + if (req_asymmetric_comp) { + int32_t *out_asym_comp + = reinterpret_cast(&out[zp_offset]); + out_asym_comp[idx] = acc; + } + }); +} + +void jit_uni_reorder_t::fill_curr_data_chunks(const tr::prb_t &prb, + const int off, const ptrdiff_t *omp_data_chunks, const int omp_ndims, + tr::tail_call_param_t &c) const { + // Chunks are backwards numered i.e: + // [0] -> [node_size] + // [1] -> [node_size - 1] + // ... + // [node_size - 1] -> [1] + + // It is done like this, because it is easier to decrement counter + // and check if it is equal to zero than increment and check + // if it is equal to node_size in jit kernel. + + static constexpr int64_t empty_chunk_info = -1; + static constexpr int64_t last_chunk = 1; + + for (int curr_node_id = prb.ndims - 1; curr_node_id >= 0; curr_node_id--) { + const int parent_node_id = prb.nodes[curr_node_id].parent_node_id; + const bool is_drv_processing_this_node + = curr_node_id >= off && curr_node_id <= off + omp_ndims - 1; + const bool is_tail_processing + = prb.is_tail_in_one_of_child_nodes(curr_node_id) + || prb.nodes[curr_node_id].tail_size > 0; + + if (is_drv_processing_this_node && is_tail_processing) { + const int inner_idx = curr_node_id - off; + assert(inner_idx < omp_ndims); + const int64_t node_size = prb.nodes[curr_node_id].tail_size > 0 + ? prb.nodes[curr_node_id].tail_size + : prb.nodes[curr_node_id].n; + const int64_t data_chunk = node_size - omp_data_chunks[inner_idx]; + + if (!prb.nodes[curr_node_id].is_parent_empty()) { + const bool is_parent_chunk_last + = c.curr_data_chunks[parent_node_id] == last_chunk; + c.curr_data_chunks[curr_node_id] + = is_parent_chunk_last ? data_chunk : empty_chunk_info; + c.zeroing_data = static_cast( + is_parent_chunk_last && data_chunk <= 0); + } else { + c.curr_data_chunks[curr_node_id] = data_chunk; + c.zeroing_data = static_cast(data_chunk <= 0); + } + c.skip_kernel_execution = static_cast(c.zeroing_data + && !prb.nodes[curr_node_id].is_zero_pad_needed); + if (c.zeroing_data || c.skip_kernel_execution) break; + } else + c.curr_data_chunks[curr_node_id] = empty_chunk_info; + } +} + +status_t jit_uni_reorder_t::init(engine_t *engine) { + CHECK(safe_ptr_assign(kernel_, tr::kernel_t::create(pd()->ker_desc_))); + return kernel_->create_kernel(); +} + +status_t jit_uni_reorder_t::execute(const exec_ctx_t &ctx) const { + const auto &scratchpad = ctx.get_scratchpad_grantor(); + auto in = CTX_IN_MEM(const char *, DNNL_ARG_FROM); + auto out = CTX_OUT_MEM(char *, DNNL_ARG_TO); + DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); + DEFINE_ARG_SCALES_BUFFER(dst_scales_, DNNL_ARG_DST); + + const float *dst_scales = pd()->precompute_scales( + scratchpad, pd()->attr(), pd()->D_mask_, dst_scales_); + assert(dst_scales); + + const int32_t *src_zero_points = CTX_IN_MEM( + const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); + const int32_t *dst_zero_points = CTX_IN_MEM( + const int32_t *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); + + omp_driver(in, out, src_scales, dst_scales, src_zero_points, + dst_zero_points, scratchpad); + + return status::success; +} + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/reorder/jit_uni_reorder.hpp b/src/cpu/aarch64/reorder/jit_uni_reorder.hpp new file mode 100644 index 00000000000..24fc9e30943 --- /dev/null +++ b/src/cpu/aarch64/reorder/jit_uni_reorder.hpp @@ -0,0 +1,104 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* Copyright 2020-2023 FUJITSU LIMITED +* Copyright 2022, 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_REORDER_JIT_UNI_REORDER_HPP +#define CPU_AARCH64_REORDER_JIT_UNI_REORDER_HPP + +#include + +#include "common/c_types_map.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder_utils.hpp" +#include "cpu/reorder/cpu_reorder_pd.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +struct jit_uni_reorder_t : public primitive_t { + using primitive_t::primitive_t; + struct pd_t : public cpu_reorder_pd_t { + using cpu_reorder_pd_t::cpu_reorder_pd_t; + + DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t); + + tr::prb_t prb_; + tr::kernel_t::desc_t ker_desc_; + int nthr_; + bool with_groups_ = false; + dim_t D_mask_ = 0; + + status_t init( + engine_t *engine, engine_t *src_engine, engine_t *dst_engine); + + private: + status_t init_scratchpad(); + static status_t create(reorder_pd_t **reorder_pd, engine_t *engine, + const primitive_attr_t *attr, engine_t *src_engine, + const memory_desc_t *src_md, engine_t *dst_engine, + const memory_desc_t *dst_md); + + friend dnnl::impl::impl_list_item_t; + }; + + status_t init(engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override; + + enum { ndims_driver_max = 4 }; + +private: + void omp_driver_0d(int off, const char *in, char *out, + const float *src_scales, const float *dst_scales, int src_zp, + int dst_zp, int32_t *compensation_scratch) const; + void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out, + const float *src_scales, const float *dst_scales, int src_zp, + int dst_zp, int32_t *compensation_scratch) const; + void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out, + const float *src_scales, const float *dst_scales, int src_zp, + int dst_zp, int32_t *compensation_scratch) const; + void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out, + const float *src_scales, const float *dst_scales, int src_zp, + int dst_zp, int32_t *compensation_scratch) const; + void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out, + const float *src_scales, const float *dst_scales, int src_zp, + int dst_zp, int32_t *compensation_scratch) const; + + void omp_driver(const char *in, char *out, const float *src_scales, + const float *dst_scales, const int32_t *src_zero_points, + const int32_t *dst_zero_points, + const memory_tracking::grantor_t &scratchpad) const; + + void fill_curr_data_chunks(const tr::prb_t &prb, const int off, + const ptrdiff_t *omp_data_chunks, const int omp_ndims, + tr::tail_call_param_t &c) const; + + void reduce_compensation(char *out, + const int32_t *compensation_reduce_scratch, const int nthr, + const dim_t wspace_per_thr_size) const; + + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + std::unique_ptr kernel_; +}; + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/aarch64/reorder/jit_uni_reorder_kernel.cpp b/src/cpu/aarch64/reorder/jit_uni_reorder_kernel.cpp new file mode 100644 index 00000000000..01b27b0e5ac --- /dev/null +++ b/src/cpu/aarch64/reorder/jit_uni_reorder_kernel.cpp @@ -0,0 +1,2083 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* Copyright 2020-2024 FUJITSU LIMITED +* Copyright 2022-2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include + +#include "common/c_types_map.hpp" +#include "common/nstl.hpp" +#include "common/type_helpers.hpp" +#include "common/utils.hpp" + +#include "cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp" + +#include "cpu/aarch64/jit_generator.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace tr { + +using namespace Xbyak_aarch64; +using namespace dnnl::impl::types; + +status_t kernel_t::desc_init( + kernel_t::desc_t &desc, const prb_t &prb, int ndims_ker_max) { + + desc.prb = prb; + desc.prb.ioff = desc.prb.ooff = 0; + + if (ndims_ker_max > prb.ndims) return status::invalid_arguments; + + auto ndims_ker_max_f = [&]() { + size_t cur_size = 1; + for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n) + if (cur_size >= ker_prb_size_min) return d; + return prb.ndims; + }; + + if (ndims_ker_max <= 0) ndims_ker_max = ndims_ker_max_f(); + + /* traverse through kernel implementations */ + /* TODO: find a better way to do that... */ + desc.id = 0; + for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) { + desc.prb.ndims = ndims_ker; + if (jit_uni_reorder_kernel_f32_t::applicable(desc.prb)) + return status::success; + } + + return status::unimplemented; +} + +kernel_t *kernel_t::create(const kernel_t::desc_t &desc) { + switch (desc.id) { + case 0: return new jit_uni_reorder_kernel_f32_t(desc); + default: assert(!"unknown kernel id"); return nullptr; + } + + return nullptr; +} + +/* kernel */ +void jit_uni_reorder_kernel_f32_t::operator()(const call_param_t *c) const { + jit_generator::operator()(c); +} + +void jit_uni_reorder_kernel_f32_t::operator()( + const tail_call_param_t *c) const { + jit_generator::operator()(c); +} + +status_t jit_uni_reorder_kernel_f32_t::create_kernel() { + return jit_generator::create_kernel(); +} + +#define PARAM(x) \ + abi_param1, \ + prb_.is_tail_present ? offsetof(tail_call_param_t, base_params) \ + + offsetof(call_param_t, x) \ + : offsetof(call_param_t, x) +#define TAIL_PARAM(x) abi_param1, offsetof(tail_call_param_t, x) + +bool jit_uni_reorder_kernel_f32_t::simple_impl_desc_init( + const prb_t &prb, simple_impl_desc_t *desc) { + const int ndims = prb.ndims; + + int ndims_full_unroll = 0; + int len_last_dim_unroll = 1; + int tail_len_unroll = 0; + int len_unroll = 1; + + // It is responsible for finding as many values + // as kernel can unroll. If tail is present then + // kernel will unroll only last node (possible improvement). + // If there is no tail kernel can unroll a few nodes without any loops etc. + // ndims_full_unroll - how many nodes will be unrolled + // len_last_dim_unroll - what piece of last unrolled node will be unrolled + if (prb.is_tail_present) { + ndims_full_unroll = 1; + len_unroll = prb.nodes[0].n; + tail_len_unroll = prb.nodes[0].is_zero_pad_needed + ? 0 + : static_cast(prb.nodes[0].tail_size); + } else { + for (int d = 0; d < ndims; ++d) { + const auto &node = prb.nodes[d]; + if (len_unroll * node.n <= len_unroll_max) { + ndims_full_unroll++; + len_unroll *= node.n; + } else { + len_last_dim_unroll = len_unroll_max / len_unroll; + while (node.n % len_last_dim_unroll) + --len_last_dim_unroll; + len_unroll *= len_last_dim_unroll; + break; + } + } + } + + if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max) return false; + + if (desc) { + desc->ndims_full_unroll = ndims_full_unroll; + desc->len_last_dim_unroll = len_last_dim_unroll; + desc->tail_len_unroll = tail_len_unroll; + desc->len_unroll = len_unroll; + } + + return true; +} + +bool jit_uni_reorder_kernel_f32_t::applicable(const prb_t &p) { + using namespace data_type; + + bool bf16_ok = (mayiuse_bf16() && (p.itype == bf16) && (p.otype == bf16) + && !interim_f32_needed(p, false) && p.beta == 0.f) + || (p.itype != bf16 && p.otype != bf16) + || (p.itype == f32 && p.otype == bf16 && mayiuse_bf16() + && p.beta == 0.f) + || (p.itype == bf16 && p.otype == f32 && mayiuse_bf16() + && p.beta == 0.f); + + bool is_f16 = (p.itype == f16 || p.otype == f16); + bool f16_ok = (p.itype == f32 && p.otype == f16 && p.beta == 0.f) + || (p.itype == f16 && p.otype == f32 && p.beta == 0.f) + || (p.itype == f16 && p.otype == f16 && p.beta == 0.f); + + bool ok = true && p.ndims > 0 + && utils::one_of(p.itype, f32, f16, bf16, s32, data_type::s8, u8) + && utils::one_of(p.otype, f32, f16, bf16, s32, data_type::s8, u8) + && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */ + && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */ + && simple_impl_desc_init(p, nullptr) && prb_has_small_strides(p) + && bf16_ok && IMPLICATION(is_f16, f16_ok); + + return ok; +} + +XReg jit_uni_reorder_kernel_f32_t::o_addr( + int o_off, bool with_type_multiplier) { + if (o_off) { + add_imm(X_DEFAULT_ADDR, x_ptr_out_off, + o_off * (with_type_multiplier ? otype_sz_ : 1), X_TMP); + return X_DEFAULT_ADDR; + } + + return x_ptr_out_off; +} + +XReg jit_uni_reorder_kernel_f32_t::src_s_addr(int s_off) { + if (s_off) { + add_imm(X_DEFAULT_ADDR, x_ptr_src_scale_off, s_off * stype_sz_, X_TMP); + return X_DEFAULT_ADDR; + } else { + return x_ptr_src_scale_off; + } +} + +XReg jit_uni_reorder_kernel_f32_t::dst_s_addr(int s_off) { + if (s_off) { + add_imm(X_DEFAULT_ADDR, x_ptr_dst_scale_off, s_off * stype_sz_, X_TMP); + return X_DEFAULT_ADDR; + } else { + return x_ptr_dst_scale_off; + } +} + +XReg jit_uni_reorder_kernel_f32_t::c_addr(int c_off) { + if (c_off) { + add_imm(X_DEFAULT_ADDR, x_ptr_comp_off, c_off * sizeof(int32_t), X_TMP); + return X_DEFAULT_ADDR; + } + + return x_ptr_comp_off; +} + +XReg jit_uni_reorder_kernel_f32_t::data_chunk_addr(int node_id) { + add_imm(X_DEFAULT_ADDR, abi_param1, + offsetof(tail_call_param_t, curr_data_chunks) + + sizeof(int64_t) * (node_id), + X_TMP); + return X_DEFAULT_ADDR; +} + +void jit_uni_reorder_kernel_f32_t::step(int off, int prev_i_off, int prev_o_off, + int prev_s_off, int prev_c_off, int &i_off, int &o_off, int &s_off, + int &c_off, int step_size) { + i_off = prev_i_off; + o_off = prev_o_off; + s_off = prev_s_off; + c_off = prev_c_off; + + if (off == 0) return; + + int start_dim = 0, dims_prod = 1; + for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim) + dims_prod *= prb_.n(start_dim); + assert(start_dim < prb_.ndims); + off /= step_size; + + for (int dim_id = start_dim; dim_id < prb_.ndims; ++dim_id) { + i_off += prb_.is(dim_id); + o_off += prb_.os(dim_id); + s_off += prb_.ss(dim_id); + c_off += prb_.cs(dim_id); + + if (off % prb_.n(dim_id)) break; + + i_off += -prb_.n(dim_id) * prb_.is(dim_id); + o_off += -prb_.n(dim_id) * prb_.os(dim_id); + s_off += -prb_.n(dim_id) * prb_.ss(dim_id); + c_off += -prb_.n(dim_id) * prb_.cs(dim_id); + + off /= prb_.n(dim_id); + + if (off == 0) break; /* FIXME: is it really required? */ + } +} + +void jit_uni_reorder_kernel_f32_t::step(int off, int prev_i_off, int prev_o_off, + int &i_off, int &o_off, int step_size) { + int dummy = 0; + step(off, prev_i_off, prev_o_off, dummy, dummy, i_off, o_off, dummy, dummy, + step_size); +} + +bool jit_uni_reorder_kernel_f32_t::can_do_tr4x8() { + using namespace data_type; + + // The kernel is specialised for f32 -> bf16 reorders. + // + // This process relies on swapping the two innermost dimensions. + // Therefore, the input stride in the second node and output stride in + // first node have to be equal to 1. + return mayiuse(sve_256) && prb_.ndims >= 2 + && (prb_.itype == f32 && prb_.otype == bf16) && prb_.n(0) == 4 + && prb_.n(1) == 8 && utils::everyone_is(1, prb_.os(0), prb_.is(1)) + && !prb_.is_tail_present + && prb_.src_scale_type == scale_type_t::NONE + && prb_.dst_scale_type == scale_type_t::NONE && prb_.beta == 0.f + && !compensation_needed_; +} + +bool jit_uni_reorder_kernel_f32_t::process_unroll_tr4x8( + const int ndims, const int len) { + if (!can_do_tr4x8()) return false; + + const int step_size = prb_.n(0) * prb_.n(1); + int i_off = 0, o_off = 0; + for (int off = 0; off < len; off += step_size) { + step(off, i_off, o_off, i_off, o_off, step_size); + tr4x8_sve256(i_off, o_off); + } + + return true; +} + +void jit_uni_reorder_kernel_f32_t::tr4x8_sve256(int i_off, int o_off) { + using namespace data_type; + + auto z0 = ZRegS(0); + auto z1 = ZRegS(1); + auto z2 = ZRegS(2); + auto z3 = ZRegS(3); + + assert(x_tmp_vec.size() >= 4); + auto x_tmp_0 = x_tmp_vec[0]; + auto x_tmp_1 = x_tmp_vec[1]; + auto x_tmp_2 = x_tmp_vec[2]; + auto x_tmp_3 = x_tmp_vec[3]; + + // Load + auto in_ptr_diff = itype_sz_ * prb_.is(0); + add_imm(x_tmp_0, XReg(x_ptr_in_off), itype_sz_ * i_off, X_DEFAULT_ADDR); + add_imm(x_tmp_1, x_tmp_0, 1 * in_ptr_diff, X_DEFAULT_ADDR); + add_imm(x_tmp_2, x_tmp_0, 2 * in_ptr_diff, X_DEFAULT_ADDR); + add_imm(x_tmp_3, x_tmp_0, 3 * in_ptr_diff, X_DEFAULT_ADDR); + + ld1w(z0, P_ALL_ONE, ptr(x_tmp_0)); + ld1w(z1, P_ALL_ONE, ptr(x_tmp_1)); + ld1w(z2, P_ALL_ONE, ptr(x_tmp_2)); + ld1w(z3, P_ALL_ONE, ptr(x_tmp_3)); + + // Transpose + auto z4 = ZReg(4); + auto z5 = ZReg(5); + auto z6 = ZReg(6); + auto z7 = ZReg(7); + + // Interleaving two vectors containing rows of a tile is the same as + // transposing pairs of elements. + // + // If you start with: + // vec0: 0 1 2 3 4 5 6 7 + // vec1: 8 9 10 11 12 13 14 15 + // vec2: 16 17 18 19 20 21 22 23 + // vec3: 24 25 26 27 28 29 30 31 + // + // Then after two zips you have: + // vec4 = zip1(vec0, vec2): + // vec4: 0 16 1 17 2 18 3 19 + // vec5 = zip1(vec1, vec3): + // vec5: 8 24 9 25 10 26 11 27 + // + // Notice that if you convert and interleave these then you are done. That's + // what the subsequent bfcvt-bfcvtnt block of instructions does. + zip1(z4.s, z0, z2); + zip1(z5.s, z1, z3); + zip2(z6.s, z0, z2); + zip2(z7.s, z1, z3); + + // bfcvt converts one f32 vector to bf16 but leaves 0s in every alternate + // position within the destination vector (dst1). bfcvtnt then converts the + // second f32 vector to bf16 while filling in the zeroed spots left by bfcvt + // within dst1. + // + // With the two vectors from above: + // vec4: 0 16 1 17 2 18 3 19 + // vec5: 8 24 9 25 10 26 11 27 + // + // vec4 = bfcvt(vec4) + // vec4: 0 0 16 0 1 0 ... + // ^----------^-----------^ + // zeroed gaps left by bfcvt + // + // Now convert vec5 and fill the gaps in vec4 with a single instruction + // (storing the result in vec4): + // vec4: 0 8 16 24 1 9 ... + // + // Which contains the first 4 transposed columns of the original tile as + // required. + bfcvt(z4.h, P_ALL_ONE / T_z, z4.s); + bfcvtnt(z4.h, P_ALL_ONE / T_m, z5.s); + bfcvt(z6.h, P_ALL_ONE / T_z, z6.s); + bfcvtnt(z6.h, P_ALL_ONE / T_m, z7.s); + + // Store + auto out_ptr_diff = get_sve_length(); + add_imm(x_tmp_0, XReg(x_ptr_out_off), otype_sz_ * o_off, X_DEFAULT_ADDR); + add_imm(x_tmp_1, x_tmp_0, out_ptr_diff, X_DEFAULT_ADDR); + + st1h(z4.h, P_ALL_ONE, ptr(x_tmp_0)); + st1h(z6.h, P_ALL_ONE, ptr(x_tmp_1)); +} + +void jit_uni_reorder_kernel_f32_t::tr8x8_sve256(int i_off, int o_off) { + using namespace data_type; + + const auto cvt2ps + = [=](const int startIdx, const int regNum, data_type_t idt) { + switch (idt) { + case f32: + /* do nothing */ + break; + case f16: cvt_v_f16_f32(startIdx, regNum); break; + case s32: cvt_z_s32_f32(startIdx, regNum); break; + case bf16: cvt_v_bf16_fp32(startIdx, regNum); break; + case data_type::s8: + cvt_z_s8_s32(startIdx, regNum); + cvt_z_s32_f32(startIdx, regNum); + break; + case u8: + cvt_z_u8_s32(startIdx, regNum); + cvt_z_s32_f32(startIdx, regNum); + break; + default: assert(!"unreachable"); + } + }; + + const auto cvt2odt = [=](const int startIdx, const int regNum, + data_type_t odt, data_type_t idt) { + switch (odt) { + case s32: + if (idt == f32) + cvt_z_f32_s32(startIdx, regNum); + else if (idt == data_type::s8) + cvt_z_s8_s32(startIdx, regNum); + else if (idt == u8) + cvt_z_u8_s32(startIdx, regNum); + break; + case data_type::s8: + if (idt == f32) cvt_z_f32_s32(startIdx, regNum); + if (utils::one_of(idt, f32, s32)) + cvt_z_s32_s8(startIdx, regNum); + if (idt == u8) cvt_z_u8_s8(startIdx, regNum); + break; + case data_type::bf16: + if (idt == f32) cvt_v_f32_bf16(startIdx, regNum); + break; + case data_type::f16: + if (idt == f32) cvt_v_f32_f16(startIdx, regNum); + break; + case u8: + if (idt == f32) cvt_z_f32_s32(startIdx, regNum); + if (utils::one_of(idt, f32, s32)) + cvt_z_s32_u8(startIdx, regNum); + if (idt == data_type::s8) cvt_z_s8_u8(startIdx, regNum); + break; + default: assert(!"unreachable"); + } + }; + + const int unroll = 8; + + const bool interim_f32 + = (prb_.itype != f32) || utils::one_of(f32, prb_.itype, prb_.otype); + + const bool need_saturation + = (utils::one_of(prb_.otype, u8, data_type::s8, s32) + && interim_f32); + const uint64_t sveLen = get_sve_length(); + + PReg p_size(DUMMY_IDX); + switch (unroll * itype_sz_) { + case 32: p_size = p_lsb_256; break; + case 16: p_size = p_lsb_128; break; + case 8: p_size = p_lsb_64; break; + default: assert(!"unreachable"); + } + + const int node_0_input_stride = prb_.is(0); + add_imm(X_TMP_0, XReg(x_ptr_in_off), itype_sz_ * i_off, X_DEFAULT_ADDR); + for (int i = 1; i < unroll / 2; i++) + add_imm(x_tmp_vec[i], x_tmp_vec[i - 1], itype_sz_ * node_0_input_stride, + X_DEFAULT_ADDR); + for (uint32_t i = 0; i < unroll / 2; i++) + ld1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i])); + for (int i = 0; i < unroll / 2; i++) + add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4], + itype_sz_ * node_0_input_stride, X_DEFAULT_ADDR); + for (uint32_t i = 0; i < unroll / 2; i++) + ld1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i])); + + if (interim_f32) cvt2ps(0, unroll, prb_.itype); + +#if 0 + /* Debug code to forcedly set test pattern. */ + index(z0.s, 0, 1); + mov(z0.s, P_NOT_256/T_m, 0); + mov(z_tmp_vec[0].s, 16); + for(uint32_t i=1; i<8; i++) { + add(ZRegS{i}, ZRegS{i-1}, z_tmp_vec[0].s); + mov(ZRegS{i}, P_NOT_256/T_m, 0); + } +#endif + + ptrue(p_tmp0.s, VL4); + /* 1st turn */ + for (uint32_t i = 0; i < unroll / 2; i++) { + trn1(z_tmp_vec[i].s, ZRegS {2 * i}, ZRegS {2 * i + 1}); + trn2(z_tmp_vec[unroll / 2 + i].s, ZRegS {2 * i}, ZRegS {2 * i + 1}); + } + + /* 2nd turn */ + trn1(z4.d, z_tmp_vec[0].d, z_tmp_vec[1].d); + trn1(z5.d, z_tmp_vec[4].d, z_tmp_vec[5].d); + trn2(z6.d, z_tmp_vec[0].d, z_tmp_vec[1].d); + trn2(z7.d, z_tmp_vec[4].d, z_tmp_vec[5].d); + trn1(z_tmp_vec[0].d, z_tmp_vec[2].d, z_tmp_vec[3].d); + trn1(z_tmp_vec[1].d, z_tmp_vec[6].d, z_tmp_vec[7].d); + trn2(z_tmp_vec[2].d, z_tmp_vec[2].d, z_tmp_vec[3].d); + trn2(z_tmp_vec[3].d, z_tmp_vec[6].d, z_tmp_vec[7].d); + + /* 3rd turn */ + for (uint32_t i = 0; i < unroll / 2; i++) { + mov(ZRegD {i}, ZRegD {unroll / 2 + i}); + mov(z_tmp_vec[unroll / 2 + i].d, z_tmp_vec[i].d); + } + + /* 4th turn */ + for (uint32_t i = 0; i < unroll / 2; i++) { + ZRegB z {unroll / 2 + i}; + ZRegB z_tmp = z_tmp_vec[unroll / 2 + i].b; + /* Move bit 0-127 to 128-255. */ + ext(z, z, 16); + /* Move bit 128-255 to 0-127. */ + ext(z_tmp, z_tmp, sveLen - 16); + } + + /* 5th turn */ + for (uint32_t i = 0; i < unroll / 2; i++) { + ZRegS z0 {i}; + ZRegS z1 {unroll / 2 + i}; + sel(z0, p_tmp0.s, z0, z_tmp_vec[unroll / 2 + i].s); + sel(z1, p_tmp0, z1, z_tmp_vec[i].s); + } + + if (need_saturation) { + init_saturate_f32(ymm_zero_, ymm_saturation_ubound_, X_TMP_0, + interim_f32 ? f32 : prb_.itype, prb_.otype); + for (int i = 0; i < unroll; i++) + saturate_f32(ZRegS(i), ymm_zero_, ymm_saturation_ubound_, + prb_.otype, P_ALL_ONE); + } + + if (prb_.otype != f32) + cvt2odt(0, unroll, prb_.otype, interim_f32 ? f32 : prb_.itype); + + const int node_1_output_stride = prb_.os(1); + + switch (unroll * otype_sz_) { + case 32: p_size = p_lsb_256; break; + case 16: p_size = p_lsb_128; break; + case 8: p_size = p_lsb_64; break; + default: assert(!"unreachable"); + } + + add_imm(X_TMP_0, XReg(x_ptr_out_off), otype_sz_ * o_off, X_DEFAULT_ADDR); + for (int i = 1; i < unroll / 2; i++) + add_imm(x_tmp_vec[i], x_tmp_vec[i - 1], + otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR); + for (uint32_t i = 0; i < 4; i++) + st1w(ZRegS {i}, p_size / T_z, ptr(x_tmp_vec[i])); + for (int i = 0; i < unroll / 2; i++) + add_imm(x_tmp_vec[i], x_tmp_vec[(i + 3) % 4], + otype_sz_ * node_1_output_stride, X_DEFAULT_ADDR); + + for (uint32_t i = 0; i < unroll / 2; i++) + st1w(ZRegS {4 + i}, p_size / T_z, ptr(x_tmp_vec[i])); +} + +bool jit_uni_reorder_kernel_f32_t::can_do_tr8x8() { + using namespace data_type; + + static constexpr int desirable_node_size = 8; + static constexpr int desirable_stride = 1; + + // This process relies on swapping the two innermost dimensions. + // Therefore, the input stride in the second node and output stride in + // first node have to be equal to 1. + return mayiuse(sve_256) && prb_.ndims >= 2 + && ((utils::one_of(prb_.itype, u8, data_type::s8, s32, f32) + && utils::one_of(prb_.otype, u8, data_type::s8, s32, f32))) + && utils::everyone_is(desirable_node_size, prb_.n(0), prb_.n(1)) + && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(1)) + && !prb_.is_tail_present + && prb_.src_scale_type == scale_type_t::NONE + && prb_.dst_scale_type == scale_type_t::NONE && prb_.beta == 0.f + && !compensation_needed_; +} + +bool jit_uni_reorder_kernel_f32_t::process_unroll_tr8x8( + const int ndims, const int len) { + if (!can_do_tr8x8()) return false; + + const int step_size = prb_.n(0) * prb_.n(1); + int i_off = 0, o_off = 0; + for (int off = 0; off < len; off += step_size) { + step(off, i_off, o_off, i_off, o_off, step_size); + tr8x8_sve256(i_off, o_off); + } + + return true; +} + +template +bool jit_uni_reorder_kernel_f32_t::process_direct_copy( + const int ndims, const int len) { + using namespace data_type; + + static constexpr int desirable_stride = 1; + using TRegS = + typename utils::conditional::type; + const int simd_w = cpu_isa_traits::vlen / itype_sz_; + + // TODO: support tail_processing for direct copy + + const bool do_src_zp = prb_.req_src_zp; + const bool do_dst_zp = prb_.req_dst_zp; + const bool zp_applicable = IMPLICATION( + (do_src_zp || do_dst_zp), utils::one_of(prb_.itype, s32, f32)); + const bool can_do = true && mayiuse(isa) && compensation_needed_ == false + && utils::everyone_is(desirable_stride, prb_.os(0), prb_.is(0)) + && (false || (prb_.itype == prb_.otype ? zp_applicable : false) + || (prb_.itype == s32 && prb_.otype == f32) + || (prb_.itype == f32 && prb_.otype == s32)) + && len % simd_w == 0 && prb_.n(0) % len == 0 + && !prb_.is_tail_present + && prb_.src_scale_type == scale_type_t::NONE + && prb_.dst_scale_type == scale_type_t::NONE && prb_.beta == 0.f; + if (!can_do) return false; + + static constexpr int vmm_zp_last_idx = 15; + const auto vmm_src_zp + = TRegS(do_dst_zp ? vmm_zp_last_idx - 1 : vmm_zp_last_idx); + if (do_src_zp) { + uni_ld1rw(vmm_src_zp, PARAM(src_zp)); + uni_scvtf(vmm_src_zp, vmm_src_zp); + } + const auto vmm_dst_zp = TRegS(vmm_zp_last_idx); + if (do_dst_zp) { + uni_ld1rw(vmm_dst_zp, PARAM(dst_zp)); + uni_scvtf(vmm_dst_zp, vmm_dst_zp); + } + + const auto apply_zp_ps = [&](const TRegS vmm) { + if (do_src_zp) fsub(vmm, vmm, vmm_src_zp); + if (do_dst_zp) fadd(vmm, vmm, vmm_dst_zp); + }; + + for (int off = 0; off < len;) { + // TODO: we need extra reg for proper saturation if otype == s32 + int unroll = nstl::min(16 - (prb_.otype == s32), (len - off) / simd_w); + unroll = (do_src_zp || do_dst_zp) + ? nstl::min(unroll, 16 - do_src_zp - do_dst_zp) + : unroll; + + int ur = 0; + int tmp_ur = 0; + while (ur < unroll) { + int count = 0; + const int vlen = cpu_isa_traits::vlen; + + do { + add_imm(x_tmp_vec[count++], x_ptr_in_off, + (off + ur * simd_w) * itype_sz_, X_DEFAULT_ADDR); + ur++; + } while (ur < unroll && count < x_tmp_vec_size); + + for (int i = 0; i < count; i++) { + if (vlen == 64 || vlen == 32) + ld1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z, ptr(x_tmp_vec[i])); + else if (vlen == 16) + ldr(QReg(tmp_ur + i), ptr(x_tmp_vec[i])); + else + assert(!"unreachable"); + } + tmp_ur += count; + } + + if (prb_.itype != prb_.otype) { + for (int ur = 0; ur < unroll; ++ur) { + TRegS r(ur); + if (prb_.itype == s32 && prb_.otype == f32) { + uni_scvtf(r, r); + apply_zp_ps(r); + } else if (prb_.itype == f32 && prb_.otype == s32) { + apply_zp_ps(r); + uni_frinti(r, r); + uni_fcvtzs(r, r); + } else + assert(!"unreachable"); + } + } else if (do_src_zp || do_dst_zp) { + for (int ur = 0; ur < unroll; ++ur) { + const auto vmm = TRegS(ur); + if (prb_.otype == f32) { + apply_zp_ps(vmm); + } else if (prb_.otype == s32) { + uni_scvtf(vmm, vmm); + apply_zp_ps(vmm); + uni_frinti(vmm, vmm); + uni_fcvtzs(vmm, vmm); + } + } + } + + ur = 0; + tmp_ur = 0; + while (ur < unroll) { + int count = 0; + const int vlen = cpu_isa_traits::vlen; + + do { + add_imm(x_tmp_vec[count++], x_ptr_out_off, + (off + ur * simd_w) * otype_sz_, X_DEFAULT_ADDR); + ur++; + } while (ur < unroll && count < x_tmp_vec_size); + + for (int i = 0; i < count; i++) { + if (vlen == 64 || vlen == 32) + st1w(ZRegS(tmp_ur + i), p_lsb_256 / T_z, ptr(x_tmp_vec[i])); + else if (vlen == 16) + str(QReg(tmp_ur + i), ptr(x_tmp_vec[i])); + else + assert(!"unreachable"); + } + tmp_ur += count; + } + + off += unroll * simd_w; + } + + return true; +} + +void jit_uni_reorder_kernel_f32_t::process_unroll_generic_step(int reg_unroll, + const int *i_off, const int *o_off, const int *s_off, const int *c_off, + const int *zero_padding, const bool tail_processing) { + using namespace data_type; + + auto cvt2ps = [=](const int startIdx, const int regNum, data_type_t idt) { + switch (idt) { + case f32: + /* do nothing */ + break; + case s32: cvt_v_s32_f32(startIdx, regNum); break; + case bf16: cvt_v_bf16_fp32(startIdx, regNum); break; + case f16: cvt_v_f16_f32(startIdx, regNum); break; + case data_type::s8: + cvt_v_s8_s32(startIdx, regNum); + cvt_v_s32_f32(startIdx, regNum); + break; + case u8: + cvt_v_u8_s32(startIdx, regNum); + cvt_v_s32_f32(startIdx, regNum); + break; + default: assert(!"unreachable"); + } + }; + + auto cvt2odt = [=](const int startIdx, const int regNum, data_type_t odt, + data_type_t idt) { + switch (odt) { + case f32: + if (idt == bf16) cvt_v_bf16_fp32(startIdx, regNum); + if (idt == f16) cvt_v_f16_f32(startIdx, regNum); + break; + case s32: + if (idt == f32) + cvt_v_f32_s32(startIdx, regNum); + else if (idt == data_type::s8) + cvt_v_s8_s32(startIdx, regNum); + else if (idt == u8) + cvt_v_u8_s32(startIdx, regNum); + break; + case data_type::s8: + if (idt == f32) cvt_v_f32_s32(startIdx, regNum); + if (idt == f32 || idt == s32) cvt_v_s32_s8(startIdx, regNum); + if (idt == u8) { cvt_v_u8_s8(startIdx, regNum); } + break; + case u8: + if (idt == f32) cvt_v_f32_s32(startIdx, regNum); + if (idt == f32 || idt == s32) cvt_v_s32_u8(startIdx, regNum); + if (idt == data_type::s8) cvt_v_s8_u8(startIdx, regNum); + break; + case bf16: + if (idt == f32) cvt_v_f32_bf16(startIdx, regNum); + break; + case f16: + if (idt == f32) cvt_v_f32_f16(startIdx, regNum); + break; + default: assert(!"unreachable"); + } + }; + + auto load_bytes_addr = [=](const int ur, const int r) { + add_imm(x_tmp_vec[r], x_ptr_in_off, i_off[ur + r] * itype_sz_, + X_DEFAULT_ADDR); + }; + auto load_bytes = [=](const int ur, int size, int r) { + switch (size) { + case 4: ld1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); break; + case 2: ld1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); break; + case 1: ld1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); break; + default: assert(!"unreachable"); + } + }; + + auto store = [=](const XReg &addr, const VReg ymm, int size) { + const uint32_t xmm = ymm.getIdx(); + switch (size) { + case 16: str(QReg(xmm), ptr(addr)); break; + case 8: str(DReg(xmm), ptr(addr)); break; + case 4: str(SReg(xmm), ptr(addr)); break; + case 2: str(HReg(xmm), ptr(addr)); break; + case 1: str(BReg(xmm), ptr(addr)); break; + default: assert(!"unreachable"); + } + }; + + /* check whether loading 4 values at once is possible */ + static constexpr int xmm_vlen = 4; + bool can_load_xmm = reg_unroll % xmm_vlen == 0; + int registers_total = reg_unroll / 4; + for (int reg = 0; reg < registers_total; reg++) { + for (int ur = 1 + (reg * 4); ur < ((reg + 1) * 4); ur++) + if (i_off[ur] != i_off[ur - 1] + 1) { + can_load_xmm = false; + break; + } + } + const int load_step = can_load_xmm ? xmm_vlen : 1; + + /* check whether storing 4 values at once is possible */ + bool can_store_xmm = reg_unroll % xmm_vlen == 0; + for (int ur = 1; ur < reg_unroll; ++ur) + if (o_off[ur] != o_off[ur - 1] + 1) { + can_store_xmm = false; + break; + } + const int ur_step = can_store_xmm ? 4 : 1; + const int load_tail_step + = !can_load_xmm && can_store_xmm ? ur_step : load_step; + + const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_); + + const bool need_saturation + = (utils::one_of(prb_.otype, u8, data_type::s8, s32) + && interim_f32); + + std::vector store_masks; + if (tail_processing) { + for (int ur = 0; ur < reg_unroll; ur += load_tail_step) { + uni_clear(VReg(ur)); + store_masks.push_back(0); + for (int r = 0; r < load_tail_step; ++r) { + if (zero_padding[ur + r] == 0) { + store_masks.back() += 1 << r; + load_bytes_addr(ur, r); + } + } + + for (int r = 0; r < load_tail_step; ++r) + if (zero_padding[ur + r] == 0) load_bytes(ur, itype_sz_, r); + } + } else { + if (!can_load_xmm && can_store_xmm) { + assert(ur_step == xmm_vlen); + /* load with stride */ + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + for (int r = 0; r < ur_step; ++r) { + load_bytes_addr(ur, r); + } + for (int r = 0; r < ur_step; ++r) + load_bytes(ur, itype_sz_, r); + } + } else { + int ur = 0; + int tmp_ur = 0; + while (ur < reg_unroll) { + int count = 0; + + do { + add_imm(x_tmp_vec[count++], x_ptr_in_off, + i_off[ur] * itype_sz_, X_DEFAULT_ADDR); + ur += load_step; + } while (ur < reg_unroll && count < x_tmp_vec_size); + + for (int i = 0; i < count; i++) { + + switch (load_step * itype_sz_) { + case 16: ldr(QReg(tmp_ur), ptr(x_tmp_vec[i])); break; + case 8: ldr(DReg(tmp_ur), ptr(x_tmp_vec[i])); break; + case 4: ldr(SReg(tmp_ur), ptr(x_tmp_vec[i])); break; + case 2: ldr(HReg(tmp_ur), ptr(x_tmp_vec[i])); break; + case 1: ldr(BReg(tmp_ur), ptr(x_tmp_vec[i])); break; + default: assert(!"unreachable"); + } + tmp_ur += load_step; + } + } + } + } + + /* xmm[:] <-- (f32)xmm[:] */ + if (interim_f32) { + const int cvt_step = nstl::max(load_step, ur_step); + for (int ur = 0; ur < reg_unroll; ur += cvt_step) + cvt2ps(ur, 1, prb_.itype); + } + + if (can_load_xmm && !can_store_xmm) { + // transposition on the fly + const bool fast_return = prb_.src_scale_type != scale_type_t::MANY + && prb_.dst_scale_type != scale_type_t::MANY && prb_.beta == 0.f + && !prb_.req_src_zp && !prb_.req_dst_zp + && !compensation_needed_; + if (fast_return) { + if (prb_.src_scale_type == scale_type_t::COMMON) + for (int ur = 0; ur < reg_unroll; ur += load_step) + fmul(VReg4S(ur), VReg4S(ur), xmm_src_scales_); + if (prb_.dst_scale_type == scale_type_t::COMMON) + for (int ur = 0; ur < reg_unroll; ur += load_step) + fmul(VReg4S(ur), VReg4S(ur), xmm_dst_scales_); + if (prb_.otype != f32) { + init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, X_TMP_0, + interim_f32 ? f32 : prb_.itype, prb_.otype); + for (int ur = 0; ur < reg_unroll; ur += load_step) { + if (need_saturation) + saturate_f32(VReg4S(ur), xmm_zero_, + xmm_saturation_ubound_, prb_.otype, P_ALL_ONE); + } + + for (int ur = 0; ur < reg_unroll; ur += load_step) + cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype); + } + for (int ur = 0; ur < reg_unroll; ur += load_step) { + for (int r = 0; r < load_step; ++r) { + add_imm(x_tmp_vec[r], x_ptr_out_off, + o_off[ur + r] * otype_sz_, X_DEFAULT_ADDR); + } + + for (int r = 0; r < load_step; ++r) { + if (otype_sz_ == 4) + st1(VReg4S(ur)[r], ptr(x_tmp_vec[r])); + else if (otype_sz_ == 2) + st1(VReg8H(ur)[r], ptr(x_tmp_vec[r])); + else + st1(VReg16B(ur)[r], ptr(x_tmp_vec[r])); + } + } + return; + } + + /* scatter elements of xmm into 4 xmms */ + if (itype_sz_ == 4 || interim_f32) { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) { + VReg4S v(ur); + VReg4S v_r(ur + r); + dup(VReg16B(ur + r), VReg16B(ur)[0]); + ins(VReg4S(ur + r)[0], VReg4S(ur)[r]); + } + } else { + for (int ur = 0; ur < reg_unroll; ur += load_step) + for (int r = 1; r < load_step; ++r) + ext(VReg16B(ur + r), VReg16B(ur), VReg16B(ur), + itype_sz_ * r); + } + } + + /* src zero point application */ + if (prb_.req_src_zp) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + const auto xmm = VReg4S(ur); + if (interim_f32) + fsub(xmm, xmm, xmm_src_zp_); + else + sub(xmm, xmm, xmm_src_zp_); + } + } + + /* scale and beta processing */ + if (can_store_xmm) { + const auto apply_scales + = [&](const VReg4S &vreg_scales, scale_arg_t scale_arg, + scale_type_t scale_type) { + if (scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + fmul(VReg4S(ur), VReg4S(ur), vreg_scales); + } else if (scale_type == scale_type_t::MANY) { + enum class scale_load_type_t { bcast, load, gather }; + const uint32_t idx = vreg_scales.getIdx(); + + uni_clear(VReg(idx)); + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + scale_load_type_t scale_load_type + = scale_load_type_t::bcast; // the best case + + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 0) + scale_load_type = scale_load_type_t::load; + + if (scale_load_type == scale_load_type_t::bcast + && !tail_processing) { + if (scale_arg == scale_arg_t::SRC) + ld1r(vreg_scales, ptr(src_s_addr(s_off[ur]))); + else + ld1r(vreg_scales, ptr(dst_s_addr(s_off[ur]))); + fmul(VReg4S(ur), VReg4S(ur), vreg_scales); + continue; + } + + // bcast doesn't work, the next try -- load + for (int r = ur + 1; r < ur + ur_step; ++r) + if (s_off[r] != s_off[r - 1] + 1) + scale_load_type = scale_load_type_t::gather; + + if (scale_load_type == scale_load_type_t::load + && !tail_processing) { + if (scale_arg == scale_arg_t::SRC) + ldr(QReg {idx}, ptr(src_s_addr(s_off[ur]))); + else + ldr(QReg {idx}, ptr(dst_s_addr(s_off[ur]))); + + fmul(VReg4S(ur), VReg4S(ur), VReg4S {idx}); + continue; + } + + // load doesn't work as well + // so gather the scale factors one by one + for (int r = ur; r < ur + ur_step; ++r) + if (zero_padding[r] == 0 || !tail_processing) { + if (scale_arg == scale_arg_t::SRC) + mov(x_tmp_vec[r - ur], src_s_addr(s_off[r])); + else + mov(x_tmp_vec[r - ur], dst_s_addr(s_off[r])); + } + for (int r = ur; r < ur + ur_step; ++r) + if (zero_padding[r] == 0 || !tail_processing) + ld1(vreg_scales[r - ur], ptr(x_tmp_vec[r - ur])); + fmul(VReg4S(ur), VReg4S(ur), vreg_scales); + } + } + }; + /* xmm <-- src_scales * xmm[:] */ + apply_scales(xmm_src_scales_, scale_arg_t::SRC, prb_.src_scale_type); + + /* xmm[:] <-- beta * dst + xmm[:] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + int ur = 0; + int tmp_ur = 0; + + while (ur < reg_unroll) { + int count = 0; + + do { + add_imm(x_tmp_vec[count++], x_ptr_out_off, + o_off[ur] * otype_sz_, X_DEFAULT_ADDR); + ur += ur_step; + } while (ur < reg_unroll && count < x_tmp_vec_size); + + assert(count <= z_tmp_vec_size); + /* Firstly, data is loaded. */ + for (int i = 0; i < count; i++) { + + if (prb_.otype == f32 || prb_.otype == s32) { + ldr(QReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); // bug + } else if (prb_.otype == data_type::s8 + || prb_.otype == u8) { + ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); // bug + } else + assert(!"unreachable"); + } + + /* Secondly, it is added. */ + if (prb_.otype == f32) { + for (int i = 0; i < count; i++) { + VReg4S v(tmp_ur); + fadd(v, v, VReg4S(tmp_vec_idx[i])); + tmp_ur += ur_step; + } + } else { + for (int i = 0; i < count; i++) { + /* cvt2ps() generate successive instructions + which have save destination operand, + but out of order can be expected. */ + cvt2ps(tmp_vec_idx[i], 1, prb_.otype); + } + for (int i = 0; i < count; i++) { + VReg4S v(tmp_ur); + fadd(v, v, VReg4S(tmp_vec_idx[i])); + tmp_ur += ur_step; + } + } + } + } + + /* dst <-- dst_scales * xmm[:] */ + apply_scales(xmm_dst_scales_, scale_arg_t::DST, prb_.dst_scale_type); + } else { + const auto apply_scales + = [&](const VReg4S &vreg_scales, scale_arg_t scale_arg, + scale_type_t scale_type) { + if (scale_type == scale_type_t::COMMON) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) + fmul(VReg4S(ur), VReg4S(ur), vreg_scales); + } else if (scale_type == scale_type_t::MANY) { +#define DUMMY_IDX_ (99) + std::vector idx_list; + std::vector offt_list; + std::vector vec_reg; + std::vector addr_reg; + const size_t max_cnt_per_loop + = std::min(tmp_vec_idx.size(), x_tmp_vec.size()); + size_t cnt = 0; // valid unroll steps count + + // 1. Listing up valid steps + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (zero_padding[ur] == 0 || !tail_processing) { + idx_list.push_back(ur); + offt_list.push_back(s_off[ur]); + vec_reg.push_back(tmp_vec_idx[cnt % max_cnt_per_loop]); + if (s_off[ur]) + addr_reg.push_back( + x_tmp_vec[cnt % max_cnt_per_loop]); + else + addr_reg.push_back(scale_arg == scale_arg_t::SRC + ? x_ptr_src_scale_off + : x_ptr_dst_scale_off); + cnt++; + } + } + /* 2. Generate instructions considering instruction order. + If cnt > max_cnt_per_loop, the following instruction sets are + generated several times. + add x?, ..., add x? for calculating address + ldr s?, ..., ldr s? for loading data + fmul v?, ..., fmul v? for scaling */ + for (size_t ur = 0; ur < cnt;) { + // Calculating address + for (size_t i = ur; i < cnt && i - ur < max_cnt_per_loop; + i++) + add_imm(addr_reg[i], + scale_arg == scale_arg_t::SRC + ? x_ptr_src_scale_off + : x_ptr_dst_scale_off, + offt_list[i] * stype_sz_, X_TMP); + // Loading data + for (size_t i = ur; i < cnt && i - ur < max_cnt_per_loop; + i++) + ldr(SReg(vec_reg[i]), ptr(addr_reg[i])); + // Scaling + for (size_t i = ur; i < cnt && i - ur < max_cnt_per_loop; + i++) { + VReg4S v(idx_list[i]); + fmul(v, v, VReg4S(vec_reg[i])); + } + ur += std::min(cnt, max_cnt_per_loop); + } + } +#undef DUMMY_IDX_ + }; + + /* xmm[0] <-- src_scales * xmm[0] */ + apply_scales(xmm_src_scales_, scale_arg_t::SRC, prb_.src_scale_type); + + /* xmm[0] <-- beta * dst + xmm[0] */ + assert(prb_.beta == 0.f || prb_.beta == 1.f); + if (prb_.beta == 1.f) { + int ur = 0; + int tmp_ur = 0; + while (ur < reg_unroll) { + int count = 0; + + do { + add_imm(x_tmp_vec[count++], x_ptr_out_off, + o_off[ur] * otype_sz_, X_DEFAULT_ADDR); + ur += ur_step; + } while (ur < reg_unroll && count < (x_tmp_vec_size / 2)); + + assert(static_cast(count) <= z_tmp_vec.size()); + + if (prb_.otype == f32) { + /* addss: dest[31:0] <- src1[31:0] + src2[31:0] + dset[MAXVL-1:32] (Unmodified) */ + for (int i = 0; i < count; i++) { + ld1(VReg4S(z_tmp_vec[i].getIdx())[0], + ptr(x_tmp_vec[i])); + } + for (int i = 0; i < count; i++) { + SReg s {tmp_vec_idx[i]}; + fadd(s, s, SReg(tmp_ur + ur_step * i)); + } + for (int i = 0; i < count; i++) { + mov(VReg4S(tmp_ur)[0], VReg4S(tmp_vec_idx[i])[0]); + tmp_ur += ur_step; + } + } else { + for (int i = 0; i < count; i++) { + if (prb_.otype == s32) { + ldr(SReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); + } else if (utils::one_of( + prb_.otype, data_type::s8, u8)) { + ldr(BReg(tmp_vec_idx[i]), ptr(x_tmp_vec[i])); + } else { + assert(!"unsupported o_type"); + } + cvt2ps(tmp_vec_idx[i], 1, prb_.otype); + } + for (int i = 0; i < count; i++) { + VReg4S v(tmp_ur); + fadd(v, v, VReg4S(tmp_vec_idx[i])); + tmp_ur += ur_step; + } + } + } + } + + /* dst <-- dst_scales * xmm[0] */ + apply_scales(xmm_dst_scales_, scale_arg_t::DST, prb_.dst_scale_type); + } + + /* dst zero point application */ + if (prb_.req_dst_zp) { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + const auto xmm = VReg4S(ur); + if (interim_f32) + fadd(xmm, xmm, xmm_dst_zp_); + else + add(xmm, xmm, xmm_dst_zp_); + } + } + + /* adjust scale application */ + if (prb_.scale_adjust != 1.f) { + dup(xmm_tmp_, reg_scale_adjust_); + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + fmul(VReg4S(ur), VReg4S(ur), xmm_tmp_); + } + } + + if (need_saturation) { + init_saturate_f32(xmm_zero_, xmm_saturation_ubound_, X_TMP_0, f32, + prb_.otype, compensation_needed_); + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + saturate_f32(VReg4S(ur), xmm_zero_, xmm_saturation_ubound_, + prb_.otype, P_ALL_ONE, compensation_needed_); + } + + // reset back xmm_zero_ if needed. + if (compensation_needed_ && (prb_.req_src_zp || prb_.req_dst_zp)) + uni_clear(VReg(xmm_zero_.getIdx())); + } + + if (compensation_needed_) { + uint32_t xmm_id = 0; + const auto get_temp_xmm = [&] { + const Xbyak_aarch64::VReg temp {tmp_vec_idx[xmm_id]}; + + xmm_id = (xmm_id + 1) % tmp_vec_idx.size(); + + return temp; + }; + if (can_store_xmm) { + enum class comp_load_type_t { bcast, load, gather }; + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + + bool all_ip_padding_one = true; + bool all_ip_padding_zero = true; + for (int r = ur; r < ur + ur_step; r++) { + if (zero_padding[r] != 1) + all_ip_padding_one = false; + else + all_ip_padding_zero = false; + } + if (all_ip_padding_one) continue; + + comp_load_type_t comp_load_type = comp_load_type_t::bcast; + + for (int r = ur + 1; r < ur + ur_step; ++r) + if (c_off[r] != c_off[r - 1] + 0) { + comp_load_type = comp_load_type_t::load; + break; + } + + if (comp_load_type == comp_load_type_t::bcast + && all_ip_padding_zero) { + frinti(xmm_compensation, VReg4S(ur)); + fcvtzs(xmm_compensation, xmm_compensation); + addv(SReg(xmm_compensation.getIdx()), xmm_compensation); + addv(SReg(xmm_compensation.getIdx()), xmm_compensation); + const auto comp_addr = c_addr(c_off[ur]); + const auto xmm_tmp_ = get_temp_xmm().s4; + ldr(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); + add(xmm_tmp_, xmm_tmp_, xmm_compensation); + str(SReg(xmm_tmp_.getIdx()), ptr(comp_addr)); + continue; + } + + if (comp_load_type == comp_load_type_t::load) + for (int r = ur + 1; r < ur + ur_step; ++r) + if (c_off[r] != c_off[r - 1] + 1) { + comp_load_type = comp_load_type_t::gather; + break; + } + + if (comp_load_type == comp_load_type_t::load + && all_ip_padding_zero) { + const auto xmm_reorder_result = VReg4S(ur); + const auto comp_addr = c_addr(c_off[ur]); + frinti(xmm_compensation, xmm_reorder_result); + fcvtzs(xmm_compensation, xmm_compensation); + const auto xmm_tmp_ = get_temp_xmm().s4; + ldr(QReg(xmm_tmp_.getIdx()), ptr(comp_addr)); + add(xmm_compensation, xmm_compensation, xmm_tmp_); + str(QReg(xmm_compensation.getIdx()), ptr(comp_addr)); + continue; + } + + frinti(xmm_compensation, VReg4S(ur)); + fcvtzs(xmm_compensation, xmm_compensation); + for (int r = ur; r < ur + ur_step; ++r) { + if (zero_padding[r] == 0 || !tail_processing) { + mov(W_TMP_0, xmm_compensation[r % 4]); + const auto comp_addr = c_addr(c_off[r]); + ldr(W_TMP_1, ptr(comp_addr)); + add(W_TMP_0, W_TMP_0, W_TMP_1); + str(W_TMP_0, ptr(comp_addr)); + } + } + } + } else { + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (zero_padding[ur] == 0 || !tail_processing) { + const auto comp_addr = c_addr(c_off[ur]); + frinti(xmm_compensation, VReg4S(ur)); + fcvtzs(xmm_compensation, xmm_compensation); + const auto xmm_tmp_ = get_temp_xmm().s4; + ld1(xmm_tmp_, ptr(comp_addr)); + add(xmm_compensation, xmm_compensation, xmm_tmp_); + st1(VReg(xmm_compensation.getIdx()).s[0], ptr(comp_addr)); + } + } + } + } + + for (int ur = 0; ur < reg_unroll; ur += ur_step) { + if (prb_.req_src_zp || prb_.req_dst_zp) { + const bool use_store_masks = !store_masks.empty(); + if (use_store_masks) { + const auto mask = (~store_masks[ur / ur_step]) & 0xF; + switch (mask) { + case 0x0: + /* Do nothing */ + break; + case 0x1: ins(VReg4S(ur)[0], xmm_zero_[0]); break; + case 0x2: ins(VReg4S(ur)[1], xmm_zero_[1]); break; + case 0x3: + ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); + break; + case 0x4: ins(VReg4S(ur)[2], xmm_zero_[2]); break; + case 0x5: + ins(VReg4S(ur)[0], xmm_zero_[0]); + ins(VReg4S(ur)[2], xmm_zero_[2]); + break; + case 0x6: + ins(VReg4S(ur)[1], xmm_zero_[1]); + ins(VReg4S(ur)[2], xmm_zero_[2]); + break; + case 0x7: + ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); + ins(VReg4S(ur)[2], xmm_zero_[2]); + break; + case 0x8: ins(VReg4S(ur)[3], xmm_zero_[3]); break; + case 0x9: + ins(VReg4S(ur)[0], xmm_zero_[0]); + ins(VReg4S(ur)[3], xmm_zero_[3]); + break; + case 0xa: + ins(VReg4S(ur)[1], xmm_zero_[1]); + ins(VReg4S(ur)[3], xmm_zero_[3]); + break; + case 0xb: + ins(VReg2D(ur)[0], VReg2D(xmm_zero_.getIdx())[0]); + ins(VReg4S(ur)[3], xmm_zero_[3]); + break; + case 0xc: + ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); + break; + case 0xd: + ins(VReg4S(ur)[0], xmm_zero_[0]); + ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); + break; + case 0xe: + ins(VReg4S(ur)[1], xmm_zero_[1]); + ins(VReg2D(ur)[1], VReg2D(xmm_zero_.getIdx())[1]); + break; + case 0xf: movi(VReg16B(ur), 0); break; + default: assert(!"unreachable"); + } + } + } + if (prb_.otype != f32) + cvt2odt(ur, 1, prb_.otype, interim_f32 ? f32 : prb_.itype); + + store(o_addr(o_off[ur]), VReg(ur), ur_step * otype_sz_); + } +} + +bool jit_uni_reorder_kernel_f32_t::interim_f32_needed( + const prb_t &prb, bool compensation_needed) { + using namespace data_type; + bool ret = utils::one_of(f32, prb.itype, prb.otype) + || prb.src_scale_type != scale_type_t::NONE + || prb.dst_scale_type != scale_type_t::NONE || prb.beta != 0.f + || ((prb.req_src_zp || prb.req_dst_zp) + ? !(prb.itype == s32 && prb.otype == s32) + : false) + || (prb.itype != f32 && compensation_needed) + || prb.scale_adjust != 1.f; + return ret; +} + +void jit_uni_reorder_kernel_f32_t::process_unroll_generic( + const int ndims, int len, const bool tail_processing) { + assert(IMPLICATION(prb_.nodes[0].tail_size > 0, + len == static_cast(prb_.nodes[0].n) + || len == static_cast(prb_.nodes[0].tail_size))); + + const int blk = 8; + + int i_off[2 * blk] = {0}; + int o_off[2 * blk] = {0}; + int s_off[2 * blk] = {0}; + int c_off[2 * blk] = {0}; + + int curr = 0; // will switch between 0 and 1 + + const bool interim_f32 = interim_f32_needed(prb_, compensation_needed_); + + if (prb_.req_src_zp) { + add_imm(X_DEFAULT_ADDR, PARAM(src_zp), X_TMP_0); + ld1r(xmm_src_zp_, ptr(X_DEFAULT_ADDR)); + if (interim_f32) scvtf(xmm_src_zp_, xmm_src_zp_); + } + if (prb_.req_dst_zp) { + add_imm(X_DEFAULT_ADDR, PARAM(dst_zp), X_TMP_0); + ld1r(xmm_dst_zp_, ptr(X_DEFAULT_ADDR)); + if (interim_f32) scvtf(xmm_dst_zp_, xmm_dst_zp_); + } + + for (int off = 0; off < len; off += blk) { + const int reg_unroll = nstl::min(off + blk, len) - off; + int zero_padding[blk] = {0}; + const auto curr_blk = curr * blk; + + /* compute offsets and tail*/ + for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) { + const int ur_c = curr_blk + ur; + const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur + const bool is_tail + = off + ur >= static_cast(prb_.nodes[0].tail_size); + step(off + ur, i_off[ur_p], o_off[ur_p], s_off[ur_p], c_off[ur_p], + i_off[ur_c], o_off[ur_c], s_off[ur_c], c_off[ur_c]); + if (tail_processing && is_tail) zero_padding[ur] = 1; + } + + process_unroll_generic_step(reg_unroll, i_off + curr_blk, + o_off + curr_blk, s_off + curr_blk, c_off + curr_blk, + zero_padding, tail_processing); + + curr = 1 - curr; + } +} + +void jit_uni_reorder_kernel_f32_t::compute_ker( + const int ndims, const int len_unroll, const bool tail_processing) { + bool optimized = false; + optimized = optimized || process_direct_copy(ndims, len_unroll) + || process_direct_copy(ndims, len_unroll) + || process_unroll_tr8x8(ndims, len_unroll) + || process_unroll_tr4x8(ndims, len_unroll); + + if (!optimized) process_unroll_generic(ndims, len_unroll, tail_processing); +} + +void jit_uni_reorder_kernel_f32_t::loop_begin(Label &l, XReg reg_cnt, int len) { + mov(reg_cnt, len); + L(l); +} + +void jit_uni_reorder_kernel_f32_t::check_if_this_is_last_chunk( + const XReg reg_curr_chunk, int node_id) { + // Chunks are backwards numered i.e: + // [0] -> [node_size] + // [1] -> [node_size - 1] + // ... + // [node_size - 1] -> [1] + + // It is done like this, because it is easier to decrement counter + // and check if it is equal to zero than increment and check + // if it is equal to node_size. + static constexpr int64_t last_chunk = 1; + cmp(reg_curr_chunk, last_chunk); +} + +void jit_uni_reorder_kernel_f32_t::zero_dst_memory(const int bytes_to_zeroing) { + static constexpr int num_of_bytes_in_xmm = 128 / 8; + + const int xmms_to_zeroing + = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).quot; + const int tail_to_zeroing + = std::div(bytes_to_zeroing, num_of_bytes_in_xmm).rem; + + movi(xmm_tmp_, 0); + + if (xmms_to_zeroing > 0) { + Label loop; + + mov(X_TMP_4, xmms_to_zeroing); + L(loop); + str(QReg(xmm_tmp_.getIdx()), ptr(o_addr(0))); + add_imm(reg_off_out_, reg_off_out_, num_of_bytes_in_xmm, X_TMP_0); + add_imm(x_ptr_out_off, x_ptr_out_off, num_of_bytes_in_xmm, X_TMP_0); + subs(X_TMP_4, X_TMP_4, 1); + b(NE, loop); + } + + if (tail_to_zeroing) mov_imm(W_TMP_4, 0); + for (int i = 0; i < tail_to_zeroing; i++) + strb(W_TMP_4, ptr(o_addr(i, false))); + + // Restore dst offset to initial value + if (xmms_to_zeroing > 0) { + sub_imm(reg_off_out_, reg_off_out_, + num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0); + sub_imm(x_ptr_out_off, x_ptr_out_off, + num_of_bytes_in_xmm * xmms_to_zeroing, X_TMP_0); + } +} + +void jit_uni_reorder_kernel_f32_t::finalize_tail_loop(int i_step, int o_step, + int s_step, int c_step, const int curr_node_id) { + static constexpr int empty_chunk_info = -1; + + mov(X_TMP_0, empty_chunk_info); + str(X_TMP_0, ptr(data_chunk_addr(curr_node_id))); + + const int padded_area + = prb_.nodes[curr_node_id].n - prb_.nodes[curr_node_id].tail_size; + + if (prb_.nodes[curr_node_id].is_zero_pad_needed) { + int num_of_zero_padded_values = padded_area; + for (int i = curr_node_id - 1; i >= 0; i--) { + num_of_zero_padded_values *= prb_.nodes[i].n; + } + + const int bytes_to_zeroing = num_of_zero_padded_values * otype_sz_; + zero_dst_memory(bytes_to_zeroing); + } + + // This function is called by loop_end. At the end + // of loop_end is section that is responsible for + // restoring offset values. Restoring is based on + // len value which is equal to prb.nodes[x].n. + // If fill_zero_padded_area is called then it means + // offsets were shifted prb.nodes[x].tail_size times. + // Therefore, this function has to shift offsets by + // zero pad area. + add_imm(reg_off_in_, reg_off_in_, padded_area * i_step * itype_sz_, + X_TMP_0); + add_imm(reg_off_out_, reg_off_out_, padded_area * o_step * otype_sz_, + X_TMP_0); + add_imm(x_ptr_in_off, x_ptr_in_off, padded_area * i_step * itype_sz_, + X_TMP_0); + add_imm(x_ptr_out_off, x_ptr_out_off, padded_area * o_step * otype_sz_, + X_TMP_0); + if (prb_.src_scale_type == scale_type_t::MANY) + add_imm(x_ptr_src_scale_off, x_ptr_src_scale_off, + padded_area * s_step * stype_sz_, X_TMP_0); + if (prb_.dst_scale_type == scale_type_t::MANY) + add_imm(x_ptr_dst_scale_off, x_ptr_dst_scale_off, + padded_area * s_step * stype_sz_, X_TMP_0); + + if (compensation_needed_) { + add_imm(reg_off_comp_, reg_off_comp_, + padded_area * c_step * sizeof(int32_t), X_TMP_0); + add_imm(x_ptr_comp_off, x_ptr_comp_off, + padded_area * c_step * sizeof(int32_t), X_TMP_0); + } +} + +void jit_uni_reorder_kernel_f32_t::loop_end(Label &l, XReg reg_cnt, int len, + int i_step, int o_step, int s_step, int c_step, + const int curr_node_id) { + add_imm(reg_off_in_, reg_off_in_, i_step * itype_sz_, X_TMP_0); + add_imm(reg_off_out_, reg_off_out_, o_step * otype_sz_, X_TMP_0); + add_imm(x_ptr_in_off, x_ptr_in_off, i_step * itype_sz_, X_TMP_0); + add_imm(x_ptr_out_off, x_ptr_out_off, o_step * otype_sz_, X_TMP_0); + + if (prb_.src_scale_type == scale_type_t::MANY) + add_imm(x_ptr_src_scale_off, x_ptr_src_scale_off, s_step * stype_sz_, + X_TMP_0); + if (prb_.dst_scale_type == scale_type_t::MANY) + add_imm(x_ptr_dst_scale_off, x_ptr_dst_scale_off, s_step * stype_sz_, + X_TMP_0); + + if (compensation_needed_) { + add_imm(reg_off_comp_, reg_off_comp_, c_step * sizeof(int32_t), + X_TMP_0); + add_imm(x_ptr_comp_off, x_ptr_comp_off, c_step * sizeof(int32_t), + X_TMP_0); + } + + subs(reg_cnt, reg_cnt, 1); + b(NE, l); + + if (prb_.tail(curr_node_id) != 0) { + Label if_end; + + // On the stack should be an information if node + // was processed with tail or not. + ldr(X_TMP_0, post_ptr(X_SP, X_TMP_0.getBit() / 8)); + + cmp(X_TMP_0, with_tail_info_); + b(NE, if_end); + finalize_tail_loop(i_step, o_step, s_step, c_step, curr_node_id); + L(if_end); + } + + // Restore offset to initial values. It means before + // loop execution. + sub_imm(reg_off_in_, reg_off_in_, len * i_step * itype_sz_, X_TMP_0); + sub_imm(reg_off_out_, reg_off_out_, len * o_step * otype_sz_, X_TMP_0); + sub_imm(x_ptr_in_off, x_ptr_in_off, len * i_step * itype_sz_, X_TMP_0); + sub_imm(x_ptr_out_off, x_ptr_out_off, len * o_step * otype_sz_, X_TMP_0); + + if (prb_.src_scale_type == scale_type_t::MANY) + sub_imm(x_ptr_src_scale_off, x_ptr_src_scale_off, + len * s_step * stype_sz_, X_TMP_0); + if (prb_.dst_scale_type == scale_type_t::MANY) + sub_imm(x_ptr_dst_scale_off, x_ptr_dst_scale_off, + len * s_step * stype_sz_, X_TMP_0); + if (compensation_needed_) { + sub_imm(reg_off_comp_, reg_off_comp_, len * c_step * sizeof(int32_t), + X_TMP_0); + sub_imm(x_ptr_comp_off, x_ptr_comp_off, len * c_step * sizeof(int32_t), + X_TMP_0); + } +} + +void jit_uni_reorder_kernel_f32_t::compute_blk_ker( + const simple_impl_desc_t &desc) { + static constexpr bool with_tail_processing = true; + Label no_last_chunk, end_label; + int omp_ndims = prb_.full_ndims - prb_.ndims; + + if (prb_.nodes[0].tail_size > 0) { + if (!prb_.nodes[0].is_parent_empty()) { + const int parent_node_id = prb_.nodes[0].parent_node_id; + ldr(X_TMP_0, ptr(data_chunk_addr(parent_node_id))); + check_if_this_is_last_chunk(X_TMP_0, parent_node_id); + b(NE, no_last_chunk); + } + + const int len_unroll = desc.tail_len_unroll > 0 ? desc.tail_len_unroll + : desc.len_unroll; + compute_ker(omp_ndims, len_unroll, with_tail_processing); + b(end_label); + } + + L(no_last_chunk); + compute_ker(omp_ndims, desc.len_unroll, !with_tail_processing); + L(end_label); +} + +void jit_uni_reorder_kernel_f32_t::create_loops(const simple_impl_desc_t &desc, + const std::array ®_cnt, int jit_loop) { + assert(jit_loop <= ndims_jit_loop_max); + + if (jit_loop > 0) { + const int nfu = desc.ndims_full_unroll; + const int unroll_factor = jit_loop == 1 ? desc.len_last_dim_unroll : 1; + const int curr_node_id = nfu + (jit_loop - 1); + const int parent_node_id = prb_.nodes[curr_node_id].parent_node_id; + const int tail_size = prb_.tail(curr_node_id) / unroll_factor; + const int node_size = prb_.n(curr_node_id) / unroll_factor; + const XReg reg_loop_cnt = reg_cnt[jit_loop - 1]; + const bool curr_node_has_tail = prb_.tail(curr_node_id) != 0; + Label loop, if_no_tail, if_end; + + if (curr_node_has_tail) { + const size_t reg_bytes = X_TMP_0.getBit() / 8; + if (prb_.nodes[curr_node_id].is_parent_empty()) { + mov(reg_loop_cnt, tail_size); + // Put info that node is being processed with tail. + mov(X_TMP_0, with_tail_info_); + str(X_TMP_0, pre_ptr(X_SP, -static_cast(reg_bytes))); + } else { + ldr(X_TMP_0, ptr(data_chunk_addr(parent_node_id))); + check_if_this_is_last_chunk(X_TMP_0, parent_node_id); + b(NE, if_no_tail); + mov(reg_loop_cnt, tail_size); + // Put info that node is being processed with tail. + mov(X_TMP_0, with_tail_info_); + str(X_TMP_0, pre_ptr(X_SP, -static_cast(reg_bytes))); + b(if_end); + + L(if_no_tail); + mov(reg_loop_cnt, node_size); + // Put info that node is being processed without tail. + mov(X_TMP_0, without_tail_info_); + str(X_TMP_0, pre_ptr(X_SP, -static_cast(reg_bytes))); + L(if_end); + } + } + + if (prb_.is_tail_in_one_of_child_nodes(curr_node_id)) { + if (!curr_node_has_tail) { + mov(reg_loop_cnt, node_size); + str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); + } + L(loop); + if (!prb_.nodes[curr_node_id].is_parent_empty()) { + Label if_no_tail_in_child_node; + ldr(X_TMP_0, ptr(data_chunk_addr(parent_node_id))); + check_if_this_is_last_chunk(X_TMP_0, parent_node_id); + b(NE, if_no_tail_in_child_node); + str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); + L(if_no_tail_in_child_node); + } else { + str(reg_loop_cnt, ptr(data_chunk_addr(curr_node_id))); + } + } else if (curr_node_has_tail) { + L(loop); + } else { + loop_begin(loop, reg_loop_cnt, node_size); + } + + create_loops(desc, reg_cnt, jit_loop - 1); + + loop_end(loop, reg_loop_cnt, node_size, + prb_.is(curr_node_id) * unroll_factor, + prb_.os(curr_node_id) * unroll_factor, + prb_.ss(curr_node_id) * unroll_factor, + prb_.cs(curr_node_id) * unroll_factor, curr_node_id); + } else { + compute_blk_ker(desc); + } +} + +bool jit_uni_reorder_kernel_f32_t::simple_impl() { + simple_impl_desc_t d; + if (!simple_impl_desc_init(prb_, &d)) return false; + + eor(reg_off_in_, reg_off_in_, reg_off_in_); + eor(reg_off_out_, reg_off_out_, reg_off_out_); + + if (prb_.src_scale_type == scale_type_t::MANY) + mov(x_ptr_src_scale_off, reg_ptr_src_scales_); + if (prb_.dst_scale_type == scale_type_t::MANY) + mov(x_ptr_dst_scale_off, reg_ptr_dst_scales_); + + if (compensation_needed_) eor(reg_off_comp_, reg_off_comp_, reg_off_comp_); + + std::array reg_cnt({{x15, x14, x13}}); + + const int n_jit_loops = prb_.ndims - d.ndims_full_unroll; + create_loops(d, reg_cnt, n_jit_loops); + + return true; +} + +void jit_uni_reorder_kernel_f32_t::impl() { + if (simple_impl()) return; + assert(!"no implementation available"); +} + +#define UNROLL_INST(inst, reg, ...) \ + for (size_t i = startIdx; i < startIdx + regNum; i++) { \ + reg tmp(i); \ + inst(__VA_ARGS__); \ + } +#define UNROLL_INST2(inst, ...) \ + for (size_t i = startIdx; i < startIdx + regNum; i++) \ + inst(__VA_ARGS__); + +void jit_uni_reorder_kernel_f32_t::cvt_z_s32_f32( + const size_t startIdx, const size_t regNum) { + UNROLL_INST(scvtf, ZRegS, tmp, P_ALL_ONE / T_m, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_s32_f32( + const size_t startIdx, const size_t regNum) { + UNROLL_INST(scvtf, VReg4S, tmp, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_f32_s32( + const size_t startIdx, const size_t regNum) { + UNROLL_INST(frinti, ZRegS, tmp, P_ALL_ONE / T_m, tmp); + UNROLL_INST(fcvtzs, ZRegS, tmp, P_ALL_ONE / T_m, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_f32_s32( + const size_t startIdx, const size_t regNum) { + UNROLL_INST(frinti, VReg4S, tmp, tmp); + UNROLL_INST(fcvtzs, VReg4S, tmp, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_f32_bf16( + const size_t startIdx, const size_t regNum) { + UNROLL_INST2(bfcvtn, VReg4H(i), VReg4S(i)); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_bf16_fp32( + const size_t startIdx, const size_t regNum) { + UNROLL_INST2(shll, VReg4S(i), VReg4H(i), 16); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_f16_f32( + const size_t startIdx, const size_t regNum) { + UNROLL_INST2(fcvtl, VReg4S(i), VReg4H(i)); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_f32_f16( + const size_t startIdx, const size_t regNum) { + UNROLL_INST2(fcvtn, VReg4H(i), VReg4S(i)); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_s8_s32( + const size_t startIdx, const size_t regNum) { + cvt_z_b_s(startIdx, regNum); + UNROLL_INST(sxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_s8_s32( + const size_t startIdx, const size_t regNum) { + UNROLL_INST(sxtl, VReg, tmp.h8, tmp.b8); + UNROLL_INST(sxtl, VReg, tmp.s4, tmp.h4); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_s8_f32( + const size_t startIdx, const size_t regNum) { + cvt_z_b_s(startIdx, regNum); + cvt_z_s32_f32(startIdx, regNum); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_s8_f32( + const size_t startIdx, const size_t regNum) { + cvt_v_b_s(startIdx, regNum); + cvt_v_s32_f32(startIdx, regNum); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_b_s( + const size_t startIdx, const size_t regNum) { + assert(z_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < z_tmp7.getIdx()); + + dup(z_tmp7.b, 0); + UNROLL_INST(zip1, ZRegB, tmp, tmp, z_tmp7.b); + UNROLL_INST(zip1, ZRegH, tmp, tmp, z_tmp7.h); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_b_s( + const size_t startIdx, const size_t regNum) { + assert(v_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < v_tmp7.getIdx()); + + mov_imm(W_TMP_0, 0); + dup(v_tmp7.b16, W_TMP_0); + UNROLL_INST(zip1, VReg16B, tmp, tmp, v_tmp7.b16); + UNROLL_INST(zip1, VReg8H, tmp, tmp, v_tmp7.h8); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_u8_s32( + const size_t startIdx, const size_t regNum) { + cvt_z_b_s(startIdx, regNum); + UNROLL_INST(uxtb, ZRegS, tmp, P_ALL_ONE / T_m, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_u8_s32( + const size_t startIdx, const size_t regNum) { + UNROLL_INST(uxtl, VReg, tmp.h8, tmp.b8); + UNROLL_INST(uxtl, VReg, tmp.s4, tmp.h4); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_s32_s8( + const size_t startIdx, const size_t regNum) { + assert(z_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < z_tmp7.getIdx()); + + dup(z_tmp7.s, 0); + UNROLL_INST2(smin, ZRegS(i), 127); + UNROLL_INST2(smax, ZRegS(i), -128); + UNROLL_INST(uzp1, ZRegH, tmp, tmp, z_tmp7.h); + UNROLL_INST(uzp1, ZRegB, tmp, tmp, z_tmp7.b); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_s32_s8( + const size_t startIdx, const size_t regNum) { + assert(v_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < v_tmp7.getIdx()); + + mov_imm(W_TMP_0, 127); + dup(v_tmp7.s4, W_TMP_0); + mov_imm(W_TMP_0, -128); + UNROLL_INST2(smin, VReg4S(i), VReg4S(i), v_tmp7.s4); + dup(v_tmp7.s4, W_TMP_0); + UNROLL_INST2(smax, VReg4S(i), VReg4S(i), v_tmp7.s4); + mov_imm(W_TMP_0, 0); + dup(v_tmp7.s4, W_TMP_0); + UNROLL_INST(uzp1, VReg8H, tmp, tmp, v_tmp7.h8); + UNROLL_INST(uzp1, VReg16B, tmp, tmp, v_tmp7.b16); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_u8_s8( + const size_t startIdx, const size_t regNum) { + UNROLL_INST2(umin, ZRegB(i), 127); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_u8_s8( + const size_t startIdx, const size_t regNum) { + assert(v_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < v_tmp7.getIdx()); + + mov_imm(W_TMP_0, 127); + dup(v_tmp7.b16, W_TMP_0); + UNROLL_INST(umin, VReg16B, tmp, tmp, v_tmp7.b16); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_u32_u8( + const size_t startIdx, const size_t regNum) { + UNROLL_INST2(umin, ZRegS(i), 255); + UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp); + UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_u32_u8( + const size_t startIdx, const size_t regNum) { + assert(v_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < v_tmp7.getIdx()); + + mov_imm(W_TMP_0, 255); + dup(v_tmp7.s4, W_TMP_0); + UNROLL_INST(umin, VReg4S, tmp, tmp, v_tmp7.s4); + UNROLL_INST(uzp1, VReg8H, tmp, tmp, tmp); + UNROLL_INST(uzp1, VReg16B, tmp, tmp, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_s32_u8( + const size_t startIdx, const size_t regNum) { + assert(z_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < z_tmp7.getIdx()); + + dupm(z_tmp7.s, 255); + UNROLL_INST2(smax, ZRegS(i), 0); + UNROLL_INST2(smin, ZRegS(i), P_ALL_ONE / T_m, z_tmp7.s); + UNROLL_INST(uzp1, ZRegH, tmp, tmp, tmp); + UNROLL_INST(uzp1, ZRegB, tmp, tmp, tmp); + UNROLL_INST2(mov, ZRegB(i), P_NOT_128 / T_m, 0); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_s32_u8( + const size_t startIdx, const size_t regNum) { + assert(v_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < v_tmp7.getIdx()); + + mov_imm(W_TMP_0, 0); + dup(v_tmp7.s4, W_TMP_0); + mov_imm(W_TMP_0, 255); + UNROLL_INST(smax, VReg4S, tmp, tmp, v_tmp7.s4); + dup(v_tmp7.s4, W_TMP_0); + UNROLL_INST(smin, VReg4S, tmp, tmp, v_tmp7.s4); + UNROLL_INST(uzp1, VReg8H, tmp, tmp, tmp); + UNROLL_INST(uzp1, VReg16B, tmp, tmp, tmp); +} + +void jit_uni_reorder_kernel_f32_t::cvt_z_s8_u8( + const size_t startIdx, const size_t regNum) { + UNROLL_INST2(smax, ZRegB(i), 0); +} + +void jit_uni_reorder_kernel_f32_t::cvt_v_s8_u8( + const size_t startIdx, const size_t regNum) { + assert(v_tmp7.getIdx() < startIdx + || startIdx + regNum - 1 < v_tmp7.getIdx()); + + mov_imm(W_TMP_0, 0); + dup(v_tmp7.b16, W_TMP_0); + UNROLL_INST(smax, VReg16B, tmp, tmp, v_tmp7.b16); +} +#undef UNROLL_INST +#undef UNROLL_INST + +jit_uni_reorder_kernel_f32_t::jit_uni_reorder_kernel_f32_t(const desc_t &desc) + : kernel_t(desc), isa_(get_max_cpu_isa()) { + assert(!utils::one_of(isa_, isa_undef, isa_all)); + itype_sz_ = data_type_size(prb_.itype); + otype_sz_ = data_type_size(prb_.otype); + stype_sz_ = sizeof(float); +} + +void jit_uni_reorder_kernel_f32_t::generate() { + using namespace Xbyak_aarch64::util; + uint64_t sveLen = get_sve_length(); + Label end_of_kernel; + + preamble(); + + if (prb_.src_scale_type == scale_type_t::COMMON) { + add_imm(X_DEFAULT_ADDR, PARAM(src_scales), X_TMP_1); + ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); + ld1r(xmm_src_scales_, ptr(X_TMP_0)); + } else if (prb_.src_scale_type == scale_type_t::MANY) { + add_imm(X_DEFAULT_ADDR, PARAM(src_scales), X_TMP_0); + ldr(reg_ptr_src_scales_, ptr(X_DEFAULT_ADDR)); + } + + if (prb_.dst_scale_type == scale_type_t::COMMON) { + add_imm(X_DEFAULT_ADDR, PARAM(dst_scales), X_TMP_1); + ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); + ld1r(xmm_dst_scales_, ptr(X_TMP_0)); + } else if (prb_.dst_scale_type == scale_type_t::MANY) { + add_imm(X_DEFAULT_ADDR, PARAM(dst_scales), X_TMP_0); + ldr(reg_ptr_dst_scales_, ptr(X_DEFAULT_ADDR)); + } + + if (compensation_needed_) { + add_imm(X_DEFAULT_ADDR, PARAM(compensation_scratch), X_TMP_0); + ldr(reg_ptr_comp_, ptr(X_DEFAULT_ADDR)); + } + if (prb_.scale_adjust == 0.5f) { mov(reg_scale_adjust_, 0x3f000000); } + add_imm(X_TMP_0, PARAM(in), X_TMP_2); + add_imm(X_TMP_1, PARAM(out), X_TMP_2); + ldr(reg_ptr_in_, ptr(X_TMP_0)); + ldr(reg_ptr_out_, ptr(X_TMP_1)); + + if (sveLen) { /* SVE is available. */ + ptrue(p_lsb_256.b, VL32); + ptrue(p_lsb_128.b, VL16); + ptrue(p_lsb_64.b, VL8); + } + + bool is_tail_in_drv_dims = false; + for (int i = prb_.ndims; i < prb_.full_ndims; i++) + if (prb_.nodes[i].tail_size > 0) { + is_tail_in_drv_dims = true; + break; + } + + if (is_tail_in_drv_dims) { + Label reorder_kernel; + add_imm(X_DEFAULT_ADDR, TAIL_PARAM(skip_kernel_execution), X_TMP_0); + ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); + cmp(X_TMP_0, static_cast(true)); + b(EQ, end_of_kernel); + + add_imm(X_DEFAULT_ADDR, TAIL_PARAM(zeroing_data), X_TMP_0); + ldr(X_TMP_0, ptr(X_DEFAULT_ADDR)); + cmp(X_TMP_0, static_cast(false)); + b(EQ, reorder_kernel); + // If zeroing data is set then all dst memory + // will be zeroed and nothing more will be done. + int bytes_to_zeroing = otype_sz_; + for (int i = 0; i < prb_.ndims; i++) { + bytes_to_zeroing *= prb_.nodes[i].n; + } + eor(reg_off_out_, reg_off_out_, reg_off_out_); + mov(x_ptr_out_off, reg_ptr_out_); + zero_dst_memory(bytes_to_zeroing); + b(end_of_kernel); + L(reorder_kernel); + } + + if (can_do_tr8x8()) { + dup(ymm_zero_, 0); + } else { + movi(xmm_zero_, 0); + } + + impl(); + + L(end_of_kernel); + postamble(); +} + +#undef TAIL_PARAM +#undef PARAM +} //namespace tr +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp b/src/cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp new file mode 100644 index 00000000000..90c5fe45801 --- /dev/null +++ b/src/cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp @@ -0,0 +1,424 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* Copyright 2020-2023 FUJITSU LIMITED +* Copyright 2022, 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_REORDER_JIT_UNI_REORDER_KERNEL_HPP +#define CPU_AARCH64_REORDER_JIT_UNI_REORDER_KERNEL_HPP + +#include + +#include "common/c_types_map.hpp" + +#include "cpu/aarch64/jit_generator.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder_utils.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { +namespace tr { +struct call_param_t { + const void *in = nullptr; + void *out = nullptr; + const float *src_scales = nullptr; + const float *dst_scales = nullptr; + int32_t src_zp = 0; + int32_t dst_zp = 0; + int32_t *compensation_scratch = nullptr; +}; + +// The additional structure is needed because +// using a data structure with tail processing +// data for non-tail cases reduces kernel +// performance. This is because there is too +// much data that has to be transferred to the kernel. +struct tail_call_param_t { + call_param_t base_params; + int64_t curr_data_chunks[DNNL_MAX_NDIMS] = {-1}; + int64_t zeroing_data = static_cast(false); + int64_t skip_kernel_execution = static_cast(false); +}; + +struct kernel_t { + struct desc_t { + int id; + prb_t prb; + }; + + kernel_t(const desc_t &desc) + : desc_(desc) + , compensation_needed_( + desc.prb.req_s8s8_comp || desc.prb.req_asymmetric_comp) {} + virtual void operator()(const call_param_t *c) const = 0; + virtual void operator()(const tail_call_param_t *c) const = 0; + virtual status_t create_kernel() = 0; + virtual ~kernel_t() = default; + + /** inits kernel descriptor: + * desc -- kernel descriptor (output) + * prb -- transposition problem (input) + * ndims_ker_max -- limit the maximum number of dimensions kernel + * will process (optional, 0 -- no limitation) */ + static status_t desc_init( + desc_t &desc, const prb_t &prb, int ndims_ker_max = 0); + + /** creates kernel for the problem described in desc */ + static kernel_t *create(const desc_t &desc); + + /** Minimal reasonable/desirable kernel size. + * The constant might be used to determine how a problem should be split + * between kernel and threading driver. */ + static constexpr size_t ker_prb_size_min = 64; + +protected: + const desc_t desc_; + const prb_t &prb_ = desc_.prb; + bool compensation_needed_ = false; +}; + +/* kernel */ +struct jit_uni_reorder_kernel_f32_t : public kernel_t, public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32) + + using XReg = Xbyak_aarch64::XReg; + using WReg = Xbyak_aarch64::WReg; + using ZReg = Xbyak_aarch64::ZReg; + using ZRegS = Xbyak_aarch64::ZRegS; + using VReg = Xbyak_aarch64::VReg; + using VReg4S = Xbyak_aarch64::VReg4S; + using PReg = Xbyak_aarch64::PReg; + + void operator()(const call_param_t *c) const override; + void operator()(const tail_call_param_t *c) const override; + + status_t create_kernel() override; + + enum class scale_arg_t { NONE, SRC, DST }; + + enum { + len_unroll_max = 256, + ndims_jit_loop_max = 3, + }; + + struct simple_impl_desc_t { + int ndims_full_unroll = 0; + int len_last_dim_unroll = 0; + int tail_len_unroll = 0; + int len_unroll = 0; + }; + + static bool simple_impl_desc_init( + const prb_t &prb, simple_impl_desc_t *desc); + + static bool applicable(const prb_t &p); + + XReg o_addr(int o_off, bool with_type_multiplier = true); + + XReg src_s_addr(int s_off); + + XReg dst_s_addr(int s_off); + + XReg c_addr(int c_off); + + XReg data_chunk_addr(int node_id); + + void step(int off, int prev_i_off, int prev_o_off, int prev_s_off, + int prev_c_off, int &i_off, int &o_off, int &s_off, int &c_off, + int step_size = 1); + + void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off, + int step_size = 1); + + bool can_do_tr4x8(); + bool process_unroll_tr4x8(const int ndims, const int len); + void tr4x8_sve256(int i_off, int o_off); + + void tr8x8_sve256(int i_off, int o_off); + + bool can_do_tr8x8(); + + bool process_unroll_tr8x8(const int ndims, const int len); + + template + bool process_direct_copy(const int ndims, const int len); + + void process_unroll_generic_step(int reg_unroll, const int *i_off, + const int *o_off, const int *s_off, const int *c_off, + const int *zero_padding, const bool tail_processing); + + static bool interim_f32_needed(const prb_t &prb, bool compensation_needed); + + void process_unroll_generic( + const int ndims, int len, const bool tail_processing); + + void compute_ker( + const int ndims, const int len_unroll, const bool tail_processing); + + void loop_begin(Xbyak_aarch64::Label &l, XReg reg_cnt, int len); + + void check_if_this_is_last_chunk(const XReg reg_curr_chunk, int node_id); + + void zero_dst_memory(const int bytes_to_zeroing); + + void finalize_tail_loop(int i_step, int o_step, int s_step, int c_step, + const int curr_node_id); + + void loop_end(Xbyak_aarch64::Label &l, XReg reg_cnt, int len, int i_step, + int o_step, int s_step, int c_step, const int curr_node_id); + + void compute_blk_ker(const simple_impl_desc_t &desc); + + void create_loops(const simple_impl_desc_t &desc, + const std::array ®_cnt, int jit_loop); + + bool simple_impl(); + + void impl(); + + void cvt_z_s32_f32(const size_t startIdx, const size_t regNum); + void cvt_v_s32_f32(const size_t startIdx, const size_t regNum); + void cvt_z_f32_s32(const size_t startIdx, const size_t regNum); + void cvt_v_f32_s32(const size_t startIdx, const size_t regNum); + void cvt_v_f32_bf16(const size_t startIdx, const size_t regNum); + void cvt_v_bf16_fp32(const size_t startIdx, const size_t regNum); + void cvt_v_f16_f32(const size_t startIdx, const size_t regNum); + void cvt_v_f32_f16(const size_t startIdx, const size_t regNum); + void cvt_z_s8_s32(const size_t startIdx, const size_t regNum); + void cvt_v_s8_s32(const size_t startIdx, const size_t regNum); + void cvt_z_s8_f32(const size_t startIdx, const size_t regNum); + void cvt_v_s8_f32(const size_t startIdx, const size_t regNum); + void cvt_z_b_s(const size_t startIdx, const size_t regNum); + void cvt_v_b_s(const size_t startIdx, const size_t regNum); + void cvt_z_u8_s32(const size_t startIdx, const size_t regNum); + void cvt_v_u8_s32(const size_t startIdx, const size_t regNum); + void cvt_z_s32_s8(const size_t startIdx, const size_t regNum); + void cvt_v_s32_s8(const size_t startIdx, const size_t regNum); + void cvt_z_u8_s8(const size_t startIdx, const size_t regNum); + void cvt_v_u8_s8(const size_t startIdx, const size_t regNum); + void cvt_z_u32_u8(const size_t startIdx, const size_t regNum); + void cvt_v_u32_u8(const size_t startIdx, const size_t regNum); + void cvt_z_s32_u8(const size_t startIdx, const size_t regNum); + void cvt_v_s32_u8(const size_t startIdx, const size_t regNum); + void cvt_z_s8_u8(const size_t startIdx, const size_t regNum); + void cvt_v_s8_u8(const size_t startIdx, const size_t regNum); + + jit_uni_reorder_kernel_f32_t(const desc_t &desc); + + void generate() override; + + ~jit_uni_reorder_kernel_f32_t() override = default; + +private: + static constexpr int64_t with_tail_info_ = static_cast(true); + static constexpr int64_t without_tail_info_ = static_cast(false); + + int itype_sz_; + int otype_sz_; + int stype_sz_; + + const cpu_isa_t isa_; + + const XReg reg_ptr_in_ = x6; + const XReg reg_ptr_out_ = x2; + const XReg reg_ptr_src_scales_ = x1; + const XReg reg_ptr_dst_scales_ = x12; + const XReg reg_ptr_comp_ = x3; + const WReg reg_scale_adjust_ = w5; + + const XReg reg_off_in_ = x8; + const XReg reg_off_out_ = x9; + const XReg reg_off_comp_ = x11; + + /* X_TMP is required to set address to + x_tmp_vec(X_TMP_0 - X_TMP_4). */ + XReg X_TMP = x20; + + VReg4S xmm_src_scales_ = v15.s; + VReg4S xmm_dst_scales_ = v11.s; + VReg4S xmm_zero_ = v14.s; + ZRegS ymm_zero_ = z14.s; + VReg4S xmm_tmp_ = v12.s; + const VReg4S xmm_src_zp_ = v9.s; + const VReg4S xmm_dst_zp_ = v10.s; + const VReg4S xmm_compensation = v8.s; + VReg4S xmm_saturation_ubound_ = v12.s; + ZRegS ymm_saturation_ubound_ = z12.s; + + /* Note: x22 - x28 are already used as temporal registgers + in jit_generator.hpp. + x_ptr_(in|out|scale|comp)_off keeps (base + offset) address. */ + XReg x_ptr_in_off = reg_ptr_in_; + XReg x_ptr_out_off = reg_ptr_out_; + XReg x_ptr_comp_off = reg_ptr_comp_; + XReg x_ptr_src_scale_off = x19; + XReg x_ptr_dst_scale_off = x29; + + /* Caution: Chose predicate registers not used by x64's implementation. */ + PReg p_lsb_256 = p7; + PReg p_lsb_128 = p6; + PReg p_lsb_64 = p4; + PReg p_tmp0 = p5; + + const std::vector tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27}; + VReg v_tmp0 = v20; + ZReg z_tmp0 = z20; + ZReg z_tmp1 = z21; + ZReg z_tmp2 = z22; + ZReg z_tmp3 = z23; + ZReg z_tmp4 = z24; + ZReg z_tmp5 = z25; + ZReg z_tmp6 = z26; + ZReg z_tmp7 = z27; + VReg v_tmp7 = v27; + + const std::vector z_tmp_vec + = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7}; + constexpr static int z_tmp_vec_size = 8; +}; + +/* TODO: add trans_t class */ + +// Seperate class for no unroll/threading burden +struct jit_single_blk_kernel_t : public jit_generator { + DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_single_blk_kernel) + using XReg = Xbyak_aarch64::XReg; + using ZRegS = Xbyak_aarch64::ZRegS; + using ZReg = Xbyak_aarch64::ZReg; + using PReg = Xbyak_aarch64::PReg; + using VReg = Xbyak_aarch64::VReg; + + static bool applicable(const prb_t &p); + + jit_single_blk_kernel_t(const prb_t &prb); + + void generate() override; + + void gen_loadu(const ZRegS ymm, const XReg &addr, int size); + + void gen_storeu(const XReg &addr, const ZRegS ymm, int size); + + void gen_maskloadu( + const ZRegS ymm, const XReg &addr, const PReg mask, int size); + + void gen_maskstoreu( + const XReg &addr, const ZRegS ymm, const PReg mask, int size); + + // Register allocation xmm0~11 + void gen_transpose_8x8(); + + void gen_transpose_4x4(); + + // keep order nchw -> nChw()C + // or nChw()C -> nchw + void gen_setmask(int mask); + + void gen_tr4x4(int i_off, int o_off, int input_stride, int output_stride, + int in_tail, int out_tail); + void gen_ker4x4(int i_off, int o_off, int input_stride, int output_stride, + int in_tail, int out_tail); + + // TODO: Mark parameter with type information + // XXX: ! + // offset in byte offset + // stride in element number + // + // Gen specific 8x8 transform respect to certain tail condition + void gen_tr8x8(int i_off, int o_off, int input_stride, int output_stride, + int in_tail, int out_tail); + + // tail: 0 ~ 8 + // support: either in_tail or out_tail is not 8, but not both + void gen_ker8x8(int i_off, int o_off, int input_stride, int output_stride, + int in_tail, int out_tail); + + void gen_ker16x16_in_8x8( + int i_off, int o_off, int input_stride, int output_stride); + + // tail can be 1 ~ 16, using sve2 for now + void gen_ker16x16_in_8x8(int i_off, int o_off, int input_stride, + int output_stride, int in_tail, int out_tail); + + void gen_ker32x32_in_16x16( + int i_off, int o_off, int input_stride, int output_stride); + + void gen_ker32x32_in_16x16(int i_off, int o_off, int input_stride, + int output_stride, int in_tail, int out_tail); + + void gen_ker64x64_in_32x32( + int i_off, int o_off, int input_stride, int output_stride); + + void gen_ker64x64_in_32x32(int i_off, int o_off, int input_stride, + int output_stride, int in_tail, int out_tail); + +private: + // 6 ~ 12 + constexpr static int xmm_save_start_from = 6; + constexpr static int xmm_width = 16; + + void preamble(); + + void postamble(); + + const tr::prb_t &prb_; + + int itype_sz_; + int otype_sz_; + int block_sz; + + XReg reg_ptr_in_ = abi_param1; + XReg reg_ptr_out_ = abi_param2; + XReg reg_ptr_tail = abi_param3; + + /* Because the callee-saved registers are not restored blk_reorder, + the temporary registers (x9-x15) must be assigned. + Must be selected from the temporary registers (x9-x15). */ + XReg x_addr = x10; + XReg x_tmp_0 = x11; + XReg x_tmp_1 = x12; + + /* Avoid P_TMP(p7) in jit_generator.hpp. */ + PReg p_lsb_256 = p6; + PReg p_mask = p5; + PReg p_tmp1 = p4; + PReg p_tmp2 = p3; + + ZRegS ymm_tmp = z0.s; + + const std::vector tmp_vec_idx = {20, 21, 22, 23, 24, 25, 26, 27}; + VReg v_tmp0 = v20; + ZReg z_tmp0 = z20; + ZReg z_tmp1 = z21; + ZReg z_tmp2 = z22; + ZReg z_tmp3 = z23; + ZReg z_tmp4 = z24; + ZReg z_tmp5 = z25; + ZReg z_tmp6 = z26; + ZReg z_tmp7 = z27; + VReg v_tmp7 = v27; + + const std::vector z_tmp_vec + = {z_tmp0, z_tmp1, z_tmp2, z_tmp3, z_tmp4, z_tmp5, z_tmp6, z_tmp7}; + constexpr static int z_tmp_vec_size = 8; +}; + +} // namespace tr +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/aarch64/jit_uni_reorder_utils.cpp b/src/cpu/aarch64/reorder/jit_uni_reorder_utils.cpp similarity index 64% rename from src/cpu/aarch64/jit_uni_reorder_utils.cpp rename to src/cpu/aarch64/reorder/jit_uni_reorder_utils.cpp index 90e78f3877b..47f8298d8ce 100644 --- a/src/cpu/aarch64/jit_uni_reorder_utils.cpp +++ b/src/cpu/aarch64/reorder/jit_uni_reorder_utils.cpp @@ -1,7 +1,7 @@ /******************************************************************************* -* Copyright 2018-2023 Intel Corporation +* Copyright 2018 Intel Corporation * Copyright 2020-2023 FUJITSU LIMITED -* Copyright 2022 Arm Ltd. and affiliates +* Copyright 2022, 2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,17 +17,20 @@ *******************************************************************************/ #include -#include +#include #include "common/c_types_map.hpp" -#include "common/dnnl_thread.hpp" #include "common/memory_desc_wrapper.hpp" #include "common/nstl.hpp" +#include "common/primitive_attr.hpp" +#include "common/reorder_pd.hpp" #include "common/type_helpers.hpp" #include "common/utils.hpp" +#include "cpu/platform.hpp" #include "oneapi/dnnl/dnnl_debug.h" -#include "cpu/aarch64/jit_uni_reorder.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder_kernel.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder_utils.hpp" // #define DNNL_DEV_MODE #if defined(DNNL_DEV_MODE) @@ -50,6 +53,24 @@ namespace aarch64 { namespace tr { +namespace { +inline bool is_direct_copy(const prb_t &prb) { + return prb.ndims == 1 && prb.nodes[0].is == 1 && prb.nodes[0].os == 1; +} +} // namespace + +bool prb_has_small_strides(const prb_t &prb) { + constexpr ptrdiff_t max_stride = (1LL << 31) - 1; + for (int d = 0; d < prb.ndims; ++d) { + const ptrdiff_t cms = max_stride / prb.nodes[d].n; + const bool small_strides = true + && prb.nodes[d].is < cms / (int)data_type_size(prb.itype) + && prb.nodes[d].os < cms / (int)data_type_size(prb.otype); + if (!small_strides) return false; + } + return true; +} + /** ad-hoc structure to describe blocked memory layout */ struct layout_desc_t { layout_desc_t() @@ -289,7 +310,11 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, = dst_mask == 0 ? scale_type_t::COMMON : scale_type_t::MANY; } - if (src_mask != dst_mask) return status::unimplemented; + VDISPATCH_REORDER_IC( + IMPLICATION(p.src_scale_type != scale_type_t::NONE + && p.dst_scale_type != scale_type_t::NONE, + src_mask == dst_mask), + VERBOSE_UNSUPPORTED_SCALES_CFG); p.scale_adjust = (om_d.extra().flags & memory_extra_flags::scale_adjust) ? om_d.extra().scale_adjust @@ -310,7 +335,7 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, om_d.extra().asymm_compensation_mask)) return status::unimplemented; - ptrdiff_t ss[max_ndims] = {0}; // scales strides + ptrdiff_t ss[DNNL_MAX_NDIMS] = {0}; // scales strides if (p.src_scale_type == scale_type_t::MANY || p.dst_scale_type == scale_type_t::MANY) { const int mask = nstl::max(src_mask, dst_mask); @@ -336,13 +361,13 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, p.compensation_mask = p.req_s8s8_comp ? om_d.extra().compensation_mask : (p.req_asymmetric_comp ? om_d.extra().asymm_compensation_mask - : tr::prb_t::invalid_comp_mask); + : prb_t::invalid_comp_mask); - if (p.compensation_mask == tr::prb_t::asymmetric_comp_mask) + if (p.compensation_mask == prb_t::asymmetric_comp_mask) return unimplemented; - assert(p.compensation_mask == tr::prb_t::standard_comp_mask - || p.compensation_mask == tr::prb_t::comp_mask_with_groups); + assert(p.compensation_mask == prb_t::standard_comp_mask + || p.compensation_mask == prb_t::comp_mask_with_groups); } int ndims = 0; @@ -353,8 +378,8 @@ status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd, while (i_pos < ild.ndims && o_pos < old.ndims) { assert(ild.id[i_pos] == old.id[o_pos]); - assert(ndims < max_ndims); - if (ndims == max_ndims) return runtime_error; + assert(ndims < DNNL_MAX_NDIMS); + if (ndims == DNNL_MAX_NDIMS) return runtime_error; if (ild.dims[i_pos] == old.dims[o_pos]) { p.nodes[ndims].n = ild.dims[i_pos]; @@ -536,7 +561,7 @@ void prb_simplify(prb_t &p) { void prb_node_split(prb_t &p, int dim, size_t new_node_size) { assert(dim < p.ndims); - assert(p.ndims < max_ndims); + assert(p.ndims < DNNL_MAX_NDIMS); assert(p.nodes[dim].n % new_node_size == 0); p.ndims += 1; @@ -575,7 +600,7 @@ void prb_node_split(prb_t &p, int dim, size_t new_node_size) { void prb_node_swap(prb_t &p, int d0, int d1) { assert(d0 < p.ndims); assert(d1 < p.ndims); - assert(p.ndims < max_ndims); + assert(p.ndims < DNNL_MAX_NDIMS); if (d0 == d1) return; @@ -585,7 +610,7 @@ void prb_node_swap(prb_t &p, int d0, int d1) { void prb_node_move(prb_t &p, int d0, int d1) { assert(d0 < p.ndims); assert(d1 < p.ndims); - assert(p.ndims < max_ndims); + assert(p.ndims < DNNL_MAX_NDIMS); if (d0 == d1) return; @@ -617,6 +642,233 @@ std::string prb_dump(const prb_t &p) { return ss.str(); } +void prb_block_for_cache(prb_t &prb) { + // Performance improvements when doing simple inner blocking of 8 or 4 + // This covers ab->Ba8b, ab->Ba4b, ba->Ab8a, ba->Ab4a and cdba->Acdb8a + // Split middle node, then swap to improve cache locality + // Before split+swap we traverse src column-wise 8 row elements at a time from top to bottom + // After split+swap in src we traverse split_countx8 row-wise for inner_blk=8, and split_countx4 for inner_blk=4 + if (prb.ndims == 3) { + const int inner_blk = prb.nodes[0].n; + if ((inner_blk == 8 || inner_blk == 4) && prb.nodes[0].is == 1 + && prb.nodes[0].os == 1 && prb.nodes[1].os == inner_blk + && prb.nodes[2].is == inner_blk) { + + // Try finding value to split on + // prb.nodes[1].n > 32 to ensure we only split if enough rows for split to have cache benefit + size_t split_value = 0; + for (size_t d = 8; d >= 2; --d) { + if (prb.nodes[1].n % d == 0 && prb.nodes[1].n != d + && prb.nodes[1].n > 32) { + split_value = d; + break; + } + } + + // Split on found split_value, if any + if (split_value) { + prb_node_split(prb, 1, split_value); + prb_node_swap(prb, 3, 2); + prb_node_dependency(prb); + } + } + } + + /* If strides for 0th and 1st nodes are cache friendly + * then one can altogether do away with blocking ! */ + static constexpr int num_elems_thr = 16; + const bool stride_cache_friendly + = ((prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > num_elems_thr) + || (prb.ndims > 1 && prb.nodes[1].is % num_elems_thr == 0 + && prb.nodes[1].n > num_elems_thr)) + && !prb.is_tail_present; + + // TODO: Find a way to associate the caching logic to its kernel. + // The issue is that this swap logic is separated from the tr4x8 kernel that + // it is relevant to. This is a performance improvement for the + // f32:bf16 ab->BA8b4a reorder. + if (mayiuse(sve_256) && prb.ndims == 4 && prb.n(0) == 4 && prb.is(0) != 1 + && prb.is(1) == 1 && prb.os(1) == 4 && prb.n(1) == 8 + && prb.is(3) == 8 && prb.itype == data_type::f32 + && prb.otype == data_type::bf16) { + // Changes the order of traversal of the tile from column-wise to + // row-wise. This makes the reads more cache-friendly at the cost of + // the writes being separated by some stride. + tr::prb_node_move(prb, 2, 3); + } + + // performance improvement for shapes with large inner-most dimension + const size_t L1_cache_sz + = size_t(3) * platform::get_per_core_cache_size(1) / 4; + const size_t itype_sz_ = data_type_size(prb.itype); + const size_t inner_block_sz = prb.nodes[0].n * itype_sz_; + const bool requires_inner_blocking = inner_block_sz > L1_cache_sz + // 'is_tail_present' is not supported for cache_blocking when + // asymmetric_comp is executed. + && IMPLICATION(prb.req_asymmetric_comp, !prb.is_tail_present); + + const bool cache_blocking_needed + = stride_cache_friendly || requires_inner_blocking; + if (!cache_blocking_needed || is_direct_copy(prb)) return; + + int unit_input_stride_idx = -1; + for (auto idx = 0; idx < prb.ndims; ++idx) { + if (prb.nodes[idx].is == 1) unit_input_stride_idx = idx; + } + + /* Re-prioritize the sequential read over sequential write: + * /-> [n0:is0:1][16n1:1:osk]... + * [n0:is0:1]...[nk:1:osk] --> or + * \-> [16n1:1:osk][n0:is0:1]... */ + if (unit_input_stride_idx != -1) { + const auto output_stride = prb.nodes[unit_input_stride_idx].os; + const auto num_elems = prb.nodes[unit_input_stride_idx].n; + + const bool split_needed = (num_elems > num_elems_thr) + && (num_elems % num_elems_thr == 0); + const int move_location = (output_stride % 4 != 0) ? 0 : 1; + if (split_needed) + prb_node_split(prb, unit_input_stride_idx, num_elems_thr); + + /* Because of cache-unfriendly nature of unit-output stride node, let + * us move unit-input stride node on or near front! */ + if (unit_input_stride_idx != move_location) + prb_node_move(prb, unit_input_stride_idx, move_location); + } + + /* Potentially, split the node with os=1 in two and pull in the node with + * is=1 between them for better cache reuse: + * [n0:is0:1][n1:1:os1] --> [16n0:is0:1][n1:1:os1][n0/16:is0*16:16] */ + if (prb.ndims >= 2 && prb.nodes[0].os == 1 && prb.nodes[1].is == 1) { + const auto num_elems = prb.nodes[0].n; + + const bool split_needed = (num_elems > num_elems_thr) + && (num_elems % num_elems_thr == 0); + if (split_needed) { + prb_node_split(prb, 0, num_elems_thr); + prb_node_move(prb, 1, 2); + + // Update node information + prb_node_dependency(prb); + + // heuristics - looping over the unrolled dims should maximize reuse + // of the already cached data; observation is choosing the smallest + // dim from the remaining (from 2 up to ndims) gives good results + constexpr int new_position = 2; + const auto dim_beg_it = std::begin(prb.nodes); + const auto dim_two_it = dim_beg_it + new_position; + const auto dim_last_it = dim_beg_it + prb.ndims; + const auto min_n_node_it = std::min_element(dim_two_it, dim_last_it, + [](const tr::node_t &lhs, const tr::node_t &rhs) { + return lhs.n < rhs.n; + }); + const auto min_idx = std::distance(dim_beg_it, min_n_node_it); + // check if min_idx node is parent of node with tail processing which + // is currently unsupported (i.e. tail processing can only be handled + // at the inner-most dimension) + bool inner_block_has_tail = false; + for (int idx = min_idx - 1; idx >= new_position; idx--) { + if (prb.nodes[idx].parent_node_id == min_idx) { + inner_block_has_tail = true; + break; + } + } + + if (min_idx > new_position && (!inner_block_has_tail)) + prb_node_move(prb, min_idx, new_position); + } + } +} + +void prb_thread_kernel_balance(prb_t &prb, int &ndims_ker_max, int nthr) { + size_t size_total = 1; + for (int d = 0; d < prb.ndims; ++d) + size_total *= prb.nodes[d].n; + + /* The general expression for size_drv_thr can be written as + * size_drv_min = C0 + FC * (nthr > 1 ? 1 : 0) + VC * (nthr - 1) + * where FC and VC are fixed and variable costs respectively. + * Though for now, the below heuristic seems to be good enough */ + // Note: direct copy needs only as many kernels as nthr. + const size_t size_drv_thr = is_direct_copy(prb) ? nthr + : (nthr > 1) ? 16 * nthr + : 1; + + /* size_drv_min is the minimal size for the parallel + * driver required for good parallelization */ + const size_t size_drv_min + = nstl::min(size_drv_thr, utils::div_up(size_total, 1024)); + + /* kdims -- # of dimensions processed by a kernel + * size_ker_cur -- product of the dimension processed by a kernel + * size_drv_cur -- product of the dimension processed by a driver */ + + int kdims = prb.ndims; + size_t size_drv_cur = 1; + for (; kdims > 1 && size_drv_cur < size_drv_min; --kdims) + size_drv_cur *= prb.nodes[kdims - 1].n; + + size_t size_ker_cur = 1; + for (int d = 0; d < kdims; ++d) + size_ker_cur *= prb.nodes[d].n; + + /* Initially kdims is chosen so that size_drv_cur >= size_drv_min. + * + * It might happen that for chosen kdims the size_ker_cur is too small + * (less than tr::ker_prb_size_min). In that case try to split the + * innermost driver dimension into two, to increase size_ker_cur. */ + const bool want_borrow_ker_from_drv = kdims < prb.ndims + && size_ker_cur < kernel_t::ker_prb_size_min + && size_drv_cur > size_drv_min; + if (want_borrow_ker_from_drv) { + /* size_want_borrow is the minimal size, so that: + * o) size_ker_cur * size_want_borrow >= tr::ker_prb_size_min + * o) current innermost driver dimension is divisible by + * size_want_borrow (so that we can evenly split that + * dimension into two) + * + * In the worst case the minimal size_want_borrow is equal + * to the innermost driver dimension itself. In that case + * we will sacrifice it in favor of kernel (is it fine?). */ + size_t size_want_borrow + = utils::div_up(kernel_t::ker_prb_size_min, size_ker_cur); + for (; prb.nodes[kdims].n % size_want_borrow; ++size_want_borrow) + ; + + if (size_want_borrow != prb.nodes[kdims].n) + prb_node_split(prb, kdims, size_want_borrow); + kdims += 1; + } + + /* On the other hand it might happen that for chosen kdims + * the size_drv_cur is too small (less than size_drv_min). In that case + * try to split the outermost kernel dimension into two, to increase + * size_drv_cur. */ + const bool want_borrow_drv_from_ker + = size_ker_cur > kernel_t::ker_prb_size_min + && size_drv_cur < size_drv_min; + if (want_borrow_drv_from_ker) { + size_t size_want_borrow = utils::div_up(size_drv_min, size_drv_cur); + for (; prb.nodes[kdims - 1].n % size_want_borrow; ++size_want_borrow) + ; + + if (size_want_borrow != prb.nodes[kdims - 1].n) + prb_node_split( + prb, kdims - 1, prb.nodes[kdims - 1].n / size_want_borrow); + } + + ndims_ker_max = kdims; + + if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) { + DEBUG({ + verbose_printf( + verbose_t::debuginfo, "split: %s\n", prb_dump(prb).c_str()); + verbose_printf(verbose_t::debuginfo, "ndims_ker_max = %d\n", + ndims_ker_max); + }); + } +} + } // namespace tr } // namespace aarch64 diff --git a/src/cpu/aarch64/reorder/jit_uni_reorder_utils.hpp b/src/cpu/aarch64/reorder/jit_uni_reorder_utils.hpp new file mode 100644 index 00000000000..d1fa8d33a65 --- /dev/null +++ b/src/cpu/aarch64/reorder/jit_uni_reorder_utils.hpp @@ -0,0 +1,168 @@ +/******************************************************************************* +* Copyright 2018 Intel Corporation +* Copyright 2020-2023 FUJITSU LIMITED +* Copyright 2022, 2025 Arm Ltd. and affiliates +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_AARCH64_REORDER_JIT_UNI_REORDER_UTILS_HPP +#define CPU_AARCH64_REORDER_JIT_UNI_REORDER_UTILS_HPP + +#include +#include + +#include "common/c_types_map.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace aarch64 { + +namespace tr { +struct node_t { + static constexpr int64_t empty_field = -1; + + size_t n = 0; + size_t tail_size = 0; + int dim_id = empty_field; + int parent_node_id = empty_field; + bool is_zero_pad_needed = false; + ptrdiff_t is = 0; // input stride + ptrdiff_t os = 0; // output stride + ptrdiff_t ss = 0; // scale stride + ptrdiff_t cs = 0; // compensation stride + + bool is_dim_id_empty() const { return dim_id == empty_field; } + bool is_parent_empty() const { return parent_node_id == empty_field; } +}; + +enum class scale_type_t { NONE, COMMON, MANY }; + +struct prb_t { + /* The compensation mask value indicates how big an additional buffer should be. + * Possible values for reorder: + * 1) standard compensation = 1 = 0b01 + * 2) asymmetric compensation = 2 = 0b10 + * 3) compensation if tensor contains group = 3 = 0b11 */ + static constexpr int invalid_comp_mask = 0; + static constexpr int standard_comp_mask = 0b1; + static constexpr int asymmetric_comp_mask = 0b10; + static constexpr int comp_mask_with_groups + = standard_comp_mask + asymmetric_comp_mask; + + bool is_tail_in_one_of_child_nodes(int parent_node_id) const { + for (int i = parent_node_id; i >= 0; i--) { + if (nodes[i].parent_node_id == parent_node_id) { + if (nodes[i].tail_size != 0) + return true; + else + parent_node_id = i; + } + } + + return false; + } + + int tail(int d) const { + assert(d < ndims); + return static_cast(nodes[d].tail_size); + } + + int n(int d) const { + assert(d < ndims); + return static_cast(nodes[d].n); + } + int is(int d) const { + assert(d < ndims); + return static_cast(nodes[d].is); + } + int os(int d) const { + assert(d < ndims); + return static_cast(nodes[d].os); + } + int ss(int d) const { + assert(d < ndims); + return static_cast(nodes[d].ss); + } + + int cs(int d) const { + assert(d < ndims); + return static_cast(nodes[d].cs); + } + + data_type_t itype; + data_type_t otype; + int ndims; + node_t nodes[DNNL_MAX_NDIMS]; + ptrdiff_t ioff; + ptrdiff_t ooff; + scale_type_t src_scale_type; + scale_type_t dst_scale_type; + float beta; + int full_ndims; + bool is_tail_present = false; + float scale_adjust = 1.f; + int compensation_mask = invalid_comp_mask; + bool req_s8s8_comp = false; + bool req_asymmetric_comp = false; + bool req_src_zp = false; + bool req_dst_zp = false; +}; + +bool prb_has_small_strides(const prb_t &prb); + +status_t prb_init(prb_t &prb, const memory_desc_t &imd, + const memory_desc_t &omd, const primitive_attr_t *attr); + +/** sorts the problem nodes so that output strides come in ascending order */ +void prb_normalize(prb_t &p); + +/** fill parent node info for blocked nodes */ +void prb_node_dependency(prb_t &p); + +/** folds nodes together if possible */ +void prb_simplify(prb_t &p); + +/** splits the node dim into two of sizes n1 and n / n1 + * @warning n must be multiple of n1 */ +void prb_node_split(prb_t &p, int dim, size_t n1); + +/** swaps d0 and d1 nodes */ +void prb_node_swap(prb_t &p, int d0, int d1); + +/** moves node d0 to the d1 position. + * nodes (d0, d1] are shifted to the left if d0 < d1 or + * to the right if d0 > d1 */ +void prb_node_move(prb_t &p, int d0, int d1); + +bool prb_has_small_strides(const prb_t &prb); + +/** dumps the problem to a string */ +std::string prb_dump(const prb_t &p); + +void prb_block_for_cache(prb_t &prb); + +/** finds the maximum number of dimension the kernel should process and + * optionally splits one of the dimension to achieve better balance between + * parallel driver and the kernel. */ +void prb_thread_kernel_balance(prb_t &prb, int &ndims_ker_max, int nthr); + +} // namespace tr + +} // namespace aarch64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/cpu/platform.hpp b/src/cpu/platform.hpp index af0d6e944a8..f5b15bba4d9 100644 --- a/src/cpu/platform.hpp +++ b/src/cpu/platform.hpp @@ -97,8 +97,8 @@ // Helper macros: expand the parameters only on the corresponding architecture. // Equivalent to: #if DNNL_$ARCH ... #endif #define DNNL_X64_ONLY(...) Z_CONDITIONAL_DO(DNNL_X64, __VA_ARGS__) -#define DNNL_PPC64_ONLY(...) Z_CONDITIONAL_DO(DNNL_PPC64_ONLY, __VA_ARGS__) -#define DNNL_S390X_ONLY(...) Z_CONDITIONAL_DO(DNNL_S390X_ONLY, __VA_ARGS__) +#define DNNL_PPC64_ONLY(...) Z_CONDITIONAL_DO(DNNL_PPC64, __VA_ARGS__) +#define DNNL_S390X_ONLY(...) Z_CONDITIONAL_DO(DNNL_S390X, __VA_ARGS__) #define DNNL_AARCH64_ONLY(...) Z_CONDITIONAL_DO(DNNL_AARCH64, __VA_ARGS__) #define DNNL_ARM_ONLY(...) Z_CONDITIONAL_DO(DNNL_ARM, __VA_ARGS__) @@ -122,6 +122,12 @@ #define DNNL_ACL_ONLY(...) #endif +#if DNNL_AARCH64 && defined(DNNL_USE_ACL) +#define DNNL_AARCH64_ACL_ONLY(...) __VA_ARGS__ +#else +#define DNNL_AARCH64_ACL_ONLY(...) +#endif + // Primitive ISA section for configuring knobs. // Note: MSVC preprocessor by some reason "eats" symbols it's not supposed to // if __VA_ARGS__ is passed as empty. Then things happen like this for non-x64: diff --git a/src/cpu/reorder/cpu_reorder.hpp b/src/cpu/reorder/cpu_reorder.hpp index 6b40b927d33..8ca709b5c55 100644 --- a/src/cpu/reorder/cpu_reorder.hpp +++ b/src/cpu/reorder/cpu_reorder.hpp @@ -36,12 +36,13 @@ #include "cpu/x64/jit_uni_reorder_direct_copy.hpp" #include "cpu/x64/matmul/brgemm_matmul_reorders.hpp" #elif DNNL_AARCH64 -#include "cpu/aarch64/jit_uni_reorder.hpp" +#include "cpu/aarch64/reorder/jit_uni_reorder.hpp" +#include "cpu/aarch64/reorder/jit_blk_reorder.hpp" #include "cpu/aarch64/matmul/brgemm_matmul_reorders.hpp" #endif #if DNNL_AARCH64 && DNNL_USE_ACL -#include "cpu/aarch64/acl_reorder.hpp" +#include "cpu/aarch64/reorder/acl_reorder.hpp" #endif #include "cpu/rnn/rnn_reorders.hpp" diff --git a/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp index 759706c737b..762948b38d3 100644 --- a/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_bf16_s8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,170 +26,170 @@ const impl_list_map_t &comp_bf16_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // bf16 -> s8 {{bf16, s8, 2}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp)) - REG_SR(bf16, ab, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, ab, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, ab, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, ab, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, ba, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, ba, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, ba, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, ba, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oi, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, format_tag::io, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ab, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ba, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp)) + REG_SR(bf16, ab, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, ab, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, ab, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, ab, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, ba, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, ba, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, ba, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, ba, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) nullptr, }}, // bf16 -> s8 {{bf16, s8, 3}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, wio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - REG_SR(bf16, abc, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, abc, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, abc, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, abc, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, acb, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, acb, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, acb, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(bf16, acb, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, wio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + REG_SR(bf16, abc, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, abc, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, abc, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, abc, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, acb, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, acb, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, acb, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(bf16, acb, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) nullptr, }}, {{bf16, s8, 4}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, hwio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, wigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, hwio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, wigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, {{bf16, s8, 5}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, hwigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, dhwio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, hwigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, dhwio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, {{bf16, s8, 6}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, dhwigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(bf16, any, s8, dhwigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(bf16, goidhw, s8, gOIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp index 104868ac072..b2adb3c6471 100644 --- a/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_f32_s8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2023 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,168 +27,168 @@ const impl_list_map_t &comp_f32_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> s8 {{f32, s8, 2}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) - REG_SR(f32, ab, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, ab, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, ab, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, ab, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, ba, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, ba, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, ba, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, ba, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oi, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, format_tag::io, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) + REG_SR(f32, ab, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, ab, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, ab, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, ab, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, ba, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, ba, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, ba, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, ba, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) nullptr, }}, // f32 -> s8 {{f32, s8, 3}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, wio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - REG_SR(f32, abc, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, abc, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, abc, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, abc, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, acb, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, acb, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, acb, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(f32, acb, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, wio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + REG_SR(f32, abc, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, abc, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, abc, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, abc, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, acb, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, acb, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, acb, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(f32, acb, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) nullptr, }}, {{f32, s8, 4}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, hwio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, wigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, wigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, {{f32, s8, 5}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, hwigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, dhwio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, dhwio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, {{f32, s8, 6}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, dhwigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(f32, any, s8, dhwigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(f32, goidhw, s8, gOIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp b/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp index 4cb92ea0832..d430b9a22a5 100644 --- a/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_comp_s8_s8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2025 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2023 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,175 +27,175 @@ const impl_list_map_t &comp_s8_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // s8 -> s8 {{s8, s8, 2}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i64o4i, fmt_order_keep, spec_conv_req_comp)) - REG_SR(s8, ab, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, ab, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, ab, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, ab, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, ba, s8, BA16a16b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, ba, s8, BA16a32b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, ba, s8, BA16a48b4a, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, ba, s8, BA16a64b4a, fmt_order_keep, spec_conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oi, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, format_tag::io, s8, OI4i64o4i, fmt_order::keep, spec::conv_req_comp)) + REG_SR(s8, ab, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, ab, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, ab, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, ab, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, ba, s8, BA16a16b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, ba, s8, BA16a32b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, ba, s8, BA16a48b4a, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, ba, s8, BA16a64b4a, fmt_order::keep, spec::conv_req_comp) nullptr, }}, // s8 -> s8 {{s8, s8, 3}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, wio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, Owi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - REG_SR(s8, abc, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, abc, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, abc, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, abc, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, acb, s8, aCB16b16c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, acb, s8, aCB16b32c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, acb, s8, aCB16b48c4b, fmt_order_keep, spec_conv_req_comp) - REG_SR(s8, acb, s8, aCB16b64c4b, fmt_order_keep, spec_conv_req_comp) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, wio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, Owi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, iwo, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oiw, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wio, s8, OIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + REG_SR(s8, abc, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, abc, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, abc, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, abc, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, acb, s8, aCB16b16c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, acb, s8, aCB16b32c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, acb, s8, aCB16b48c4b, fmt_order::keep, spec::conv_req_comp) + REG_SR(s8, acb, s8, aCB16b64c4b, fmt_order::keep, spec::conv_req_comp) nullptr, }}, {{s8, s8, 4}, { DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, hwio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, wigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOwi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, Owhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, wigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, Goiw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOwi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goiw, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, wigo, s8, gOIw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, Owhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, ihwo, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oihw, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwio, s8, OIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, {{s8, s8, 5}, { DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, hwigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, dhwio, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i32o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i64o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw16g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw8g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw4g, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOwhi16o, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, dhwio, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i32o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4i64o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw16g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw8g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, Goihw4g, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOwhi16o, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goihw, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, hwigo, s8, gOIhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, idhwo, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, oidhw, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, dhwio, s8, OIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, {{s8, s8, 6}, { DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, dhwigo, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw4i16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw2i8o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw4o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOdhwI16o4i, fmt_order_keep, spec_conv_req_comp)) - DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw16i16o4i, fmt_order_keep, spec_conv_req_comp)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + DNNL_NON_X64_ONLY(REG_SR(s8, any, s8, dhwigo, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw4i16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw2i8o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw4o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOdhwI16o4i, fmt_order::keep, spec::conv_req_comp)) + DNNL_NON_X64_ONLY(REG_SR(s8, goidhw, s8, gOIdhw16i16o4i, fmt_order::keep, spec::conv_req_comp)) nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_regular_bf16.cpp b/src/cpu/reorder/cpu_reorder_regular_bf16.cpp index 388afbfa39c..f917f049c59 100644 --- a/src/cpu/reorder/cpu_reorder_regular_bf16.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_bf16.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2024 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,10 +27,10 @@ const impl_list_map_t ®ular_bf16_impl_list_map() { // bf16 -> {{bf16, data_type::undef, 0}, { CPU_REORDER_INSTANCE(rnn_weights_reorder_t, bf16, bf16) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(bf16, any, f32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(bf16, any, f32, nCdhw16c)) @@ -53,14 +53,14 @@ const impl_list_map_t ®ular_bf16_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(bf16, any, u8, OIdhw16o16i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(bf16, any, u8, OIdhw16i16o)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - REG_SR(bf16, any, bf16, any, fmt_order_any, spec_reference) - REG_SR(bf16, any, f32, any, fmt_order_any, spec_reference) - REG_SR(bf16, any, s8, any, fmt_order_any, spec_reference) - REG_SR(bf16, any, u8, any, fmt_order_any, spec_reference) - REG_SR(bf16, any, f8_e5m2, any, fmt_order_any, spec_reference) - REG_SR(bf16, any, f8_e4m3, any, fmt_order_any, spec_reference) + REG_SR(bf16, any, bf16, any, fmt_order::any, spec::reference) + REG_SR(bf16, any, f32, any, fmt_order::any, spec::reference) + REG_SR(bf16, any, s8, any, fmt_order::any, spec::reference) + REG_SR(bf16, any, u8, any, fmt_order::any, spec::reference) + REG_SR(bf16, any, f8_e5m2, any, fmt_order::any, spec::reference) + REG_SR(bf16, any, f8_e4m3, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f16.cpp b/src/cpu/reorder/cpu_reorder_regular_f16.cpp index 5d6bb97ac57..de36293cddd 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f16.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f16.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,19 +28,19 @@ const impl_list_map_t ®ular_f16_impl_list_map() { {{f16, data_type::undef, 0}, { DNNL_AARCH64_ONLY(REG_SR_DIRECT_COPY(f16, f16)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) - REG_SR(f16, any, f8_e5m2, any, fmt_order_any, spec_reference) - REG_SR(f16, any, f8_e4m3, any, fmt_order_any, spec_reference) - REG_SR(f16, any, f16, any, fmt_order_any, spec_reference) - REG_SR(f16, any, f32, any, fmt_order_any, spec_reference) - REG_SR(f16, any, s8, any, fmt_order_any, spec_reference) - REG_SR(f16, any, u8, any, fmt_order_any, spec_reference) + REG_SR(f16, any, f8_e5m2, any, fmt_order::any, spec::reference) + REG_SR(f16, any, f8_e4m3, any, fmt_order::any, spec::reference) + REG_SR(f16, any, f16, any, fmt_order::any, spec::reference) + REG_SR(f16, any, f32, any, fmt_order::any, spec::reference) + REG_SR(f16, any, s8, any, fmt_order::any, spec::reference) + REG_SR(f16, any, u8, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp index 9b6d5cd4f2d..fb2e6a9d950 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_bf16.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2023 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,13 +30,13 @@ const impl_list_map_t ®ular_f32_bf16_impl_list_map() { CPU_REORDER_INSTANCE(rnn_weights_reorder_t, f32, bf16) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, bf16, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, bf16, nCdhw16c)) - DNNL_AARCH64_ONLY(DNNL_ACL_ONLY(CPU_REORDER_INSTANCE(acl::acl_reorder_fwd_t))) + DNNL_AARCH64_ACL_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) DNNL_NON_X64_ONLY(REG_SR(f32, oihw, bf16, OIhw8i16o2i, fmt_order::keep)) @@ -48,7 +48,7 @@ const impl_list_map_t ®ular_f32_bf16_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR(f32, oihw, bf16, OIhw16i16o, fmt_order::keep)) DNNL_NON_X64_ONLY(REG_SR(f32, goihw, bf16, gOIhw16i16o, fmt_order::keep)) - REG_SR(f32, any, bf16, any, fmt_order_any, spec_reference) + REG_SR(f32, any, bf16, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f16.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f16.cpp index d4da37cc42d..2b16fb79c9a 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_f16.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_f16.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,12 +28,13 @@ const impl_list_map_t ®ular_f32_f16_impl_list_map() { // f32 -> f16 {{f32, f16, 0}, { DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_SR(f32, any, f16, any, fmt_order::any, spec::reference) + nullptr, }}, }); diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp index b2f86ff709d..5b76e3ce9dd 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_f32.cpp @@ -1,7 +1,7 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2022-2024 FUJITSU LIMITED -* Copyright 2023 Arm Ltd. and affiliates +* Copyright 2023, 2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,31 +28,31 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> f32 {{f32, f32, 0}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(DNNL_ACL_ONLY(CPU_REORDER_INSTANCE(acl::acl_reorder_fwd_t))) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ACL_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::brgemm_matmul_copy_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY_F32_F32 - REG_SR(f32, any, f32, any, fmt_order_any, spec::reference) + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, {{f32, f32, 3}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::brgemm_matmul_matrix_B_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::brgemm_matmul_copy_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY_F32_F32 @@ -73,21 +73,21 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, IOw8o8i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, IOw16o16i)) - REG_SR(f32, any, f32, any, fmt_order_any, spec_reference) + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, {{f32, f32, 4}, { CPU_REORDER_INSTANCE(rnn_weights_reorder_t, f32, f32) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(DNNL_ACL_ONLY(CPU_REORDER_INSTANCE(acl::acl_reorder_fwd_t))) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ACL_ONLY(CPU_REORDER_INSTANCE(aarch64::acl_reorder_fwd_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY_F32_F32 @@ -126,20 +126,20 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i)) - REG_SR(f32, any, f32, any, fmt_order_any, spec_reference) + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, {{f32, f32, 5}, { CPU_REORDER_INSTANCE(rnn_weights_reorder_t, f32, f32) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY_F32_F32 @@ -181,17 +181,17 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i)) - REG_SR(f32, any, f32, any, fmt_order_any, spec_reference) + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, {{f32, f32, 6}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY_F32_F32 @@ -209,7 +209,7 @@ const impl_list_map_t ®ular_f32_f32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gIOdhw8o8i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, f32, gIOdhw16o16i)) - REG_SR(f32, any, f32, any, fmt_order_any, spec_reference) + REG_SR(f32, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_fp8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_fp8.cpp index dd642125d53..3613dabe3f1 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_fp8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_fp8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp index 7961f8f361b..347a19e94c0 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_s32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,17 +28,17 @@ const impl_list_map_t ®ular_f32_s32_impl_list_map() { // f32 -> s32 {{f32, s32, 0}, { DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY(f32, s32) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s32, nChw16c)) - REG_SR(f32, any, s32, any, fmt_order_any, spec_reference) + REG_SR(f32, any, s32, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp index a3878c5d630..9425dccdf38 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_s8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,20 +27,16 @@ const impl_list_map_t ®ular_f32_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // f32 -> s8 {{f32, s8, 0}, { - // TODO: move it down when checks for sparse md are implemented in other implementations. - DNNL_X64_ONLY(REG_SPARSE_SR(f32, oi, s8, OI16i64o4i, sparse_inputs_order::keep, sparse_spec::reference)) - DNNL_X64_ONLY(REG_SPARSE_SR(f32, format_tag::io, s8, OI16i64o4i, sparse_inputs_order::keep, sparse_spec::reference)) - CPU_REORDER_INSTANCE(rnn_data_reorder_t, f32, s8) CPU_REORDER_INSTANCE(rnn_weights_reorder_s8_t, f32) CPU_REORDER_INSTANCE(rnn_brgemm_weights_reorder_s8_t, f32, s8) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY(f32, s8) @@ -48,9 +44,9 @@ const impl_list_map_t ®ular_f32_s8_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i)) - REG_SR(f32, any, s8, any, fmt_order_any, spec_reference) + REG_SR(f32, any, s8, any, fmt_order::any, spec::reference) - REG_SPARSE_SR_X64(f32, any, s8, any) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(simple_sparse_reorder_t)) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp index 923e74a28ac..0f4007504b3 100644 --- a/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_f32_u8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,17 +30,21 @@ const impl_list_map_t ®ular_f32_u8_impl_list_map() { CPU_REORDER_INSTANCE(rnn_data_reorder_t, f32, u8) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) + + DNNL_PPC64_ONLY(CPU_REORDER_INSTANCE(ppc64::ppc64_matrixA_reorder_t)) + + DNNL_RV64GCV_ONLY(CPU_REORDER_INSTANCE(rv64::rvv_matrixA_reorder_t)) REG_FAST_DIRECT_COPY(f32, u8) DNNL_NON_X64_ONLY(REG_SR_BIDIR(f32, any, u8, nChw16c)) - REG_SR(f32, any, u8, any, fmt_order_any, spec_reference) + REG_SR(f32, any, u8, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_fp4.cpp b/src/cpu/reorder/cpu_reorder_regular_fp4.cpp index f97631ef80e..49e3a0ae604 100644 --- a/src/cpu/reorder/cpu_reorder_regular_fp4.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_fp4.cpp @@ -30,15 +30,6 @@ const impl_list_map_t ®ular_fp4_impl_list_map() { }}, {{f4_e2m1, data_type::undef, 0}, { REG_SR(f4_e2m1, any, f32, any, fmt_order::any, spec::reference) - REG_SR(f4_e2m1, any, f4_e2m1, OI8i8o2i, fmt_order_keep) - REG_SR(f4_e2m1, any, f4_e2m1, OI8i16o2i, fmt_order_keep) - REG_SR(f4_e2m1, any, f4_e2m1, OI8i24o2i, fmt_order_keep) - REG_SR(f4_e2m1, any, f4_e2m1, OI8i32o2i, fmt_order_keep) - REG_SR(f4_e2m1, any, f4_e2m1, OI8i64o2i, fmt_order_keep) - REG_SR(f4_e2m1, any, f4_e2m1, OI16i16o2i, fmt_order_keep) - REG_SR(f4_e2m1, any, f4_e2m1, OI16i32o2i, fmt_order_keep) - REG_SR(f4_e2m1, any, f4_e2m1, OI16i48o2i, fmt_order_keep) - REG_SR(f4_e2m1, any, f4_e2m1, OI16i64o2i, fmt_order_keep) nullptr, }}, {{f32, f4_e3m0, 0}, { diff --git a/src/cpu/reorder/cpu_reorder_regular_fp8.cpp b/src/cpu/reorder/cpu_reorder_regular_fp8.cpp index bd08fda826d..c27e029762c 100644 --- a/src/cpu/reorder/cpu_reorder_regular_fp8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_fp8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,12 +50,12 @@ const impl_list_map_t ®ular_fp8_impl_list_map() { nullptr, }}, - // f8_e8m0 -> - {{e8m0, data_type::undef, 0}, { - REG_SR(e8m0, any, e8m0, any, fmt_order::any, spec::reference) - + // e8m0 -> f32 + {{e8m0, f32, 0}, { + REG_SR(e8m0, any, f32, any, fmt_order::any, spec::reference) nullptr, }}, + }); return the_map; } diff --git a/src/cpu/reorder/cpu_reorder_regular_s32.cpp b/src/cpu/reorder/cpu_reorder_regular_s32.cpp index 30cd1392b37..6dd107e7ad1 100644 --- a/src/cpu/reorder/cpu_reorder_regular_s32.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_s32.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,11 +28,11 @@ const impl_list_map_t ®ular_s32_impl_list_map() { // s32 -> {{s32, data_type::undef, 0}, { DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY(s32, f32) REG_FAST_DIRECT_COPY(s32, s32) @@ -44,10 +44,10 @@ const impl_list_map_t ®ular_s32_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, s8, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s32, any, u8, nChw16c)) - REG_SR(s32, any, f32, any, fmt_order_any, spec_reference) - REG_SR(s32, any, s32, any, fmt_order_any, spec_reference) - REG_SR(s32, any, s8, any, fmt_order_any, spec_reference) - REG_SR(s32, any, u8, any, fmt_order_any, spec_reference) + REG_SR(s32, any, f32, any, fmt_order::any, spec::reference) + REG_SR(s32, any, s32, any, fmt_order::any, spec::reference) + REG_SR(s32, any, s8, any, fmt_order::any, spec::reference) + REG_SR(s32, any, u8, any, fmt_order::any, spec::reference) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_s4.cpp b/src/cpu/reorder/cpu_reorder_regular_s4.cpp index e7381490303..c42bf013e29 100644 --- a/src/cpu/reorder/cpu_reorder_regular_s4.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_s4.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,40 +29,12 @@ const impl_list_map_t ®ular_s4_impl_list_map() { nullptr, }}, {{s4, data_type::undef, 0}, { - REG_SR(s4, any, s4, OI8i8o2i, fmt_order_keep) - REG_SR(s4, any, s4, OI8i16o2i, fmt_order_keep) - REG_SR(s4, any, s4, OI8i24o2i, fmt_order_keep) - REG_SR(s4, any, s4, OI8i32o2i, fmt_order_keep) - REG_SR(s4, any, s4, OI8i64o2i, fmt_order_keep) - REG_SR(s4, any, s4, OI16i16o2i, fmt_order_keep) - REG_SR(s4, any, s4, OI16i32o2i, fmt_order_keep) - REG_SR(s4, any, s4, OI16i48o2i, fmt_order_keep) - REG_SR(s4, any, s4, OI16i64o2i, fmt_order_keep) - REG_SR(s4, any, s4, aBC8c8b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC8c16b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC8c24b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC8c32b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC8c64b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC16c16b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC16c32b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC16c48b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC16c64b2c, fmt_order_keep) - REG_SR(s4, any, s4, aBC16c16b4c, fmt_order_keep) - REG_SR(s4, any, s4, aBC16c32b4c, fmt_order_keep) - REG_SR(s4, any, s4, aBC16c48b4c, fmt_order_keep) - REG_SR(s4, any, s4, aBC16c64b4c, fmt_order_keep) - REG_SR(s4, any, u8, any, fmt_order_keep, spec::reference) - REG_SR(s4, any, f32, any, fmt_order_keep, spec::reference) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) REG_SR(s4, any, f32, any, fmt_order::any, spec::reference) REG_SR(s4, any, bf16, any, fmt_order::any, spec::reference) REG_SR(s4, any, f16, any, fmt_order::any, spec::reference) nullptr, }}, - {{s4, f32, 0}, { - REG_SR(s4, any, f32, any, fmt_order::any, spec::reference) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) - nullptr, - }}, }); return the_map; } diff --git a/src/cpu/reorder/cpu_reorder_regular_s8.cpp b/src/cpu/reorder/cpu_reorder_regular_s8.cpp index f434192a942..1e49b4c1942 100644 --- a/src/cpu/reorder/cpu_reorder_regular_s8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_s8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,13 +28,16 @@ const impl_list_map_t ®ular_s8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // s8 -> {{s8, data_type::undef, 0}, { - // // TODO: move it down when checks for sparse md are implemented in other implementations. - DNNL_X64_ONLY(REG_SPARSE_SR(s8, oi, s8, OI16i64o4i, sparse_inputs_order::keep, sparse_spec::reference)) - DNNL_X64_ONLY(REG_SPARSE_SR(s8, format_tag::io, s8, OI16i64o4i, sparse_inputs_order::keep, sparse_spec::reference)) + CPU_REORDER_INSTANCE(rnn_weights_reorder_s8_t, s8) + CPU_REORDER_INSTANCE(rnn_brgemm_weights_reorder_s8_t, s8, s8) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) - CPU_REORDER_INSTANCE(rnn_weights_reorder_s8_t,s8) - CPU_REORDER_INSTANCE(rnn_brgemm_weights_reorder_s8_t,s8, s8) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) + + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY(s8, f32) REG_FAST_DIRECT_COPY(s8, s32) @@ -43,13 +46,6 @@ const impl_list_map_t ®ular_s8_impl_list_map() { REG_FAST_DIRECT_COPY(s8, s8) REG_FAST_DIRECT_COPY(s8, u8) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) - - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) - DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, f32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, s32, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, bf16, nChw16c)) @@ -63,14 +59,14 @@ const impl_list_map_t ®ular_s8_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, bf16, gOIhw4i16o4i)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i)) - REG_SR(s8, any, f32, any, fmt_order_any, spec_reference) - REG_SR(s8, any, s32, any, fmt_order_any, spec_reference) - REG_SR(s8, any, bf16, any, fmt_order_any, spec_reference) - REG_SR(s8, any, f16, any, fmt_order_any, spec_reference) - REG_SR(s8, any, s8, any, fmt_order_any, spec_reference) - REG_SR(s8, any, u8, any, fmt_order_any, spec_reference) + REG_SR(s8, any, f32, any, fmt_order::any, spec::reference) + REG_SR(s8, any, s32, any, fmt_order::any, spec::reference) + REG_SR(s8, any, bf16, any, fmt_order::any, spec::reference) + REG_SR(s8, any, f16, any, fmt_order::any, spec::reference) + REG_SR(s8, any, s8, any, fmt_order::any, spec::reference) + REG_SR(s8, any, u8, any, fmt_order::any, spec::reference) - REG_SPARSE_SR_X64(s8, any, s8, any) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(simple_sparse_reorder_t)) nullptr, }}, diff --git a/src/cpu/reorder/cpu_reorder_regular_u4.cpp b/src/cpu/reorder/cpu_reorder_regular_u4.cpp index 09b45b2402c..280eaeadff4 100644 --- a/src/cpu/reorder/cpu_reorder_regular_u4.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_u4.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023-2024 Intel Corporation +* Copyright 2023 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,44 +29,12 @@ const impl_list_map_t ®ular_u4_impl_list_map() { nullptr, }}, {{u4, data_type::undef, 0}, { - REG_SR(u4, any, u4, OI8i8o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI8i16o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI8i24o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI8i32o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI8i64o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI16i16o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI16i32o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI16i48o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI16i64o2i, fmt_order_keep) - REG_SR(u4, any, u4, OI16i16o4i, fmt_order_keep) - REG_SR(u4, any, u4, OI16i32o4i, fmt_order_keep) - REG_SR(u4, any, u4, OI16i48o4i, fmt_order_keep) - REG_SR(u4, any, u4, OI16i64o4i, fmt_order_keep) - REG_SR(u4, any, u4, aBC8c8b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC8c16b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC8c24b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC8c32b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC8c64b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC16c16b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC16c32b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC16c48b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC16c64b2c, fmt_order_keep) - REG_SR(u4, any, u4, aBC16c16b4c, fmt_order_keep) - REG_SR(u4, any, u4, aBC16c32b4c, fmt_order_keep) - REG_SR(u4, any, u4, aBC16c48b4c, fmt_order_keep) - REG_SR(u4, any, u4, aBC16c64b4c, fmt_order_keep) - REG_SR(u4, any, u8, any, fmt_order_keep, spec::reference) - REG_SR(u4, any, f32, any, fmt_order_keep, spec::reference) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) REG_SR(u4, any, f32, any, fmt_order::any, spec::reference) REG_SR(u4, any, bf16, any, fmt_order::any, spec::reference) REG_SR(u4, any, f16, any, fmt_order::any, spec::reference) nullptr, }}, - {{u4, f32, 0}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) - REG_SR(u4, any, f32, any, fmt_order::any, spec::reference) - nullptr, - }}, }); return the_map; } diff --git a/src/cpu/reorder/cpu_reorder_regular_u8.cpp b/src/cpu/reorder/cpu_reorder_regular_u8.cpp index 97c5c135420..e3fcc659a28 100644 --- a/src/cpu/reorder/cpu_reorder_regular_u8.cpp +++ b/src/cpu/reorder/cpu_reorder_regular_u8.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Intel Corporation +* Copyright 2020 Intel Corporation * Copyright 2022 FUJITSU LIMITED * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,14 +27,14 @@ const impl_list_map_t ®ular_u8_impl_list_map() { static const impl_list_map_t the_map = REG_REORDER_P({ // u8 -> {{u8, data_type::undef, 0}, { - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_brgemm_matmul_copy_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::brgemm_matmul_copy_reorder_t)) DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_direct_copy_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) - DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64::jit_uni_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_blk_reorder_t)) - DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_blk_reorder_t)) + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64::jit_uni_reorder_t)) REG_FAST_DIRECT_COPY(u8, f32) REG_FAST_DIRECT_COPY(u8, s32) @@ -48,11 +48,11 @@ const impl_list_map_t ®ular_u8_impl_list_map() { DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, s8, nChw16c)) DNNL_NON_X64_ONLY(REG_SR_BIDIR(u8, any, u8, nChw16c)) - REG_SR(u8, any, f32, any, fmt_order_any, spec_reference) - REG_SR(u8, any, s32, any, fmt_order_any, spec_reference) - REG_SR(u8, any, bf16, any, fmt_order_any, spec_reference) - REG_SR(u8, any, u8, any, fmt_order_any, spec_reference) - REG_SR(u8, any, s8, any, fmt_order_any, spec_reference) + REG_SR(u8, any, f32, any, fmt_order::any, spec::reference) + REG_SR(u8, any, s32, any, fmt_order::any, spec::reference) + REG_SR(u8, any, bf16, any, fmt_order::any, spec::reference) + REG_SR(u8, any, u8, any, fmt_order::any, spec::reference) + REG_SR(u8, any, s8, any, fmt_order::any, spec::reference) nullptr, }},