From eb4700b10f510284ab94c8bfdcce66a416849303 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 8 Dec 2020 01:46:35 +0000 Subject: [PATCH 1/6] Transformer-kernel - supporting any arbitrary sequence-length --- csrc/transformer/ds_transformer_cuda.cpp | 7 ++--- deepspeed/ops/transformer/transformer.py | 23 +++++++++++++--- tests/unit/test_cuda_backward.py | 32 +++++++++------------- tests/unit/test_cuda_forward.py | 34 +++++++++--------------- 4 files changed, 48 insertions(+), 48 deletions(-) mode change 100644 => 100755 csrc/transformer/ds_transformer_cuda.cpp diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp old mode 100644 new mode 100755 index 85ec0418971c..c7bf99522218 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -14,6 +14,8 @@ static std::unordered_map> s_transformer_layers; +const int init_seq_length = 128; + // C++ interface template @@ -591,7 +593,6 @@ int create_transformer_layer(int layer_id, int hidden_dim, int num_heads, int intermediate_size, - int seq_length, float attn_dropout_ratio, float hidden_dropout_ratio, int seed, @@ -604,14 +605,14 @@ int create_transformer_layer(int layer_id, { Context::Instance().SetSeed(seed); Context::Instance().TestGemmFP16( - test_gemm, batch_size, seq_length, num_heads, hidden_dim / num_heads); + test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads); auto layer = std::make_shared>(layer_id, batch_size, hidden_dim, num_heads, intermediate_size, - seq_length, + init_seq_length, attn_dropout_ratio, hidden_dropout_ratio, pre_or_postLayerNorm, diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index a91e5ce6f08b..2c9df33ae77e 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -18,7 +18,6 @@ class TransformerConfig(): def __init__(self, batch_size, - max_seq_length, hidden_size, intermediate_size, heads, @@ -92,7 +91,6 @@ class DeepSpeedTransformerConfig(TransformerConfig): """ def __init__(self, batch_size=-1, - max_seq_length=-1, hidden_size=-1, intermediate_size=-1, heads=-1, @@ -112,7 +110,6 @@ def __init__(self, super(DeepSpeedTransformerConfig, self).__init__( batch_size, - max_seq_length, hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, @@ -177,6 +174,16 @@ def forward(ctx, cuda_module = stochastic_transformer_cuda_module if config.stochastic_mode else transformer_cuda_module forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32 + inp_size = input.size() + if inp_size[1] % 8 != 0: + input = torch.cat((input, + torch.randn(inp_size[0], + (8 - (inp_size[1] % 8)), + inp_size[2]).to(input.device).type_as(input)), + 1) + input_mask = torch.cat((input_mask, torch.ones(inp_size[0], input_mask.shape[1], input_mask.shape[2], \ + (8 - (inp_size[1] % 8))).to(input_mask.device).type_as(input_mask) * -10000), 3) + (output, inp_norm, qkv_tf, @@ -303,11 +310,17 @@ def forward(ctx, ctx.attn_layer_norm_var = attn_layer_norm_var ctx.layer_norm_var = layer_norm_var + if inp_size[1] % 8 != 0: + output = torch.narrow(output, 1, 0, inp_size[1]) return output @staticmethod def backward(ctx, grad_output): bsz = grad_output.shape[0] + grad_output_shape = grad_output.size() + if grad_output_shape[1] % 8 != 0: + grad_output = torch.cat((grad_output, torch.zeros(bsz, (8 - (grad_output_shape[1] % 8)), \ + grad_output_shape[2]).to(grad_output.device).type_as(grad_output)), 1) if bsz > ctx.config.batch_size: raise ValueError('grad_output batch size exceeds the limit.') @@ -398,6 +411,9 @@ def backward(ctx, grad_output): norm_w, norm_b) + if grad_output_shape[1] % 8 != 0: + grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1]) + return (grad_input, None, None, @@ -501,7 +517,6 @@ def __init__(self, layer_id, config, initial_weights=None, initial_biases=None): self.config.hidden_size, self.config.heads, self.config.intermediate_size, - self.config.max_seq_length, self.config.attn_dropout_ratio, self.config.hidden_dropout_ratio, self.config.seed, diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py index 317cd7aa33c0..fd3f9887ad42 100755 --- a/tests/unit/test_cuda_backward.py +++ b/tests/unit/test_cuda_backward.py @@ -150,7 +150,7 @@ def create_models(ds_config): hidden_act="gelu", hidden_dropout_prob=ds_config.hidden_dropout_ratio, attention_probs_dropout_prob=ds_config.attn_dropout_ratio, - max_position_embeddings=ds_config.max_seq_length, + max_position_embeddings=512, type_vocab_size=2, initializer_range=ds_config.initializer_range) @@ -210,25 +210,18 @@ def set_seed(seed): torch.manual_seed(seed) -def run_backward(ds_config, atol=1e-2, verbose=False): +def run_backward(ds_config, seq_len, atol=1e-2, verbose=False): set_seed(123) bert_encoder, ds_encoder = create_models(ds_config) # prepare test data kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32 hidden_states = torch.randn(ds_config.batch_size, - ds_config.max_seq_length, + seq_len, ds_config.hidden_size, **kwargs) - input_mask = torch.randn(ds_config.batch_size, - 1, - 1, - ds_config.max_seq_length, - **kwargs) - Y = torch.randn(ds_config.batch_size, - ds_config.max_seq_length, - ds_config.hidden_size, - **kwargs) + input_mask = torch.randn(ds_config.batch_size, 1, 1, seq_len, **kwargs) + Y = torch.randn(ds_config.batch_size, seq_len, ds_config.hidden_size, **kwargs) # run baseline base_results = bert_encoder(hidden_states, @@ -257,12 +250,12 @@ def run_backward(ds_config, atol=1e-2, verbose=False): #test_backward[3-1024-120-16-24-True-True-0.05] @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', [ - (3,1024,120,16,24,True,False, 0.05), - (3,1024,120,16,24,True,True, 0.05), - (3,1024,56,16,24,False,False, 0.1), - (3,1024,56,16,24,False,True, 0.2), - (3,128,56,2,24,False,False, 0.1), - (3,128,56,2,24,False,True, 0.2), + (3,1024,119,16,24,True,False, 0.05), + (3,1024,115,16,24,True,True, 0.05), + (1024,128,10,2,2,False,False, 0.1), + (3,1024,52,16,24,False,True, 0.2), + (3,128,51,2,24,False,False, 0.1), + (3,128,54,2,24,False,True, 0.2), ]) # yapf: disable def test_backward(batch_size, hidden_size, @@ -282,7 +275,6 @@ def test_backward(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 @@ -291,7 +283,7 @@ def test_backward(batch_size, ds_config.initializer_range = 0.02 ds_config.fp16 = use_fp16 - run_backward(ds_config, atol=atol) + run_backward(ds_config, seq_len, atol=atol) #@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16, atol', diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index 893b66c904bb..88cb90848603 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -117,7 +117,7 @@ def create_models(ds_config): hidden_act="gelu", hidden_dropout_prob=ds_config.hidden_dropout_ratio, attention_probs_dropout_prob=ds_config.attn_dropout_ratio, - max_position_embeddings=ds_config.max_seq_length, + max_position_embeddings=512, type_vocab_size=2, initializer_range=ds_config.initializer_range, fp16=ds_config.fp16) @@ -186,13 +186,8 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): # prepare test data kwargs = kwargs_fp16 if ds_config.fp16 else kwargs_fp32 - hidden_states = torch.randn(bsz, - seq_len, #ds_config.max_seq_length, - ds_config.hidden_size, - **kwargs) - input_mask = torch.randn(bsz, 1, 1, - seq_len, #ds_config.max_seq_length, - **kwargs) + hidden_states = torch.randn(bsz, seq_len, ds_config.hidden_size, **kwargs) + input_mask = torch.randn(bsz, 1, 1, seq_len, **kwargs) # run baseline base_results = bert_encoder(hidden_states, @@ -213,25 +208,25 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None): # FP16 test cases can only run on the devices support FP16. @pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16', [ - (8,256,128,4,3,True,False), - (8,256,128,4,3,True,True), - (64,1024,128,16,3,True,False), - (64,1024,128,16,3,True,True), - (8,1024,384,16,3,True,False), + (8,256,53,4,3,True,False), + (8,256,52,4,3,True,True), + (3,1024,51,16,3,True,False), + (3,1024,54,16,3,True,True), + (8,1024,381,16,3,True,False), (8,1024,384,16,3,True,True), (8,1024,384,16,3,True,True), - (8,1024,120,16,3,True,False), + (8,1024,119,16,3,True,False), (8,1024,120,16,3,True,True), - (8,1024,512,16,3,True,False), + (8,1024,509,16,3,True,False), (8,1024,512,16,3,True,True), (64,1024,56,16,3,False,False), - (64,1024,56,16,3,False,True), + (64,1024,53,16,3,False,True), (64,1024,24,16,3,False,False), - (64,1024,24,16,3,False,True), + (64,1024,21,16,3,False,True), (8,1024,384,16,3,False,False), (8,1024,384,16,3,False,True), (8,1024,512,16,3,False,False), - (8,1024,512,16,3,False,True), + (8,1024,511,16,3,False,True), (8,1536,128,24,3,False,False), (8,1536,128,24,3,False,True), (8,2048,128,32,3,False,False), @@ -259,7 +254,6 @@ def test_forward(batch_size, ds_config.layer_id = None ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size - ds_config.max_seq_length = 128 #seq_len ds_config.intermediate_size = 4 * hidden_size ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 @@ -297,7 +291,6 @@ def test_forward_with_small_bsz(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = 4 * hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 @@ -332,7 +325,6 @@ def test_forward_stochastic(batch_size, ds_config.batch_size = batch_size ds_config.hidden_size = hidden_size ds_config.intermediate_size = 4 * hidden_size - ds_config.max_seq_length = seq_len ds_config.heads = heads ds_config.attn_dropout_ratio = 0.0 ds_config.hidden_dropout_ratio = 0.0 From 06743480f071a95eec63f1a01159bdfa28490c9a Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 8 Dec 2020 02:38:33 +0000 Subject: [PATCH 2/6] remove seq-len from transformer config --- deepspeed/ops/transformer/transformer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index 2c9df33ae77e..ca1e3ffd0899 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -29,7 +29,6 @@ def __init__(self, self.batch_size = batch_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.max_seq_length = max_seq_length self.heads = heads self.attn_dropout_ratio = attn_dropout_ratio self.hidden_dropout_ratio = hidden_dropout_ratio From 0659f9284734627cd89b18a6c11434c2a215f8ee Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 8 Dec 2020 07:11:34 +0000 Subject: [PATCH 3/6] pad seq-len to be 16-aligned --- deepspeed/ops/transformer/transformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index ca1e3ffd0899..f2f8938cccea 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -138,7 +138,7 @@ def from_dict(cls, json_object): @classmethod def from_json_file(cls, json_file): - with open(json_file, "r", encoding='utf-8') as reader: + with open(json_file, "r", encoding='utf-16') as reader: text = reader.read() return cls.from_dict(json.loads(text)) @@ -174,14 +174,14 @@ def forward(ctx, forward_func = cuda_module.forward_fp16 if config.fp16 else cuda_module.forward_fp32 inp_size = input.size() - if inp_size[1] % 8 != 0: + if inp_size[1] % 16 != 0: input = torch.cat((input, torch.randn(inp_size[0], - (8 - (inp_size[1] % 8)), + (16 - (inp_size[1] % 16)), inp_size[2]).to(input.device).type_as(input)), 1) input_mask = torch.cat((input_mask, torch.ones(inp_size[0], input_mask.shape[1], input_mask.shape[2], \ - (8 - (inp_size[1] % 8))).to(input_mask.device).type_as(input_mask) * -10000), 3) + (16 - (inp_size[1] % 16))).to(input_mask.device).type_as(input_mask) * -10000), 3) (output, inp_norm, @@ -309,7 +309,7 @@ def forward(ctx, ctx.attn_layer_norm_var = attn_layer_norm_var ctx.layer_norm_var = layer_norm_var - if inp_size[1] % 8 != 0: + if inp_size[1] % 16 != 0: output = torch.narrow(output, 1, 0, inp_size[1]) return output @@ -317,8 +317,8 @@ def forward(ctx, def backward(ctx, grad_output): bsz = grad_output.shape[0] grad_output_shape = grad_output.size() - if grad_output_shape[1] % 8 != 0: - grad_output = torch.cat((grad_output, torch.zeros(bsz, (8 - (grad_output_shape[1] % 8)), \ + if grad_output_shape[1] % 16 != 0: + grad_output = torch.cat((grad_output, torch.zeros(bsz, (16 - (grad_output_shape[1] % 16)), \ grad_output_shape[2]).to(grad_output.device).type_as(grad_output)), 1) if bsz > ctx.config.batch_size: @@ -410,7 +410,7 @@ def backward(ctx, grad_output): norm_w, norm_b) - if grad_output_shape[1] % 8 != 0: + if grad_output_shape[1] % 16 != 0: grad_input = torch.narrow(grad_input, 1, 0, grad_output_shape[1]) return (grad_input, From cb15de6bc0012b305d7784370c652ef767c237e0 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 9 Dec 2020 03:08:25 +0000 Subject: [PATCH 4/6] resolve the issue with softmax forward when sequence is low --- csrc/transformer/ds_transformer_cuda.cpp | 6 ++++++ csrc/transformer/softmax_kernels.cu | 20 ++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp index c7bf99522218..ebd534d04ab3 100755 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -874,6 +874,12 @@ std::vector ds_transformer_backward(int layer_id, std::shared_ptr> layer = std::static_pointer_cast>(s_transformer_layers[layer_id]); + int seq_len = layer->GetSeqLength(); + if (g_output.size(1) != seq_len) { + seq_len = g_output.size(1); + layer->SetSeqLength(seq_len, bsz); + } + auto grad_input = torch::empty_like(input); auto grad_attn_qkvw = torch::empty_like(attn_qkvw); auto grad_attn_qkvb = torch::empty_like(attn_qkvb); diff --git a/csrc/transformer/softmax_kernels.cu b/csrc/transformer/softmax_kernels.cu index 582da4829f47..be776b0c074d 100644 --- a/csrc/transformer/softmax_kernels.cu +++ b/csrc/transformer/softmax_kernels.cu @@ -80,7 +80,8 @@ __global__ void attn_softmax(float* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { auto temp = g.shfl_xor(max_val, i); @@ -113,7 +114,8 @@ __global__ void attn_softmax(float* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } @@ -216,7 +218,8 @@ __global__ void attn_softmax(__half* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { auto temp = g.shfl_xor(max_val, i); @@ -252,7 +255,8 @@ __global__ void attn_softmax(__half* vals, #endif int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } @@ -339,7 +343,9 @@ void launch_attn_softmax(float* vals, dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / subblock_max_workload * threads) : threads); - + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); if (sequence_length <= 512) attn_softmax<32, (threads / 128), 128><<>>( vals, attn_mask, heads, seq_length4, iterations); @@ -408,7 +414,9 @@ void launch_attn_softmax<__half>(__half* vals, dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / subblock_max_workload * threads) : threads); - + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); if (sequence_length <= 512) attn_softmax<32, (threads / 128), 128><<>>( vals, attn_mask, heads, seq_length4, iterations); From 3b34bcca182e111dee1dda4ffddd318a54285528 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 16 Dec 2020 17:59:22 +0000 Subject: [PATCH 5/6] make the padding more efficient --- deepspeed/ops/transformer/transformer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/deepspeed/ops/transformer/transformer.py b/deepspeed/ops/transformer/transformer.py index f2f8938cccea..ea4b98848d3c 100755 --- a/deepspeed/ops/transformer/transformer.py +++ b/deepspeed/ops/transformer/transformer.py @@ -176,12 +176,14 @@ def forward(ctx, inp_size = input.size() if inp_size[1] % 16 != 0: input = torch.cat((input, - torch.randn(inp_size[0], - (16 - (inp_size[1] % 16)), - inp_size[2]).to(input.device).type_as(input)), + torch.randn((inp_size[0], + (16 - (inp_size[1] % 16)), + inp_size[2]), + device=input.device, + dtype=input.dtype)), 1) - input_mask = torch.cat((input_mask, torch.ones(inp_size[0], input_mask.shape[1], input_mask.shape[2], \ - (16 - (inp_size[1] % 16))).to(input_mask.device).type_as(input_mask) * -10000), 3) + input_mask = torch.cat((input_mask, torch.ones((inp_size[0], input_mask.shape[1], input_mask.shape[2], \ + (16 - (inp_size[1] % 16))), device=input_mask.device, dtype=input_mask.dtype) * -10000), 3) (output, inp_norm, @@ -318,8 +320,8 @@ def backward(ctx, grad_output): bsz = grad_output.shape[0] grad_output_shape = grad_output.size() if grad_output_shape[1] % 16 != 0: - grad_output = torch.cat((grad_output, torch.zeros(bsz, (16 - (grad_output_shape[1] % 16)), \ - grad_output_shape[2]).to(grad_output.device).type_as(grad_output)), 1) + grad_output = torch.cat((grad_output, torch.zeros((bsz, (16 - (grad_output_shape[1] % 16)), \ + grad_output_shape[2]), device=grad_output.device, dtype=grad_output.dtype)), 1) if bsz > ctx.config.batch_size: raise ValueError('grad_output batch size exceeds the limit.') From 23c70a3b4ae372e6eeb5aed7a0769db0e4e56cb3 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 17 Dec 2020 09:05:16 -0800 Subject: [PATCH 6/6] bump DSE to support this PR --- DeepSpeedExamples | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DeepSpeedExamples b/DeepSpeedExamples index fa1d1a71c486..abb270641ca8 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit fa1d1a71c48623db8a091d9cf636a5fe3b8f43c7 +Subproject commit abb270641ca8c33476282bde29916c395a060ae9