Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion colossalai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@
launch_from_slurm, launch_from_torch, get_default_parser)

__version__ = '0.0.1'

7 changes: 4 additions & 3 deletions colossalai/builder/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = []
for start, end in partitions[pipeline_rank]:
module_list.append(
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)]))
module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)],
*layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)]))
if verbose:
logger = get_dist_logger()
logger.info(f'Total {len(layers)} layers', ranks=[0])
Expand All @@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]

79 changes: 48 additions & 31 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"

#include <iostream>
#include <math.h>
#include <memory>
#include <omp.h>
#include <string.h>
#include <torch/extension.h>

#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>

Expand Down Expand Up @@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,

for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t;

#pragma omp parallel for
Expand Down Expand Up @@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
if ((t + TILE) > _param_size)
copy_size = _param_size - t;
size_t offset = copy_size + t;

#pragma omp parallel for
Expand Down Expand Up @@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,

for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t;

#pragma omp parallel for
Expand Down Expand Up @@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
s_optimizers[optimizer_id] = opt;

if (should_log) {

std::string avx_type = "";
#if defined(__AVX512__)
avx_type = "AVX512";
Expand Down Expand Up @@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,

for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t;

#pragma omp parallel for
Expand Down Expand Up @@ -460,29 +463,43 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_half_precision, loss_scale);
}

int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay, bool bias_correction,
torch::Tensor &params, torch::Tensor &grads,
torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq,
float loss_scale) {
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();

float *params_ptr = (float *)params_c.data_ptr();
float *grads_ptr = (float *)grads_c.data_ptr();
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf), loss_scale);

return 0;
int adam_step(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
float loss_scale)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();

float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
(params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
loss_scale);

return 0;
}

int destroy_adam_optimizer(int optimizer_id) {
Expand Down
32 changes: 13 additions & 19 deletions colossalai/kernel/cuda_native/csrc/cpu_adam.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ SOFTWARE
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \
#define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \
#define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))

#elif defined(__AVX256__) or defined(__AVX2__)
Expand All @@ -66,8 +66,8 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \
#define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))

#endif
Expand All @@ -83,25 +83,19 @@ union AVX_Data {

#endif

#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
#define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1);

class Adam_Optimizer {
public:
public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps),
_weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0),
_adamw_mode(adamw_mode) {}
~Adam_Optimizer() {}

Expand Down Expand Up @@ -141,7 +135,7 @@ class Adam_Optimizer {
}
}

private:
private:
float _alpha;
float _betta1;
float _betta2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel(
const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
int target_tid = targets[blockIdx.x];

if (target_tid == padding_idx) {
Expand Down
46 changes: 30 additions & 16 deletions colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include <cooperative_groups.h>

#include <chrono>
#include <ctime>

#include "kernels.h"

#include <cooperative_groups.h>

namespace cg = cooperative_groups;

curandStatePhilox4_32_10_t *curandstate;
Expand Down Expand Up @@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;

uint8_t m[4];

Expand All @@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;

float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
Expand Down Expand Up @@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel(

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();

for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < 32; i <<= 1)
sum += g.shfl_down(sum, i);

if (y == 0) tile[0][x] = sum;
if (y == 0)
tile[0][x] = sum;
__syncthreads();

if (threadIdx.x < 8) {
Expand Down Expand Up @@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();

for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);

if (y == 0) tile[0][x] = sum;
if (y == 0)
tile[0][x] = sum;
__syncthreads();

if (threadIdx.x < 8) {
Expand Down Expand Up @@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel(

int i = blockIdx.x * blockDim.x + threadIdx.x;

if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;

curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
Expand Down Expand Up @@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();

for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);

if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
if (threadIdx.x == 0)
tile[0][threadIdx.y] = sum;
__syncthreads();

if (threadIdx.y == 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <cooperative_groups.h>

#include "kernels.h"

#include <cooperative_groups.h>

namespace cg = cooperative_groups;

/**
Expand Down
Loading