Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,7 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
else
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Bdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));

//Set User Preference attributes
int64_t max_workspace_size = 32 * 1024 * 1024 * 4;
void* d_workspace;
//NEED HIP CHECK ERROR
//hipMalloc(&d_workspace, max_workspace_size);
const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel

if(DTYPE_OUT == 32)
{
Expand Down Expand Up @@ -580,17 +576,14 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
heuristicResult,
&returnedAlgoCount));

auto toMalloc = max(heuristicResult[0].workspaceSize, max_workspace_size);

//printf("\n\n1Got algosn: %d %d %d\n\n",returnedAlgoCount, heuristicResult[0].workspaceSize, toMalloc);
//NEED HIP CHECK ERROR
auto err = hipMalloc(&d_workspace, toMalloc);
//printf("Hipmalloc\n");
//printf(hipError_to_string(err).c_str());
//printf("\n");
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, toMalloc, 0));
//hipStreamSynchronize(0);
hipFree(d_workspace);
if (returnedAlgoCount == 0)
{
has_error = 1;
}
else
{
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
}
}
else
{
Expand Down Expand Up @@ -622,23 +615,19 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
heuristicResult,
&returnedAlgoCount));

//NEED HIP CHECK ERROR
hipMalloc(&d_workspace, heuristicResult[0].workspaceSize);
if(!SCALE_ROWS)
{
float alpha = 1.0f, beta = 0.0f;

has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, max_workspace_size, 0));
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
}
else
{
//has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, hipblasLt_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
float beta = 0.0f;

has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, max_workspace_size, 0));
has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0));
}

hipFree(d_workspace);
}


Expand Down