Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 44 additions & 65 deletions tensorflow/core/kernels/mkl/mkl_matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License.
namespace tensorflow {

// Fuse Operation
template <typename Device, typename T, bool native_format = false>
template <typename Device, typename T>
class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {
public:
explicit MklFusedMatMulOp(OpKernelConstruction* ctx)
Expand Down Expand Up @@ -68,17 +68,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {
(void)SetFPMathMode();
}

MklDnnShape src_mkl_shape;
MklDnnShape weight_mkl_shape;
GetMklShape(ctx, this->kInputIndexSrc, &src_mkl_shape, native_format);
GetMklShape(ctx, this->kInputIndexWeight, &weight_mkl_shape, native_format);
OP_REQUIRES(
ctx, !weight_mkl_shape.IsMklTensor(),
absl::InvalidArgumentError("Weight should not be in MKL Layout"));

// Get shapes of input tensors
auto src_tf_shape = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape()
: src_tensor.shape();
auto src_tf_shape = src_tensor.shape();
auto weight_tf_shape = weight_tensor.shape();

// Check the constraint of input matrix and bias
Expand All @@ -90,42 +81,47 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {
OP_REQUIRES(ctx, bias_tensor.dim_size(i) == 1,
absl::InvalidArgumentError(
absl::StrCat("For bias_dims > 1, all except the "
"last dimension (channel) must be 1, got: ",
"last dimension (n) must be 1, got: ",
bias_tensor.shape().DebugString())));
}

// Expression: [batch, k] * [k, channel] + [channel] = [batch, channel]
// Expression: [m, k] * [k, n] + [n] = [m, n]
//
// Get dimension size of each matrix, dim_pair[] is the location of k
// in the inputs, we have constraint that k of the two inputs are
// the same
const int64_t dim_pair[] = {1, transpose_b_ ? 1 : 0};
const int64_t batch = src_tf_shape.dim_size(1 - dim_pair[0]);
const int64_t m = src_tf_shape.dim_size(1 - dim_pair[0]);
const int64_t k = src_tf_shape.dim_size(dim_pair[0]);
const int64_t channel = weight_tf_shape.dim_size(1 - dim_pair[1]);
const int64_t n = weight_tf_shape.dim_size(1 - dim_pair[1]);

OP_REQUIRES(
ctx, k == weight_tf_shape.dim_size(dim_pair[1]),
absl::InvalidArgumentError(absl::StrCat(
"Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(),
", In[1]: ", weight_tf_shape.DebugString())));
OP_REQUIRES(ctx, bias_tensor.dim_size(bias_tensor.dims() - 1) == channel,
OP_REQUIRES(ctx, bias_tensor.dim_size(bias_tensor.dims() - 1) == n,
absl::InvalidArgumentError(absl::StrCat(
"Must provide as many biases as the channel size: ",
bias_tensor.shape().DebugString(), " vs. ", channel)));
"Must provide as many biases as the n size: ",
bias_tensor.shape().DebugString(), " vs. ", n)));

// For inputs s[batch, k], w[k, channel] and b[channel], the primitive
// For inputs s[m, k], w[k, n] and b[n], the primitive
// dims should be described like this:
// s[batch, k] * w^T[channel, k] + b[channel] = dst[batch, channel]
// s[m, k] * w^T[n, k] + b[n] = dst[m, n]
// [n, ic] * [oc, ic] + [oc] = [n, oc]
memory::dims src_dims = memory::dims({batch, k});
// Reverse the weights dims from [k, channel] to [channel, k].
memory::dims weight_dims = memory::dims({channel, k});
memory::dims bias_dims = memory::dims({channel});
memory::dims dst_dims = memory::dims({batch, channel});
memory::format_tag src_format = memory::format_tag::nc;
// memory::dims src_dims = memory::dims({m, k});
// // Reverse the weights dims from [k, n] to [n, k].
// memory::dims weight_dims = memory::dims({n, k});
memory::dims src_dims = memory::dims({m, k});
memory::dims weight_dims = memory::dims({k, n});
// broadcast: this op used to call oneDNN inner-product op
// So bias input is 1-dimensional. Now this op calls oneDNN
// matmul op, thus here it should be 2-dimensional.
memory::dims bias_dims = memory::dims({1, n});
memory::dims dst_dims = memory::dims({m, n});
memory::format_tag src_format = memory::format_tag::ab;
memory::format_tag weight_format =
transpose_b_ ? memory::format_tag::oi : memory::format_tag::io;
transpose_b_ ? memory::format_tag::ba : memory::format_tag::ab;

// Set weight format `any` for primitive as per oneDNN recommendation.
MklDnnMatMulFwdParams matmul_params(
Expand All @@ -134,7 +130,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {
memory::format_tag::nc, this->is_weight_const_);
// Extend the basic parameters for data types and fusions.
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
auto st = ExecuteSingleThreadedGemm(batch, channel, k, sizeof(T));
auto st = ExecuteSingleThreadedGemm(m, n, k, sizeof(T));
// Create the oneDNN wrapper over Eigen threadpool and set max threads
// in oneDNN.
Eigen::ThreadPoolInterface* eigen_interface =
Expand All @@ -146,7 +142,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {

// Allocate output tensor.
Tensor* dst_tensor = nullptr;
std::shared_ptr<dnnl::inner_product_forward::primitive_desc> matmul_pd =
std::shared_ptr<dnnl::matmul::primitive_desc> matmul_pd =
matmul_prim->GetPrimitiveDesc();

// The output shape of MatMul is same both for MKL and TF version.
Expand All @@ -155,48 +151,33 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(false);

TensorShape output_tf_shape({batch, channel});
TensorShape output_tf_shape({m, n});

if (fuse_add_) {
const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add);
MklDnnShape add_mkl_shape;
GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format);

// For native format, we need not to set metadata.
if (native_format && ctx->forward_input_to_output_with_shape(
kInputIndex_Add, kOutputIndex_Dst,
output_tf_shape, &dst_tensor)) {
; // Need to do nothing for native format
} else if (!native_format && ForwardMklTensorInToOutWithMklShape(
ctx, kInputIndex_Add, kOutputIndex_Dst,
&dst_tensor, output_mkl_shape, false)) {
; // If it's not native format, need to forward and set meta first
} else {
// If forward is not successful, we should use reorder to copy add
if (!ctx->forward_input_to_output_with_shape(
kInputIndex_Add, kOutputIndex_Dst,
output_tf_shape, &dst_tensor)) {
// If forward is not successful, we should use reorder to copy add
// tensor to dst tensor
AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor,
output_tf_shape, output_mkl_shape,
native_format);
OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputIndex_Dst,
output_tf_shape, &dst_tensor));
auto output_format_tag =
MklTensorFormatToMklDnnDataFormat(MklTensorFormat::FORMAT_NC);
auto add_md =
add_mkl_shape.IsMklTensor()
? add_mkl_shape.GetMklLayout()
: memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);
auto dst_md =
memory::desc(dst_dims, MklDnnType<T>(), output_format_tag);

void* add_buf =
static_cast<void*>(const_cast<T*>(add_tensor.flat<T>().data()));
void* dst_buf = static_cast<void*>((dst_tensor)->flat<T>().data());

if (native_format) {
// We are simply deep copying the add_tensor to dst_tensor without
// changing memory layout, hence using same memory descriptor.
add_md = dst_md =
memory::desc({add_tensor.NumElements()}, MklDnnType<T>(),
dnnl::memory::format_tag::x);
}
// We are simply deep copying the add_tensor to dst_tensor without
// changing memory layout, hence using same memory descriptor.
add_md = dst_md =
memory::desc({add_tensor.NumElements()}, MklDnnType<T>(),
dnnl::memory::format_tag::x);

auto fuse_add_src_ = memory(add_md, this->cpu_engine_, add_buf);
auto fuse_add_dst_ = memory(dst_md, this->cpu_engine_, dst_buf);
Expand All @@ -207,12 +188,12 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {
this->cpu_engine_, ctx);
}
} else {
AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape,
output_mkl_shape, native_format);
OP_REQUIRES_OK(ctx,
ctx->allocate_output(0, output_tf_shape, &dst_tensor));
}

// if there's nothing to compute, just return.
if (batch == 0 || channel == 0) {
if (m == 0 || n == 0) {
return;
}

Expand All @@ -227,9 +208,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {
MklDnnData<T> src_mkl(&(this->cpu_engine_));
MklDnnData<T> weight_mkl(&(this->cpu_engine_));

auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), src_format);
auto src_md = memory::desc(src_dims, MklDnnType<T>(), src_format);

if (src_md != matmul_pd->src_desc()) {
src_mkl.SetUsrMem(src_md, src_data);
Expand Down Expand Up @@ -344,12 +323,12 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, void, T> {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklFusedMatMulOp<CPUDevice, type, true>);
MklFusedMatMulOp<CPUDevice, type>);
TF_CALL_float(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES);
TF_CALL_bfloat16(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES);
TF_CALL_half(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES);
#undef REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES

} // namespace tensorflow

#endif // INTEL_MKL
#endif // INTEL_MKL
24 changes: 11 additions & 13 deletions tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#endif

using dnnl::inner_product_forward;
using dnnl::matmul;
using dnnl::primitive_attr;
using dnnl::prop_kind;
using dnnl::stream;
Expand Down Expand Up @@ -188,7 +188,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(DummyData);
}

std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
std::shared_ptr<dnnl::matmul::primitive_desc>
GetPrimitiveDesc() const {
return context_.fwd_pd;
}
Expand All @@ -209,9 +209,9 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {

// Descriptor and primitive-descriptor for forward inner-product.
#ifndef ENABLE_ONEDNN_V3
std::shared_ptr<dnnl::inner_product_forward::desc> fwd_desc;
std::shared_ptr<dnnl::matmul::desc> fwd_desc;
#endif // !ENABLE_ONEDNN_V3
std::shared_ptr<dnnl::inner_product_forward::primitive_desc> fwd_pd;
std::shared_ptr<dnnl::matmul::primitive_desc> fwd_pd;

// Memory descriptors.
std::shared_ptr<dnnl::memory::desc> src_md;
Expand Down Expand Up @@ -283,12 +283,12 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
}
// Create an inner-product.
#ifndef ENABLE_ONEDNN_V3
context_.fwd_desc.reset(new inner_product_forward::desc(
context_.fwd_desc.reset(new matmul::desc(
matmul_fwd_params.const_weight ? prop_kind::forward_inference
: prop_kind::forward_training,
*context_.src_md, *context_.weight_md, *context_.bias_md,
*context_.dst_md));
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
context_.fwd_pd.reset(new matmul::primitive_desc(
*context_.fwd_desc, cpu_engine_));
#endif // !ENABLE_ONEDNN_V3

Expand Down Expand Up @@ -396,13 +396,11 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
}

#ifndef ENABLE_ONEDNN_V3
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
context_.fwd_pd.reset(new matmul::primitive_desc(
*context_.fwd_desc, post_ops_attr, cpu_engine_));
#else
context_.fwd_pd.reset(new inner_product_forward::primitive_desc(
context_.fwd_pd.reset(new matmul::primitive_desc(
cpu_engine_,
matmul_fwd_params.const_weight ? prop_kind::forward_inference
: prop_kind::forward_training,
*context_.src_md, *context_.weight_md, *context_.bias_md,
*context_.dst_md, post_ops_attr));
#endif // !ENABLE_ONEDNN_V3
Expand All @@ -421,7 +419,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
new dnnl::memory(scratchpad_md, cpu_engine_, DummyData));

// Create inner-product primitive.
context_.matmul_fwd.reset(new inner_product_forward(*context_.fwd_pd));
context_.matmul_fwd.reset(new matmul(*context_.fwd_pd));
std::unordered_map<int, memory> net_args = {
{DNNL_ARG_SRC, *context_.src_mem},
{DNNL_ARG_WEIGHTS, *context_.weight_mem},
Expand Down Expand Up @@ -561,7 +559,7 @@ class MklDnnMatMulOpBase : public OpKernel {
// Allocate output tensor.
virtual void AllocateOutputTensor(
OpKernelContext* context,
const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc,
const matmul::primitive_desc& mkldnn_matmul_prim_desc,
const memory::dims& output_dims_mkl_order,
MklTensorFormat output_tf_format, Tensor** output_tensor,
bool native_format = false) {
Expand Down Expand Up @@ -599,7 +597,7 @@ class MklDnnMatMulOpBase : public OpKernel {
// Only one thread can execute this method at any given time.
void CacheWeight(
OpKernelContext* context,
const std::shared_ptr<dnnl::inner_product_forward::primitive_desc>&
const std::shared_ptr<dnnl::matmul::primitive_desc>&
matmul_fwd_pd,
Tweight* weight_data, const Tensor& weight_tensor,
MklDnnData<Tweight>& weight, const memory::desc& weight_md)
Expand Down
6 changes: 3 additions & 3 deletions tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class MklDnnQuantizedMatMulOp
Toutput>::Get(matmul_fwd_dims, 0);

// Allocate output Tensor.
std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
std::shared_ptr<dnnl::matmul::primitive_desc>
matmul_fwd_pd = matmul_fwd->GetPrimitiveDesc();
this->AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order,
input_output_fmt_mkldnn, &dst_tensor,
Expand Down Expand Up @@ -515,7 +515,7 @@ class MklDnnQuantizedMatMulOp
#ifndef ENABLE_ONEDNN_V3
Tbias* GetBiasHandle(
OpKernelContext* context,
std::shared_ptr<dnnl::inner_product_forward::primitive_desc>&
std::shared_ptr<dnnl::matmul::primitive_desc>&
mkldnn_matmul_fwd_pd,
const Tensor& bias_tensor, const Tensor& weight_tensor,
std::shared_ptr<stream> reorder_stream) {
Expand Down Expand Up @@ -621,7 +621,7 @@ class MklDnnQuantizedMatMulOp
#else
void GetBiasHandle(
OpKernelContext* context,
std::shared_ptr<dnnl::inner_product_forward::primitive_desc>&
std::shared_ptr<dnnl::matmul::primitive_desc>&
mkldnn_matmul_fwd_pd,
const Tensor& bias_tensor, const Tensor& weight_tensor,
std::shared_ptr<stream> reorder_stream, Tensor* temp_scaled_bias_tensor,
Expand Down