Skip to content

Missing gemm_batch data types #446

@AidanBeltonS

Description

@AidanBeltonS

Summary

I believe there are some missing gemm_batch implementations, looking at the oneMKL docs it seems this should support. A gemm_batch with, two half matrices as input, a float matrix out, and float scaling. My reference: https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-dpcpp/2023-0/gemm-batch.html
I run into issues of this overload not being found. Is my documentation correct, or have I misunderstood something?

Version

oneMKL hash: 7d2044e

Environment

oneMKL works with multiple HW and backend libraries and also depends on the
compiler and build environment. Include
the following information to help reproduce the issue:

  • HW: A100 GPU
  • Backend: cuBlas
  • OS: Ubuntu 20.04
  • Compiler version: DPC++ 2024.0.2

Steps to reproduce

Compile with for NVidia GPUs: icpx -fsycl -fsycl-targets=nvptx64-nvidia-cuda reproducer_onemkl_batch.cpp -lonemkl
or for Intel GPUs: icpx -fsycl reproducer_onemkl_batch.cpp -lonemkl

#include <sycl/sycl.hpp>
#include <oneapi/mkl.hpp>

template <class Ta, class Tb, class Tc, class Ts>
void run_gemm(sycl::queue q) {
    // Construct some arbitrary data, error is in compilation, so it does not have to be correct.
    const Ta *a[4] = {nullptr};
    const Tb *b[4] = {nullptr};
    Tc *c[4] = {nullptr};

    int64_t batch_size = 4;

    oneapi::mkl::transpose a_trans = oneapi::mkl::transpose::trans;
    oneapi::mkl::transpose b_trans = oneapi::mkl::transpose::nontrans;

    int64_t m = 10;
    int64_t n = 10;
    int64_t k = 10;

    int64_t lda = 10;
    int64_t ldb = 10;
    int64_t ldc = 10;

    int64_t group_size = 1;

    Ts alpha = 1;
    Ts beta = 0;
    oneapi::mkl::transpose *trans =
        reinterpret_cast<oneapi::mkl::transpose *>( 
            std::malloc(sizeof(oneapi::mkl::transpose) * 2 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      trans[batch + batch_size * 0] = a_trans;
      trans[batch + batch_size * 1] = b_trans;
    }   

    // structured m, n, k, lda, ldb, ldc, group_size
    int64_t *dims = reinterpret_cast<int64_t *>( 
        std::malloc(sizeof(int64_t) * 7 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      dims[batch + batch_size * 0] = m;
      dims[batch + batch_size * 1] = n;
      dims[batch + batch_size * 2] = k;

      dims[batch + batch_size * 3] = lda;
      dims[batch + batch_size * 4] = ldb;
      dims[batch + batch_size * 5] = ldc;

      dims[batch + batch_size * 6] = group_size;
    }   

    // structured alpha, beta
    Ts *coeff =
        reinterpret_cast<Ts *>(std::malloc(sizeof(Ts) * 2 * batch_size));
    for (int batch = 0; batch < batch_size; ++batch) {
      coeff[batch + batch_size * 0] = 1;
      coeff[batch + batch_size * 1] = 0;
    }


    oneapi::mkl::blas::column_major::gemm_batch(
        q, trans + batch_size * 0 /*a_trans*/,
        trans + batch_size * 1 /*b_trans*/, dims + batch_size * 0 /*m*/,
        dims + batch_size * 1 /*n*/, dims + batch_size * 2 /*k*/,
        coeff + batch_size * 0 /*alpha*/,
        reinterpret_cast<const Ta **>(a), dims + batch_size * 3 /*lda*/,
        reinterpret_cast<const Tb **>(b), dims + batch_size * 4 /*ldb*/,
        coeff + batch_size * 1 /*beta*/, reinterpret_cast<Tc **>(c),
        dims + batch_size * 5 /*ldc*/, batch_size,
        dims + batch_size * 6 /*group_size*/);
}

int main() {
    sycl::queue q;
    //run_gemm<float, float, float, float>(q); // Compiles
    run_gemm<sycl::half, sycl::half, float, float>(q); // Fails to compile
}

Error:

reproducer_onemkl_batch.cpp:60:5: error: no matching function for call to 'gemm_batch'
   60 |     oneapi::mkl::blas::column_major::gemm_batch(
      |     ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
reproducer_onemkl_batch.cpp:75:5: note: in instantiation of function template specialization 'run_gemm<sycl::detail::half_impl::half, sycl::detail::half_impl::half, float, float>' requested here
   75 |     run_gemm<sycl::half, sycl::half, float, float>(q);

Given the documentation I linked to above, I would expect this to compile. As the docs express that this combination of data types are supported.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions