Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <stdexcept>
#include <stdio.h>
#include <stdlib.h>

#include <stdexcept>

#define MAX_THREADS 1024
#define WARP_SIZE 32

Expand Down Expand Up @@ -132,8 +133,9 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3,
}

/* Convert 4-dim tensor index into vector index */
__forceinline__ __host__ __device__ int
flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) {
__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3,
int id4, int dim2, int dim3,
int dim4) {
// return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4;
int res = id4;

Expand Down Expand Up @@ -201,9 +203,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3,
}

/* Convert vector index to 6-dim tensor index */
__forceinline__ __host__ __device__ void
decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5,
int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) {
__forceinline__ __host__ __device__ void decompose_6dim(
int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0,
int *id1, int *id2, int *id3, int *id4, int *id5) {
*id5 = src % dim5;
src /= dim5;

Expand All @@ -221,9 +223,11 @@ decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5,
}

/* Convert vector index to 5-dim tensor index */
__forceinline__ __host__ __device__ void
decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0,
int *id1, int *id2, int *id3, int *id4) {
__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1,
int dim2, int dim3,
int dim4, int *id0,
int *id1, int *id2,
int *id3, int *id4) {
*id4 = src % dim4;
src /= dim4;

Expand Down Expand Up @@ -253,8 +257,9 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1,
}

/* Convert vector index to 3-dim tensor index */
__forceinline__ __host__ __device__ void
decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) {
__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1,
int dim2, int *id0,
int *id1, int *id2) {
*id2 = src % dim2;
src /= dim2;

Expand Down