Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions ggml_blas_adapter.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,44 @@ cl_platform_id platform;
cl_device_id device;
cl_context context;
cl_command_queue queue;
cl_event event;
bool cl_initialized = false;

static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a, const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k, const float alpha, const float *host_a, const int lda, const float *host_b, const int ldb, const float beta, float *host_c, const int ldc) {
cl_int err = 0;

cl_event events[3];
events[0] = NULL;
events[1] = NULL;
events[2] = NULL;

if (!cl_initialized) {
char * KCPP_CLBLAST_PLATFORM = getenv("KCPP_CLBLAST_PLATFORM");
char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES");
char * KCPP_CLBLAST_DEVICES = getenv("KCPP_CLBLAST_DEVICES");
int plat_num = (KCPP_CLBLAST_PLATFORM == NULL ? 0 : atoi(KCPP_CLBLAST_PLATFORM));
int dev_num = (KCPP_CLBLAST_DEVICES == NULL ? 0 : atoi(KCPP_CLBLAST_DEVICES));
printf("\nInitializing CLBlast (First Run)...");
printf("\nSelected: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
printf("\nAttempting to use: Platform=%d, Device=%d (If invalid, program will crash)\n",plat_num,dev_num);
cl_uint num_platforms;
clGetPlatformIDs(0, NULL, &num_platforms);
cl_platform_id* platforms = (cl_platform_id*)malloc(num_platforms*sizeof(cl_platform_id));
clGetPlatformIDs(num_platforms, platforms, NULL);
platform = platforms[plat_num];
char platform_buffer[1024];
clGetPlatformInfo(platform, CL_PLATFORM_NAME, sizeof(platform_buffer), &platform_buffer, NULL);
cl_uint num_devices;
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices);
cl_device_id* devices = (cl_device_id*)malloc(num_devices*sizeof(cl_device_id));
clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL);
device = devices[dev_num];
char device_buffer[1024];
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(device_buffer), &device_buffer, NULL);
printf("Using Platform: %s Device: %s\n", platform_buffer, device_buffer);
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
if (err != CL_SUCCESS) {
printf("Error creating OpenCL context: %d\n", err);
fflush(stdout);
}
queue = clCreateCommandQueue(context, device, 0, &err);
event = NULL;

if (err != CL_SUCCESS) {
printf("Error creating OpenCL Command Queue: %d\n", err);
Expand All @@ -56,16 +64,17 @@ static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS

free(platforms);
free(devices);

cl_initialized = true;
}

// Prepare buffers
cl_mem cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_WRITE, m*k*sizeof(float), NULL, &err);
cl_mem cl_buffer_a = clCreateBuffer(context, CL_MEM_READ_ONLY, m*k*sizeof(float), NULL, &err);
if (err != CL_SUCCESS) {
printf("Error creating OpenCL Buffer A: %d\n", err);
fflush(stdout);
}
cl_mem cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_WRITE, n*k*sizeof(float), NULL, &err);
cl_mem cl_buffer_b = clCreateBuffer(context, CL_MEM_READ_ONLY, n*k*sizeof(float), NULL, &err);
if (err != CL_SUCCESS) {
printf("Error creating OpenCL Buffer B: %d\n", err);
fflush(stdout);
Expand All @@ -76,9 +85,13 @@ static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS
fflush(stdout);
}

clEnqueueWriteBuffer(queue, cl_buffer_a, CL_TRUE, 0, m*k*sizeof(float), host_a, 0, NULL, NULL);
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_TRUE, 0, n*k*sizeof(float), host_b, 0, NULL, NULL);
clEnqueueWriteBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL);
clEnqueueWriteBuffer(queue, cl_buffer_a, CL_FALSE, 0, m*k*sizeof(float), host_a, 0, NULL, events);
clEnqueueWriteBuffer(queue, cl_buffer_b, CL_FALSE, 0, n*k*sizeof(float), host_b, 0, NULL, events + 1);
clEnqueueWriteBuffer(queue, cl_buffer_c, CL_FALSE, 0, m*n*sizeof(float), host_c, 0, NULL, events + 2);
clWaitForEvents(3, events);
clReleaseEvent(events[0]);
clReleaseEvent(events[1]);
clReleaseEvent(events[2]);

// Call the SGEMM routine.
CLBlastStatusCode status = CLBlastSgemm(order,
Expand All @@ -89,16 +102,17 @@ static void ggml_cl_sgemm_wrapper(const enum CBLAS_ORDER order, const enum CBLAS
cl_buffer_b, 0, ldb,
beta,
cl_buffer_c, 0, ldc,
&queue, &event);
&queue, events);

clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 1, events, events + 1);

// Wait for completion
if (status == CLBlastSuccess) {
clWaitForEvents(1, &event);
clReleaseEvent(event);
clWaitForEvents(2, events);
clReleaseEvent(events[0]);
clReleaseEvent(events[1]);
}

clEnqueueReadBuffer(queue, cl_buffer_c, CL_TRUE, 0, m*n*sizeof(float), host_c, 0, NULL, NULL);

clReleaseMemObject(cl_buffer_a);
clReleaseMemObject(cl_buffer_b);
clReleaseMemObject(cl_buffer_c);
Expand All @@ -117,4 +131,4 @@ ggml_cl_sgemm_wrapper(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, bet
cblas_sgemm(Order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);\
})
#endif
#endif
#endif