From d124ed79a4b640d000197e3fea20d6058e88a127 Mon Sep 17 00:00:00 2001 From: nijkah Date: Fri, 24 Mar 2023 21:09:34 +0900 Subject: [PATCH] [Lint] Run lint by pre-commit --- .github/ISSUE_TEMPLATE/config.yml | 2 +- .github/ISSUE_TEMPLATE/feature_request.yml | 2 +- .github/workflows/doc_test_on_pr.yml | 2 +- .../example_checks/check_dispatch_inputs.py | 54 +- .../example_checks/check_example_weekly.py | 74 +- .../example_checks/detect_changed_example.py | 48 +- CONTRIBUTING.md | 2 +- .../ChatGPT/chatgpt/dataset/__init__.py | 6 +- .../ChatGPT/chatgpt/dataset/reward_dataset.py | 3 + .../ChatGPT/chatgpt/dataset/sft_dataset.py | 49 +- applications/ChatGPT/chatgpt/dataset/utils.py | 4 +- .../ChatGPT/chatgpt/models/__init__.py | 2 +- .../ChatGPT/chatgpt/models/base/__init__.py | 2 +- .../ChatGPT/chatgpt/models/base/lm.py | 5 +- .../ChatGPT/chatgpt/models/bloom/__init__.py | 2 +- .../ChatGPT/chatgpt/models/bloom/bloom_lm.py | 1 - .../ChatGPT/chatgpt/models/bloom/bloom_rm.py | 2 +- .../ChatGPT/chatgpt/models/gpt/__init__.py | 2 +- .../ChatGPT/chatgpt/models/gpt/gpt_lm.py | 1 - .../ChatGPT/chatgpt/models/gpt/gpt_rm.py | 2 +- .../ChatGPT/chatgpt/models/llama/__init__.py | 2 +- .../ChatGPT/chatgpt/models/llama/llama_lm.py | 3 +- applications/ChatGPT/chatgpt/models/lora.py | 1 - applications/ChatGPT/chatgpt/models/loss.py | 2 + .../ChatGPT/chatgpt/models/opt/__init__.py | 2 +- .../ChatGPT/chatgpt/models/opt/opt_lm.py | 1 - .../ChatGPT/chatgpt/models/opt/opt_rm.py | 2 +- applications/ChatGPT/chatgpt/trainer/rm.py | 22 +- applications/ChatGPT/chatgpt/trainer/sft.py | 13 +- .../chatgpt/trainer/strategies/colossalai.py | 2 +- .../ChatGPT/chatgpt/trainer/strategies/ddp.py | 6 +- .../ChatGPT/chatgpt/utils/__init__.py | 4 +- .../ChatGPT/chatgpt/utils/tokenizer_utils.py | 26 +- applications/ChatGPT/examples/test_ci.sh | 2 +- .../ChatGPT/examples/train_reward_model.py | 35 +- applications/ChatGPT/examples/train_sft.py | 11 +- colossalai/amp/__init__.py | 10 +- .../auto_parallel/offload/amp_optimizer.py | 11 +- .../offload/base_offload_module.py | 10 +- .../auto_parallel/offload/mem_optimize.py | 19 +- colossalai/auto_parallel/offload/region.py | 13 +- .../auto_parallel/offload/region_manager.py | 73 +- colossalai/auto_parallel/offload/runtime.py | 31 +- colossalai/auto_parallel/offload/solver.py | 63 +- .../offload/training_simulator.py | 76 +- colossalai/auto_parallel/offload/util.py | 18 +- .../strategy/unary_elementwise_generator.py | 2 +- colossalai/cli/benchmark/models.py | 1 + colossalai/cli/benchmark/utils.py | 11 +- colossalai/cli/check/__init__.py | 1 + colossalai/communication/__init__.py | 18 +- colossalai/communication/p2p.py | 12 +- colossalai/communication/utils.py | 3 +- colossalai/context/__init__.py | 2 +- colossalai/context/config.py | 1 + colossalai/context/parallel_context.py | 3 +- .../process_group_initializer/__init__.py | 2 +- .../initializer_1d.py | 1 + .../initializer_2p5d.py | 1 + .../initializer_3d.py | 7 +- .../initializer_data.py | 3 +- .../initializer_model.py | 4 +- .../initializer_tensor.py | 3 +- colossalai/context/random/__init__.py | 15 +- colossalai/context/random/_helper.py | 2 +- colossalai/core.py | 2 +- .../engine/gradient_accumulation/__init__.py | 15 +- .../_gradient_accumulation.py | 13 +- .../engine/gradient_handler/__init__.py | 5 +- .../_base_gradient_handler.py | 2 +- .../_data_parallel_gradient_handler.py | 7 +- .../gradient_handler/_moe_gradient_handler.py | 91 +- .../_pipeline_parallel_gradient_handler.py | 7 +- .../_sequence_parallel_gradient_handler.py | 7 +- .../_zero_gradient_handler.py | 1 + colossalai/engine/schedule/__init__.py | 2 +- colossalai/engine/schedule/_base_schedule.py | 2 +- .../engine/schedule/_non_pipeline_schedule.py | 9 +- .../engine/schedule/_pipeline_schedule_v2.py | 13 +- .../adding_shape_consistency_pass.py | 12 +- colossalai/fx/passes/passes_for_gpt2_test.py | 15 +- colossalai/fx/passes/shard_1d_pass.py | 10 +- colossalai/fx/passes/split_module.py | 19 +- colossalai/fx/passes/utils.py | 11 +- .../profiler_function/activation_function.py | 2 + .../profiler_function/embedding.py | 4 +- .../experimental/profiler_function/linear.py | 2 + .../profiler_function/normalization.py | 2 + .../experimental/profiler_function/pooling.py | 2 + .../profiler_function/python_ops.py | 2 + .../profiler_function/torch_ops.py | 4 +- .../profiler_module/activation_function.py | 2 + .../experimental/profiler_module/attention.py | 2 + .../experimental/profiler_module/dropout.py | 2 + .../experimental/profiler_module/embedding.py | 4 +- .../experimental/profiler_module/linear.py | 2 + .../experimental/profiler_module/pooling.py | 2 + .../experimental/profiler_module/rnn.py | 6 +- .../experimental/profiler_module/torch_op.py | 4 +- colossalai/fx/proxy.py | 6 +- colossalai/fx/tracer/_tracer_utils.py | 6 +- .../meta_patch/patched_module/__init__.py | 2 +- .../gemini/ophooks/_shard_param_ophook.py | 1 + .../gemini/paramhooks/_param_hookmgr.py | 3 +- colossalai/gemini/stateful_tensor.py | 6 +- colossalai/gemini/stateful_tensor_mgr.py | 12 +- colossalai/gemini/tensor_placement_policy.py | 14 +- colossalai/gemini/tensor_utils.py | 4 +- colossalai/global_variables.py | 112 +- colossalai/kernel/cuda_native/csrc/compat.h | 2 +- .../cuda_native/csrc/kernels/cuda_util.cu | 1 - .../csrc/kernels/dropout_kernels.cu | 2003 +++-- .../csrc/kernels/general_kernels.cu | 464 +- .../csrc/kernels/include/dropout.h | 192 +- .../csrc/kernels/include/kernels.h | 27 +- .../csrc/kernels/include/normalize_layer.h | 129 +- .../csrc/kernels/include/softmax.h | 84 +- .../csrc/kernels/normalize_kernels.cu | 2341 +++--- .../csrc/kernels/softmax_kernels.cu | 730 +- .../csrc/kernels/transform_kernels.cu | 626 +- .../cuda_native/csrc/layer_norm_cuda.cpp | 2 +- .../csrc/layer_norm_cuda_kernel.cu | 2 +- .../kernel/cuda_native/csrc/moe_cuda.cpp | 194 +- .../cuda_native/csrc/moe_cuda_kernel.cu | 1318 ++-- .../csrc/multi_tensor_l2norm_kernel.cu | 2 +- .../cuda_native/csrc/multi_tensor_lamb.cu | 2 +- .../csrc/multi_tensor_scale_kernel.cu | 2 +- .../csrc/multi_tensor_sgd_kernel.cu | 2 +- .../csrc/scaled_masked_softmax.cpp | 84 +- .../cuda_native/csrc/scaled_masked_softmax.h | 868 +- .../csrc/scaled_upper_triang_masked_softmax.h | 928 ++- colossalai/kernel/jit/__init__.py | 7 +- colossalai/kernel/jit/bias_dropout_add.py | 8 +- colossalai/logging/__init__.py | 2 +- colossalai/nn/_ops/_utils.py | 11 +- colossalai/nn/_ops/addmm.py | 18 +- colossalai/nn/_ops/embedding.py | 8 +- colossalai/nn/_ops/embedding_bag.py | 8 +- colossalai/nn/_ops/layernorm.py | 5 +- colossalai/nn/_ops/loss.py | 9 +- colossalai/nn/init.py | 2 +- colossalai/nn/layer/__init__.py | 2 +- colossalai/nn/layer/base_layer.py | 3 +- .../nn/layer/colossalai_layer/__init__.py | 14 +- .../nn/layer/colossalai_layer/embedding.py | 303 +- .../layer/colossalai_layer/normalization.py | 83 +- colossalai/nn/layer/moe/__init__.py | 18 +- colossalai/nn/layer/moe/experts.py | 345 +- colossalai/nn/layer/moe/layers.py | 413 +- colossalai/nn/layer/moe/routers.py | 453 +- colossalai/nn/layer/moe/utils.py | 138 +- colossalai/nn/layer/parallel_1d/__init__.py | 14 +- colossalai/nn/layer/parallel_1d/_operation.py | 1 + colossalai/nn/layer/parallel_1d/_utils.py | 3 +- colossalai/nn/layer/parallel_2d/__init__.py | 11 +- colossalai/nn/layer/parallel_2d/_operation.py | 9 +- colossalai/nn/layer/parallel_2d/layers.py | 17 +- colossalai/nn/layer/parallel_2p5d/__init__.py | 11 +- .../nn/layer/parallel_2p5d/_operation.py | 7 +- colossalai/nn/layer/parallel_2p5d/layers.py | 24 +- colossalai/nn/layer/parallel_3d/__init__.py | 11 +- .../nn/layer/parallel_sequence/__init__.py | 2 +- .../nn/layer/parallel_sequence/_operation.py | 4 +- .../nn/layer/parallel_sequence/layers.py | 10 +- colossalai/nn/layer/utils/__init__.py | 22 +- colossalai/nn/layer/utils/common.py | 3 +- .../nn/layer/wrapper/pipeline_wrapper.py | 6 +- colossalai/nn/loss/__init__.py | 5 +- colossalai/nn/loss/loss_1d.py | 211 +- colossalai/nn/loss/loss_2d.py | 7 +- colossalai/nn/loss/loss_2p5d.py | 7 +- colossalai/nn/loss/loss_3d.py | 9 +- colossalai/nn/loss/loss_moe.py | 161 +- colossalai/nn/lr_scheduler/__init__.py | 2 +- colossalai/nn/lr_scheduler/cosine.py | 1 + colossalai/nn/lr_scheduler/multistep.py | 1 + colossalai/nn/lr_scheduler/poly.py | 1 + colossalai/nn/lr_scheduler/torch.py | 2 +- colossalai/nn/metric/__init__.py | 54 +- colossalai/nn/metric/_utils.py | 14 +- colossalai/nn/metric/accuracy_2d.py | 3 +- colossalai/nn/metric/accuracy_2p5d.py | 3 +- colossalai/nn/metric/accuracy_3d.py | 68 +- colossalai/nn/optimizer/__init__.py | 4 +- .../nn/optimizer/colossalai_optimizer.py | 1 + colossalai/nn/optimizer/lars.py | 33 +- colossalai/nn/optimizer/nvme_optimizer.py | 7 +- colossalai/nn/parallel/layers/__init__.py | 17 +- .../layers/cache_embedding/__init__.py | 4 +- .../layers/cache_embedding/base_embedding.py | 1 + .../layers/cache_embedding/cache_mgr.py | 20 +- .../cache_embedding/cached_embedding.py | 11 +- .../parallel/layers/cache_embedding/copyer.py | 4 +- .../parallel_cached_embedding.py | 7 +- .../parallel_cached_embedding_tablewise.py | 13 +- ..._cached_embedding_tablewise_split_cache.py | 14 +- colossalai/nn/parallel/layers/colo_module.py | 5 +- colossalai/nn/parallel/layers/embedding.py | 3 +- colossalai/nn/parallel/layers/linear.py | 3 +- colossalai/nn/parallel/layers/module_utils.py | 8 +- colossalai/pipeline/__init__.py | 4 +- colossalai/pipeline/layer_spec.py | 6 +- colossalai/pipeline/middleware/__init__.py | 4 +- .../pipeline/middleware/adaptor/__init__.py | 2 +- colossalai/pipeline/middleware/adaptor/fx.py | 30 +- colossalai/pipeline/middleware/topo.py | 86 +- colossalai/pipeline/pipelinable.py | 23 +- colossalai/pipeline/pipeline_process_group.py | 4 +- colossalai/pipeline/rpc/__init__.py | 4 +- colossalai/pipeline/utils.py | 9 +- colossalai/registry/registry.py | 4 +- colossalai/tensor/op_wrapper.py | 5 +- colossalai/testing/__init__.py | 4 +- colossalai/testing/pytest_wrapper.py | 5 +- colossalai/testing/utils.py | 15 +- colossalai/trainer/_trainer.py | 5 +- colossalai/trainer/hooks/__init__.py | 9 +- colossalai/trainer/hooks/_checkpoint_hook.py | 3 +- colossalai/trainer/hooks/_log_hook.py | 10 +- .../trainer/hooks/_lr_scheduler_hook.py | 3 +- colossalai/trainer/hooks/_metric_hook.py | 11 +- colossalai/utils/activation_checkpoint.py | 9 +- colossalai/utils/checkpoint/__init__.py | 2 +- .../utils/checkpoint/module_checkpoint.py | 10 +- colossalai/utils/checkpoint/utils.py | 127 +- colossalai/utils/checkpoint_io/__init__.py | 2 +- colossalai/utils/checkpoint_io/convertor.py | 2 +- colossalai/utils/checkpoint_io/distributed.py | 6 +- colossalai/utils/checkpoint_io/io.py | 9 +- colossalai/utils/checkpoint_io/meta.py | 2 +- colossalai/utils/checkpoint_io/writer.py | 12 +- colossalai/utils/checkpointing.py | 4 +- colossalai/utils/cuda.py | 2 +- .../data_sampler/data_parallel_sampler.py | 24 +- colossalai/utils/memory.py | 12 +- colossalai/utils/moe.py | 107 +- colossalai/utils/profiler/legacy/__init__.py | 12 +- .../utils/profiler/legacy/comm_profiler.py | 619 +- .../utils/profiler/legacy/pcie_profiler.py | 298 +- .../utils/profiler/legacy/prof_utils.py | 263 +- colossalai/utils/profiler/profiler.py | 18 +- .../profiler/stateful_tensor_mem_extention.py | 8 +- colossalai/utils/rank_recorder/README.md | 8 +- colossalai/utils/rank_recorder/__init__.py | 2 +- .../utils/rank_recorder/rank_recorder.py | 13 +- colossalai/utils/tensor_detector/__init__.py | 2 +- colossalai/utils/tensor_detector/readme.md | 3 +- .../utils/tensor_detector/tensor_detector.py | 5 +- colossalai/utils/timer.py | 1 + .../zero/shard_utils/base_shard_strategy.py | 1 + .../bucket_tensor_shard_strategy.py | 9 +- colossalai/zero/shard_utils/commons.py | 3 +- .../zero/shard_utils/tensor_shard_strategy.py | 5 +- colossalai/zero/sharded_model/__init__.py | 2 +- colossalai/zero/sharded_model/_utils.py | 4 +- colossalai/zero/sharded_model/utils.py | 5 +- colossalai/zero/sharded_param/__init__.py | 2 +- .../zero/sharded_param/sharded_param.py | 9 +- .../zero/sharded_param/sharded_tensor.py | 1 + colossalai/zero/utils/__init__.py | 2 +- examples/images/diffusion/README.md | 2 +- .../images/diffusion/configs/train_ddp.yaml | 2 +- examples/images/diffusion/ldm/data/cifar10.py | 60 +- .../images/diffusion/ldm/data/imagenet.py | 102 +- examples/images/diffusion/ldm/data/lsun.py | 45 +- examples/images/diffusion/ldm/data/teyvat.py | 49 +- examples/images/diffusion/ldm/lr_scheduler.py | 27 +- .../diffusion/ldm/models/autoencoder.py | 64 +- .../ldm/models/diffusion/classifier.py | 52 +- .../diffusion/ldm/models/diffusion/ddim.py | 202 +- .../models/diffusion/dpm_solver/__init__.py | 2 +- .../models/diffusion/dpm_solver/dpm_solver.py | 387 +- .../models/diffusion/dpm_solver/sampler.py | 66 +- .../diffusion/ldm/models/diffusion/plms.py | 177 +- .../ldm/models/diffusion/sampling_util.py | 4 +- .../images/diffusion/ldm/modules/attention.py | 143 +- .../modules/diffusionmodules/openaimodel.py | 145 +- .../ldm/modules/diffusionmodules/upscaling.py | 23 +- .../modules/distributions/distributions.py | 37 +- examples/images/diffusion/ldm/modules/ema.py | 5 +- .../diffusion/ldm/modules/encoders/modules.py | 90 +- .../ldm/modules/image_degradation/bsrgan.py | 118 +- .../modules/image_degradation/bsrgan_light.py | 77 +- .../modules/image_degradation/utils_image.py | 133 +- .../images/diffusion/ldm/modules/midas/api.py | 93 +- .../ldm/modules/midas/midas/base_model.py | 1 + .../ldm/modules/midas/midas/blocks.py | 177 +- .../ldm/modules/midas/midas/dpt_depth.py | 16 +- .../ldm/modules/midas/midas/midas_net.py | 4 +- .../modules/midas/midas/midas_net_custom.py | 90 +- .../ldm/modules/midas/midas/transforms.py | 41 +- .../diffusion/ldm/modules/midas/midas/vit.py | 107 +- .../diffusion/ldm/modules/midas/utils.py | 22 +- examples/images/diffusion/ldm/util.py | 64 +- .../scripts/download_first_stages.sh | 2 +- examples/images/diffusion/scripts/img2img.py | 84 +- examples/images/diffusion/scripts/inpaint.py | 47 +- examples/images/diffusion/scripts/knn2img.py | 124 +- .../diffusion/scripts/sample_diffusion.py | 153 +- .../scripts/tests/test_checkpoint.py | 11 +- .../diffusion/scripts/tests/test_watermark.py | 2 +- .../diffusion/scripts/train_searcher.py | 63 +- examples/images/diffusion/scripts/txt2img.py | 155 +- examples/images/diffusion/scripts/utils.py | 55 +- examples/images/diffusion/setup.py | 4 +- examples/images/diffusion/train_ddp.sh | 6 +- examples/images/dreambooth/colossalai.sh | 6 +- examples/images/dreambooth/inference.py | 2 +- .../gpt/experiments/auto_offload/model_zoo.py | 11 +- .../experiments/auto_offload/requirements.txt | 2 +- .../auto_offload/train_gpt_offload.py | 35 +- examples/language/opt/train_gemini_opt.py | 16 +- examples/language/roberta/README.md | 4 +- .../roberta/configs/colossalai_ddp.py | 2 +- .../roberta/configs/colossalai_zero.py | 4 +- .../language/roberta/preprocessing/README.md | 10 +- .../roberta/preprocessing/get_mask.py | 71 +- .../language/roberta/preprocessing/mask.cpp | 310 +- .../roberta/preprocessing/sentence_split.py | 60 +- .../roberta/preprocessing/tokenize_mask.py | 110 +- .../language/roberta/pretraining/README.md | 5 +- .../language/roberta/pretraining/arguments.py | 182 +- .../pretraining/bert_dataset_provider.py | 1 + .../roberta/pretraining/evaluation.py | 31 +- examples/language/roberta/pretraining/loss.py | 2 +- .../roberta/pretraining/model/bert.py | 142 +- .../roberta/pretraining/model/deberta_v2.py | 158 +- .../nvidia_bert_dataset_provider.py | 74 +- .../roberta/pretraining/pretrain_utils.py | 62 +- .../roberta/pretraining/run_pretrain.sh | 1 - .../pretraining/run_pretrain_resume.sh | 1 - .../roberta/pretraining/run_pretraining.py | 130 +- .../roberta/pretraining/utils/WandbLog.py | 14 +- .../roberta/pretraining/utils/exp_util.py | 27 +- .../roberta/pretraining/utils/global_vars.py | 18 +- .../roberta/pretraining/utils/logger.py | 13 +- examples/tutorial/fp8/mnist/README.md | 26 +- examples/tutorial/fp8/mnist/main.py | 56 +- examples/tutorial/opt/inference/batch.py | 4 +- .../opt/inference/benchmark/locustfile.py | 4 +- examples/tutorial/opt/inference/cache.py | 5 +- .../tutorial/opt/inference/opt_fastapi.py | 33 +- examples/tutorial/opt/inference/opt_server.py | 47 +- .../script/process-opt-175b/README.md | 1 - .../script/process-opt-175b/convert_ckpt.py | 3 +- .../script/process-opt-175b/flat-meta.json | 6945 ++++++++++++++++- .../inference/script/processing_ckpt_66b.py | 25 +- examples/tutorial/opt/opt/run_clm.py | 24 +- .../sequence_parallel/data/__init__.py | 30 +- .../sequence_parallel/data/bert_helper.py | 23 +- .../data/datasets/blendable_dataset.py | 5 +- .../data/datasets/builder.py | 84 +- .../data/datasets/data_samplers.py | 6 +- .../data/datasets/dataset_utils.py | 183 +- .../data/datasets/helpers.cpp | 1163 ++- .../data/datasets/ict_dataset.py | 54 +- .../data/datasets/indexed_dataset.py | 77 +- .../datasets/test/test_indexed_dataset.py | 46 +- .../data/dummy_dataloader.py | 2 +- .../data/tokenizer/__init__.py | 2 - .../data/tokenizer/bert_tokenization.py | 65 +- .../data/tokenizer/tokenizer.py | 28 +- .../sequence_parallel/loss_func/bert_loss.py | 26 +- .../loss_func/cross_entropy.py | 12 +- .../sequence_parallel/loss_func/utils.py | 13 +- .../lr_scheduler/annealing_lr.py | 21 +- .../sequence_parallel/model/__init__.py | 2 - .../tutorial/sequence_parallel/model/bert.py | 60 +- .../model/layers/__init__.py | 2 +- .../model/layers/bert_layer.py | 24 +- .../sequence_parallel/model/layers/dropout.py | 5 +- .../model/layers/embedding.py | 22 +- .../sequence_parallel/model/layers/head.py | 23 +- .../model/layers/init_method.py | 4 +- .../sequence_parallel/model/layers/linear.py | 15 +- .../sequence_parallel/model/layers/mlp.py | 13 +- .../sequence_parallel/model/layers/pooler.py | 1 + .../model/layers/preprocess.py | 7 +- op_builder/cpu_adam.py | 5 +- op_builder/fused_optim.py | 2 +- op_builder/moe.py | 5 +- op_builder/multi_head_attn.py | 4 +- op_builder/scaled_masked_softmax.py | 13 +- .../scaled_upper_triangle_masked_softmax.py | 8 +- op_builder/utils.py | 19 +- pytest.ini | 2 +- tests/components_to_test/resnet.py | 11 +- .../test_analyzer/test_fx/test_nested_ckpt.py | 2 +- .../test_subclasses/test_aten.py | 2 +- .../test_offload/model_utils.py | 28 +- .../test_offload/test_perf.py | 51 +- .../test_offload/test_solver.py | 20 +- .../test_node_handler/test_where_handler.py | 5 +- .../test_comm/test_boardcast_send_recv_v2.py | 7 +- tests/test_comm/test_comm.py | 3 +- tests/test_comm/test_object_list_p2p.py | 12 +- tests/test_comm/test_object_list_p2p_v2.py | 9 +- tests/test_config/sample_config.py | 38 +- .../test_context/configs/parallel_2d_init.py | 8 +- .../configs/parallel_2p5d_init.py | 9 +- .../test_context/configs/parallel_3d_init.py | 8 +- tests/test_context/test_hybrid_parallel.py | 5 +- tests/test_data/test_cifar10_dataset.py | 2 +- tests/test_data/test_data_parallel_sampler.py | 6 +- .../test_deterministic_dataloader.py | 7 +- .../test_cifar_with_data_pipeline_tensor.py | 15 +- ...test_cifar_with_data_pipeline_tensor_v2.py | 219 +- tests/test_ddp/test_ddp_state_dict.py | 13 +- tests/test_ddp/test_reducer.py | 14 +- tests/test_device/test_device_mesh.py | 3 +- tests/test_device/test_init_logical_pg.py | 7 +- tests/test_engine/test_engine.py | 5 +- .../test_engine/test_gradient_accumluation.py | 13 +- tests/test_fx/test_coloproxy.py | 5 +- tests/test_fx/test_comm_size_compute.py | 9 +- tests/test_fx/test_graph_manipulation.py | 7 +- tests/test_fx/test_meta/test_aten.py | 1 + tests/test_fx/test_meta/test_backward.py | 1 + tests/test_fx/test_meta/test_meta_trace.py | 1 + tests/test_fx/test_meta_info_prop.py | 3 +- tests/test_fx/test_parallel_1d.py | 9 +- .../test_pipeline/test_hf_model/hf_utils.py | 11 +- .../test_timm_model/timm_utils.py | 11 +- .../test_pipeline/test_topo/test_topo.py | 17 +- .../test_pipeline/test_topo/topo_utils.py | 29 +- tests/test_fx/test_pipeline_passes.py | 14 +- .../test_fx/test_tracer/test_control_flow.py | 1 + .../test_tracer/test_functional_conv.py | 1 + .../test_tracer/test_patched_module.py | 1 + tests/test_fx/test_tracer/test_patched_op.py | 4 +- tests/test_gemini/test_gemini_manager.py | 146 +- tests/test_layers/test_1d/checks_1d/common.py | 31 +- .../test_2d/checks_2d/check_layer_2d.py | 25 +- .../test_2d/checks_2d/check_operation_2d.py | 6 +- tests/test_layers/test_2d/test_2d.py | 23 +- .../test_2p5d/checks_2p5d/check_layer_2p5d.py | 25 +- .../checks_2p5d/check_operation_2p5d.py | 7 +- .../test_2p5d/checks_2p5d/common.py | 2 +- tests/test_layers/test_2p5d/test_2p5d.py | 7 +- tests/test_layers/test_3d/checks_3d/common.py | 2 +- tests/test_layers/test_3d/test_3d.py | 19 +- tests/test_layers/test_cache_embedding.py | 32 +- .../test_sequence/test_sequence.py | 11 +- tests/test_moe/test_grad_handler.py | 12 +- tests/test_moe/test_kernel.py | 10 +- tests/test_moe/test_moe_colo_init.py | 123 +- tests/test_moe/test_moe_group.py | 12 +- tests/test_moe/test_moe_zero_init.py | 226 +- tests/test_ops/test_addmm_tp.py | 15 +- tests/test_ops/test_embedding_bag_tp.py | 9 +- tests/test_ops/test_embedding_tp.py | 9 +- tests/test_ops/test_linear_tp.py | 7 +- tests/test_ops/test_loss_func.py | 104 +- tests/test_ops/test_op.py | 15 +- tests/test_ops/test_view.py | 201 +- tests/test_optimizer/test_fused_adam.py | 2 +- tests/test_optimizer/test_hybrid_adam.py | 2 +- tests/test_optimizer/test_nvme.py | 3 +- tests/test_pipeline/rpc_test_utils.py | 14 +- tests/test_pipeline/test_cuda_rpc_chimera.py | 4 +- .../test_pipeline/test_cuda_rpc_optimizer.py | 7 +- .../test_cuda_rpc_performance.py | 12 +- tests/test_pipeline/test_cuda_rpc_pipeline.py | 2 +- .../test_cuda_rpc_value_correctness.py | 5 +- tests/test_pipeline/test_middleware_1f1b.py | 59 +- tests/test_pipeline/test_pipelinable.py | 1 - .../test_pipeline_process_group.py | 8 +- tests/test_tensor/common_utils/__init__.py | 2 +- tests/test_tensor/core/test_dist_spec_mgr.py | 10 +- tests/test_tensor/core/test_tensor.py | 12 +- .../test_tensor/test_colo_checkpoint_tools.py | 94 +- tests/test_tensor/test_parameter.py | 5 +- tests/test_tensor/test_shape_consistency.py | 5 +- tests/test_trainer/test_pipeline/test_p2p.py | 16 +- .../test_pipeline/test_pipeline_schedule.py | 23 +- .../test_trainer_with_non_pipe_schedule.py | 5 +- .../test_trainer_with_pipe_schedule.py | 13 +- .../test_activation_checkpointing.py | 3 +- .../test_checkpoint/test_checkpoint_1d.py | 161 +- .../test_checkpoint/test_checkpoint_2d.py | 161 +- .../test_checkpoint/test_checkpoint_2p5d.py | 161 +- .../test_checkpoint/test_checkpoint_3d.py | 161 +- .../test_build_checkpoints.py | 3 +- .../test_checkpoint_io/test_load.py | 11 +- .../test_checkpoint_io/test_merge.py | 24 +- .../test_checkpoint_io/test_merge_param.py | 3 +- .../test_checkpoint_io/test_redist.py | 14 +- .../test_checkpoint_io/test_save.py | 15 +- .../test_checkpoint_io/test_unmerge_param.py | 3 +- tests/test_utils/test_colo_checkpoint.py | 24 +- tests/test_utils/test_commons.py | 13 +- tests/test_utils/test_lazy_init_ctx.py | 8 +- tests/test_utils/test_memory.py | 10 +- .../test_utils/test_norm_gradient_clipping.py | 16 +- .../test_zero_gradient_clippling.py | 13 +- tests/test_zero/common.py | 1 + tests/test_zero/test_found_inf.py | 144 +- tests/test_zero/test_shard_param.py | 7 +- .../test_sharded_optim_state_dict.py | 19 +- .../test_sharded_optim_with_sync_bn.py | 5 +- tests/test_zero/test_state_dict.py | 8 +- tests/test_zero/test_tensor_utils.py | 23 +- tests/test_zero/test_zero_engine.py | 8 +- 503 files changed, 20186 insertions(+), 12798 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 673b1274c94b..b310fcfefc15 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -8,4 +8,4 @@ contact_links: about: This issue tracker is not for technical support. Please use WeChat, and ask the community for help. - name: 😊 Advanced question - GitHub Discussions url: https://github.com/hpcaitech/ColossalAI/discussions - about: Use GitHub Discussions for advanced and unanswered technical questions, requiring a maintainer's answer. \ No newline at end of file + about: Use GitHub Discussions for advanced and unanswered technical questions, requiring a maintainer's answer. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index d05bc25f6f41..f12c41b52e6f 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -22,7 +22,7 @@ body: If applicable, add screenshots to help explain your problem. **Suggest a potential alternative/fix** Tell us how we could improve this project. - **Optional: Affiliation** + **Optional: Affiliation** Institution/email information helps better analyze and evaluate users to improve the project. Welcome to establish in-depth cooperation. placeholder: | A clear and concise description of your idea. diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml index a083362a7f0f..fbe669582c20 100644 --- a/.github/workflows/doc_test_on_pr.yml +++ b/.github/workflows/doc_test_on_pr.yml @@ -71,7 +71,7 @@ jobs: - name: Checkout ColossalAI uses: actions/checkout@v3 - + - name: Install Doc Test Requirements run: | source activate pytorch diff --git a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py index 04d2063ec5fc..5bec96187e0c 100644 --- a/.github/workflows/scripts/example_checks/check_dispatch_inputs.py +++ b/.github/workflows/scripts/example_checks/check_dispatch_inputs.py @@ -1,27 +1,27 @@ -import argparse -import os - - -def check_inputs(input_list): - for path in input_list: - real_path = os.path.join('examples', path) - if not os.path.exists(real_path): - return False - return True - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-f', '--fileNameList', type=str, help="List of file names") - args = parser.parse_args() - name_list = args.fileNameList.split(",") - is_correct = check_inputs(name_list) - - if is_correct: - print('success') - else: - print('failure') - - -if __name__ == '__main__': - main() +import argparse +import os + + +def check_inputs(input_list): + for path in input_list: + real_path = os.path.join('examples', path) + if not os.path.exists(real_path): + return False + return True + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--fileNameList', type=str, help="List of file names") + args = parser.parse_args() + name_list = args.fileNameList.split(",") + is_correct = check_inputs(name_list) + + if is_correct: + print('success') + else: + print('failure') + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/example_checks/check_example_weekly.py b/.github/workflows/scripts/example_checks/check_example_weekly.py index 941e90901f3d..83eff644e315 100644 --- a/.github/workflows/scripts/example_checks/check_example_weekly.py +++ b/.github/workflows/scripts/example_checks/check_example_weekly.py @@ -1,37 +1,37 @@ -import os - - -def show_files(path, all_files): - # Traverse all the folder/file in current directory - file_list = os.listdir(path) - # Determine the element is folder or file. If file, pass it into list, if folder, recurse. - for file_name in file_list: - # Get the abs directory using os.path.join() and store into cur_path. - cur_path = os.path.join(path, file_name) - # Determine whether folder - if os.path.isdir(cur_path): - show_files(cur_path, all_files) - else: - all_files.append(cur_path) - return all_files - - -def join(input_list, sep=None): - return (sep or ' ').join(input_list) - - -def main(): - contents = show_files('examples/', []) - all_loc = [] - for file_loc in contents: - split_loc = file_loc.split('/') - # must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not. - if len(split_loc) >= 4: - re_loc = '/'.join(split_loc[1:3]) - if re_loc not in all_loc: - all_loc.append(re_loc) - print(all_loc) - - -if __name__ == '__main__': - main() +import os + + +def show_files(path, all_files): + # Traverse all the folder/file in current directory + file_list = os.listdir(path) + # Determine the element is folder or file. If file, pass it into list, if folder, recurse. + for file_name in file_list: + # Get the abs directory using os.path.join() and store into cur_path. + cur_path = os.path.join(path, file_name) + # Determine whether folder + if os.path.isdir(cur_path): + show_files(cur_path, all_files) + else: + all_files.append(cur_path) + return all_files + + +def join(input_list, sep=None): + return (sep or ' ').join(input_list) + + +def main(): + contents = show_files('examples/', []) + all_loc = [] + for file_loc in contents: + split_loc = file_loc.split('/') + # must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not. + if len(split_loc) >= 4: + re_loc = '/'.join(split_loc[1:3]) + if re_loc not in all_loc: + all_loc.append(re_loc) + print(all_loc) + + +if __name__ == '__main__': + main() diff --git a/.github/workflows/scripts/example_checks/detect_changed_example.py b/.github/workflows/scripts/example_checks/detect_changed_example.py index df4fd67368fc..c69d95a552e9 100644 --- a/.github/workflows/scripts/example_checks/detect_changed_example.py +++ b/.github/workflows/scripts/example_checks/detect_changed_example.py @@ -1,24 +1,24 @@ -import argparse - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files") - args = parser.parse_args() - name_list = args.fileNameList.split(":") - folder_need_check = set() - for loc in name_list: - # Find only the sub-sub-folder of 'example' folder - # the examples folder structure is like - # - examples - # - area - # - application - # - file - if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4: - folder_need_check.add('/'.join(loc.split("/")[1:3])) - # Output the result using print. Then the shell can get the values. - print(list(folder_need_check)) - - -if __name__ == '__main__': - main() +import argparse + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files") + args = parser.parse_args() + name_list = args.fileNameList.split(":") + folder_need_check = set() + for loc in name_list: + # Find only the sub-sub-folder of 'example' folder + # the examples folder structure is like + # - examples + # - area + # - application + # - file + if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4: + folder_need_check.add('/'.join(loc.split("/")[1:3])) + # Output the result using print. Then the shell can get the values. + print(list(folder_need_check)) + + +if __name__ == '__main__': + main() diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 00abcf650158..915c43174c6a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -138,4 +138,4 @@ You can now create a pull request on the GitHub webpage of your repository. The Do write clearly the description of your pull request and [link the pull request to your target issue](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue). This will automatically close the issue when the pull request is approved. -In case of code conflict, you should rebase your branch and resolve the conflicts manually. \ No newline at end of file +In case of code conflict, you should rebase your branch and resolve the conflicts manually. diff --git a/applications/ChatGPT/chatgpt/dataset/__init__.py b/applications/ChatGPT/chatgpt/dataset/__init__.py index df484f46d24c..473b8787fcaf 100644 --- a/applications/ChatGPT/chatgpt/dataset/__init__.py +++ b/applications/ChatGPT/chatgpt/dataset/__init__.py @@ -1,5 +1,5 @@ -from .reward_dataset import RmStaticDataset, HhRlhfDataset +from .reward_dataset import HhRlhfDataset, RmStaticDataset +from .sft_dataset import AlpacaDataCollator, AlpacaDataset, SFTDataset from .utils import is_rank_0 -from .sft_dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator -__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator'] +__all__ = ['RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator'] diff --git a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py index 9ee13490b893..faa1c94d2728 100644 --- a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py +++ b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py @@ -5,6 +5,7 @@ from .utils import is_rank_0 + # Dahaos/rm-static class RmStaticDataset(Dataset): """ @@ -58,6 +59,7 @@ def __getitem__(self, idx): return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][ "input_ids"], self.reject[idx]["attention_mask"] + # Anthropic/hh-rlhf class HhRlhfDataset(Dataset): """ @@ -69,6 +71,7 @@ class HhRlhfDataset(Dataset): max_length: max length of input special_token: special token at the end of sentence """ + def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: super().__init__() self.chosen = [] diff --git a/applications/ChatGPT/chatgpt/dataset/sft_dataset.py b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py index 11ec61908aef..c7e1bd952006 100644 --- a/applications/ChatGPT/chatgpt/dataset/sft_dataset.py +++ b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py @@ -13,35 +13,34 @@ # limitations under the License. import copy +import random from dataclasses import dataclass, field from typing import Callable, Dict, Sequence -import random -from torch.utils.data import Dataset + +import torch import torch.distributed as dist +import transformers +from torch.utils.data import Dataset from tqdm import tqdm -import torch -from .utils import is_rank_0, jload - -import transformers from colossalai.logging import get_dist_logger +from .utils import is_rank_0, jload + logger = get_dist_logger() IGNORE_INDEX = -100 PROMPT_DICT = { - "prompt_input": ( - "Below is an instruction that describes a task, paired with an input that provides further context. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" - ), - "prompt_no_input": ( - "Below is an instruction that describes a task. " - "Write a response that appropriately completes the request.\n\n" - "### Instruction:\n{instruction}\n\n### Response:" - ), + "prompt_input": + ("Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), + "prompt_no_input": ("Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + "### Instruction:\n{instruction}\n\n### Response:"), } + class SFTDataset(Dataset): """ Dataset for sft model @@ -52,7 +51,7 @@ class SFTDataset(Dataset): max_length: max length of input """ - def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None: + def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None: super().__init__() # self.prompts = [] self.input_ids = [] @@ -77,7 +76,7 @@ def __getitem__(self, idx): # dict(input_ids=self.input_ids[i], labels=self.labels[i]) return dict(input_ids=self.input_ids[i], labels=self.labels[i]) # return dict(self.prompts[idx], self.prompts[idx]) - + def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: """Tokenize a list of strings.""" @@ -88,8 +87,7 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedToken padding="longest", max_length=tokenizer.model_max_length, truncation=True, - ) - for text in strings + ) for text in strings ] input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] input_ids_lens = labels_lens = [ @@ -102,6 +100,7 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedToken labels_lens=labels_lens, ) + def preprocess( sources: Sequence[str], targets: Sequence[str], @@ -116,6 +115,7 @@ def preprocess( label[:source_len] = IGNORE_INDEX return dict(input_ids=input_ids, labels=labels) + class AlpacaDataset(Dataset): """Dataset for supervised fine-tuning.""" @@ -143,7 +143,8 @@ def __len__(self): def __getitem__(self, i) -> Dict[str, torch.Tensor]: return dict(input_ids=self.input_ids[i], labels=self.labels[i]) - + + @dataclass class AlpacaDataCollator(object): """Collate examples for supervised fine-tuning.""" @@ -152,9 +153,9 @@ class AlpacaDataCollator(object): def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) - input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id - ) + input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, + batch_first=True, + padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) return dict( input_ids=input_ids, diff --git a/applications/ChatGPT/chatgpt/dataset/utils.py b/applications/ChatGPT/chatgpt/dataset/utils.py index 0e88cc8c39b4..f37fce67a7c6 100644 --- a/applications/ChatGPT/chatgpt/dataset/utils.py +++ b/applications/ChatGPT/chatgpt/dataset/utils.py @@ -7,14 +7,16 @@ def is_rank_0() -> bool: return not dist.is_initialized() or dist.get_rank() == 0 + def _make_r_io_base(f, mode: str): if not isinstance(f, io.IOBase): f = open(f, mode=mode) return f + def jload(f, mode="r"): """Load a .json file into a dictionary.""" f = _make_r_io_base(f, mode) jdict = json.load(f) f.close() - return jdict \ No newline at end of file + return jdict diff --git a/applications/ChatGPT/chatgpt/models/__init__.py b/applications/ChatGPT/chatgpt/models/__init__.py index b274188a21df..7489b2e87ca0 100644 --- a/applications/ChatGPT/chatgpt/models/__init__.py +++ b/applications/ChatGPT/chatgpt/models/__init__.py @@ -1,4 +1,4 @@ from .base import Actor, Critic, RewardModel -from .loss import PolicyLoss, PPOPtxActorLoss, ValueLoss, LogSigLoss, LogExpLoss +from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss __all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss'] diff --git a/applications/ChatGPT/chatgpt/models/base/__init__.py b/applications/ChatGPT/chatgpt/models/base/__init__.py index 7c7b1ceba257..7cf82309af7b 100644 --- a/applications/ChatGPT/chatgpt/models/base/__init__.py +++ b/applications/ChatGPT/chatgpt/models/base/__init__.py @@ -1,6 +1,6 @@ from .actor import Actor from .critic import Critic -from .reward_model import RewardModel from .lm import LM +from .reward_model import RewardModel __all__ = ['Actor', 'Critic', 'RewardModel', 'LM'] diff --git a/applications/ChatGPT/chatgpt/models/base/lm.py b/applications/ChatGPT/chatgpt/models/base/lm.py index b6bd7aff8315..e32ba4253369 100644 --- a/applications/ChatGPT/chatgpt/models/base/lm.py +++ b/applications/ChatGPT/chatgpt/models/base/lm.py @@ -21,13 +21,10 @@ class LM(Actor): def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None: super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias) - def forward(self, - sequences: torch.LongTensor, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Returns output log probs """ output = self.model(sequences, attention_mask=attention_mask) logits = output['logits'] log_probs = F.log_softmax(logits, dim=-1) return log_probs - diff --git a/applications/ChatGPT/chatgpt/models/bloom/__init__.py b/applications/ChatGPT/chatgpt/models/bloom/__init__.py index 7d6d7753bb9a..39dfe036a2f2 100644 --- a/applications/ChatGPT/chatgpt/models/bloom/__init__.py +++ b/applications/ChatGPT/chatgpt/models/bloom/__init__.py @@ -1,6 +1,6 @@ from .bloom_actor import BLOOMActor from .bloom_critic import BLOOMCritic -from .bloom_rm import BLOOMRM from .bloom_lm import BLOOMLM +from .bloom_rm import BLOOMRM __all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM'] diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py index 81e17f27c11a..628af2e341a2 100644 --- a/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py @@ -33,4 +33,3 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) - diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py index 2dba227ff7d0..22cfab441abb 100644 --- a/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py +++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py @@ -33,5 +33,5 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.hidden_size, 1) - value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1)) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/models/gpt/__init__.py b/applications/ChatGPT/chatgpt/models/gpt/__init__.py index c6ae05113cc0..9dc68e37544f 100644 --- a/applications/ChatGPT/chatgpt/models/gpt/__init__.py +++ b/applications/ChatGPT/chatgpt/models/gpt/__init__.py @@ -1,6 +1,6 @@ from .gpt_actor import GPTActor from .gpt_critic import GPTCritic -from .gpt_rm import GPTRM from .gpt_lm import GPTLM +from .gpt_rm import GPTRM __all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM'] diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py index 5740c80d3e77..23fc13bf23a4 100644 --- a/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py @@ -33,4 +33,3 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) - diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py index 19d673de6825..054432e1ce86 100644 --- a/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py +++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py @@ -35,5 +35,5 @@ def __init__(self, model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.n_embd, 1) - value_head.weight.data.normal_(mean=0.0, std=1/(model.config.n_embd + 1)) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/models/llama/__init__.py b/applications/ChatGPT/chatgpt/models/llama/__init__.py index 3edb51e14376..0d4dada3c9f1 100644 --- a/applications/ChatGPT/chatgpt/models/llama/__init__.py +++ b/applications/ChatGPT/chatgpt/models/llama/__init__.py @@ -1,6 +1,6 @@ from .llama_actor import LlamaActor from .llama_critic import LlamaCritic -from .llama_rm import LlamaRM from .llama_lm import LlamaLM +from .llama_rm import LlamaRM __all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM'] diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_lm.py b/applications/ChatGPT/chatgpt/models/llama/llama_lm.py index c63077b1ac04..a44196ea6a44 100644 --- a/applications/ChatGPT/chatgpt/models/llama/llama_lm.py +++ b/applications/ChatGPT/chatgpt/models/llama/llama_lm.py @@ -33,6 +33,5 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() - - super().__init__(model, lora_rank, lora_train_bias) + super().__init__(model, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/models/lora.py b/applications/ChatGPT/chatgpt/models/lora.py index 9c19f472d726..f8f7a1cb5d81 100644 --- a/applications/ChatGPT/chatgpt/models/lora.py +++ b/applications/ChatGPT/chatgpt/models/lora.py @@ -127,4 +127,3 @@ def convert_to_lora(self) -> None: return convert_to_lora_recursively(self, self.lora_rank) lora.mark_only_lora_as_trainable(self, self.lora_train_bias) - diff --git a/applications/ChatGPT/chatgpt/models/loss.py b/applications/ChatGPT/chatgpt/models/loss.py index c5b1ccc93228..7fc437d90fdb 100644 --- a/applications/ChatGPT/chatgpt/models/loss.py +++ b/applications/ChatGPT/chatgpt/models/loss.py @@ -98,6 +98,7 @@ class LogSigLoss(nn.Module): Pairwise Loss for Reward Model Details: https://arxiv.org/abs/2203.02155 """ + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: probs = torch.sigmoid(chosen_reward - reject_reward) log_probs = torch.log(probs) @@ -110,6 +111,7 @@ class LogExpLoss(nn.Module): Pairwise Loss for Reward Model Details: https://arxiv.org/abs/2204.05862 """ + def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor: loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean() return loss diff --git a/applications/ChatGPT/chatgpt/models/opt/__init__.py b/applications/ChatGPT/chatgpt/models/opt/__init__.py index fccec3bdff99..3d7a8adbf82e 100644 --- a/applications/ChatGPT/chatgpt/models/opt/__init__.py +++ b/applications/ChatGPT/chatgpt/models/opt/__init__.py @@ -1,6 +1,6 @@ from .opt_actor import OPTActor from .opt_critic import OPTCritic -from .opt_rm import OPTRM from .opt_lm import OPTLM +from .opt_rm import OPTRM __all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM'] diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_lm.py b/applications/ChatGPT/chatgpt/models/opt/opt_lm.py index 35bfe198a225..65d79e1b2307 100644 --- a/applications/ChatGPT/chatgpt/models/opt/opt_lm.py +++ b/applications/ChatGPT/chatgpt/models/opt/opt_lm.py @@ -33,4 +33,3 @@ def __init__(self, if checkpoint: model.gradient_checkpointing_enable() super().__init__(model, lora_rank, lora_train_bias) - diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_rm.py b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py index ef7f0fb16fd1..50fc0dee8568 100644 --- a/applications/ChatGPT/chatgpt/models/opt/opt_rm.py +++ b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py @@ -34,5 +34,5 @@ def __init__(self, model.gradient_checkpointing_enable() value_head = nn.Linear(model.config.word_embed_proj_dim, 1) - value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1)) + value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1)) super().__init__(model, value_head, lora_rank, lora_train_bias) diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py index 7fa87a64968b..bb56ec1c2bd0 100644 --- a/applications/ChatGPT/chatgpt/trainer/rm.py +++ b/applications/ChatGPT/chatgpt/trainer/rm.py @@ -1,12 +1,13 @@ from abc import ABC -import pandas as pd +from datetime import datetime + import loralib as lora +import pandas as pd import torch -from datetime import datetime from torch.optim import Optimizer, lr_scheduler from torch.utils.data import DataLoader, Dataset from tqdm import tqdm - + from .strategies import Strategy from .utils import is_rank_0 @@ -45,12 +46,11 @@ def __init__( self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True) self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True) - + self.model = strategy.setup_model(model) self.loss_fn = loss_fn self.optimizer = strategy.setup_optimizer(optim, self.model) - self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__()//100) - + self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100) def eval_acc(self, dataloader): dist = 0 @@ -74,7 +74,6 @@ def eval_acc(self, dataloader): acc = on / cnt self.model.train() return dist_mean, acc - def fit(self): time = datetime.now() @@ -105,16 +104,17 @@ def fit(self): dist, acc = self.eval_acc(self.valid_dataloader) cnt = 0 if is_rank_0(): - log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) + log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], + columns=['step', 'loss', 'dist', 'acc']) log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False) step_bar.update() step_bar.set_postfix({'dist': dist, 'acc': acc}) - + # eval dist, acc = self.eval_acc(self.eval_dataloader) if is_rank_0(): - log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) - log.to_csv('log.csv', mode='a', header=False, index=False) + log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc']) + log.to_csv('log.csv', mode='a', header=False, index=False) epoch_bar.update() step_bar.set_postfix({'dist': dist, 'acc': acc}) step_bar.close() diff --git a/applications/ChatGPT/chatgpt/trainer/sft.py b/applications/ChatGPT/chatgpt/trainer/sft.py index 3b35f516816f..f695ea1689bd 100644 --- a/applications/ChatGPT/chatgpt/trainer/sft.py +++ b/applications/ChatGPT/chatgpt/trainer/sft.py @@ -1,16 +1,19 @@ from abc import ABC from typing import Optional + import loralib as lora import torch +import torch.distributed as dist from chatgpt.models.loss import GPTLMLoss from torch.optim import Adam, Optimizer from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from tqdm import tqdm -import torch.distributed as dist + +from colossalai.logging import get_dist_logger + from .strategies import Strategy from .utils import is_rank_0 -from colossalai.logging import get_dist_logger class SFTTrainer(ABC): @@ -74,7 +77,8 @@ def fit(self, logger, use_lora, log_interval=10): self.strategy.optimizer_step(self.optimizer) self.optimizer.zero_grad() if batch_id % log_interval == 0: - logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') + logger.info( + f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}') # eval if self.eval_dataloader is not None: @@ -96,6 +100,5 @@ def fit(self, logger, use_lora, log_interval=10): loss_mean = loss_sum / num_seen if dist.get_rank() == 0: logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}') - - epoch_bar.update() + epoch_bar.update() diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py index b20b02d3d34d..56693d8edf39 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py @@ -156,7 +156,7 @@ def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> N # merge lora_weights into weights for module in unwrapped_model.modules(): if isinstance(module, LoraLinear): - module.merge_weights=True + module.merge_weights = True module.eval() # get state_dict and save state_dict = unwrapped_model.state_dict() diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py index c9f92c12fe0a..ef6885f39074 100644 --- a/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py +++ b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py @@ -75,15 +75,15 @@ def _unwrap_actor(actor: Actor) -> nn.Module: def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None: for module in model.modules(): if isinstance(module, LoraLinear): - module.merge_weights=True + module.merge_weights = True module.eval() - + if only_rank0 and dist.get_rank() != 0: return model = model.model.module state_dict = model.state_dict() torch.save(state_dict, path) - + def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: if only_rank0 and dist.get_rank() != 0: return diff --git a/applications/ChatGPT/chatgpt/utils/__init__.py b/applications/ChatGPT/chatgpt/utils/__init__.py index 8f526d7efdad..e75401d382a8 100644 --- a/applications/ChatGPT/chatgpt/utils/__init__.py +++ b/applications/ChatGPT/chatgpt/utils/__init__.py @@ -1,3 +1,3 @@ -from .tokenizer_utils import smart_tokenizer_and_embedding_resize, prepare_llama_tokenizer_and_embedding +from .tokenizer_utils import prepare_llama_tokenizer_and_embedding, smart_tokenizer_and_embedding_resize -__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding'] \ No newline at end of file +__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding'] diff --git a/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py b/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py index 8699bf64c7b5..1daab793f205 100644 --- a/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py +++ b/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py @@ -21,10 +21,11 @@ DEFAULT_BOS_TOKEN = "" DEFAULT_UNK_TOKEN = "" + def prepare_llama_tokenizer_and_embedding( - tokenizer: transformers.PreTrainedTokenizer, - model: transformers.PreTrainedModel, - special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, + special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), ): """prepare llama tokenizer and embedding. @@ -37,21 +38,19 @@ def prepare_llama_tokenizer_and_embedding( model=model, ) - tokenizer.add_special_tokens( - { - "eos_token": DEFAULT_EOS_TOKEN, - "bos_token": DEFAULT_BOS_TOKEN, - "unk_token": DEFAULT_UNK_TOKEN, - } - ) + tokenizer.add_special_tokens({ + "eos_token": DEFAULT_EOS_TOKEN, + "bos_token": DEFAULT_BOS_TOKEN, + "unk_token": DEFAULT_UNK_TOKEN, + }) return tokenizer def smart_tokenizer_and_embedding_resize( - tokenizer: transformers.PreTrainedTokenizer, - model: transformers.PreTrainedModel, - special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, + special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN), ): """Resize tokenizer and embedding. @@ -71,4 +70,3 @@ def smart_tokenizer_and_embedding_resize( input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg - \ No newline at end of file diff --git a/applications/ChatGPT/examples/test_ci.sh b/applications/ChatGPT/examples/test_ci.sh index 1d05c4c58341..db1d0b64e3b3 100755 --- a/applications/ChatGPT/examples/test_ci.sh +++ b/applications/ChatGPT/examples/test_ci.sh @@ -81,7 +81,7 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ --pretrain 'gpt2' --model 'gpt2' \ --strategy colossalai_gemini --loss_fn 'log_exp'\ --dataset 'Dahoas/rm-static' --test True --lora_rank 4 - + torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \ --pretrain 'bigscience/bloom-560m' --model 'bloom' \ --strategy colossalai_zero2 --loss_fn 'log_sig'\ diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py index a9c844b7b1f8..a261d87d5b18 100644 --- a/applications/ChatGPT/examples/train_reward_model.py +++ b/applications/ChatGPT/examples/train_reward_model.py @@ -1,24 +1,25 @@ import argparse +from random import randint import loralib as lora import torch from chatgpt.dataset import HhRlhfDataset, RmStaticDataset -from chatgpt.models import LogSigLoss, LogExpLoss +from chatgpt.models import LogExpLoss, LogSigLoss from chatgpt.models.base import RewardModel from chatgpt.models.bloom import BLOOMRM +from chatgpt.models.deberta import DebertaRM from chatgpt.models.gpt import GPTRM from chatgpt.models.opt import OPTRM -from chatgpt.models.deberta import DebertaRM from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from datasets import load_dataset -from random import randint from torch.optim import Adam from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.nn.optimizer import HybridAdam + def train(args): # configure strategy if args.strategy == 'naive': @@ -44,11 +45,11 @@ def train(args): model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') - + if args.model_path is not None: state_dict = torch.load(args.model_path) model.load_state_dict(state_dict) - + # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') @@ -68,7 +69,7 @@ def train(args): optim = HybridAdam(model.parameters(), lr=1.5e-5) else: optim = Adam(model.parameters(), lr=1.5e-5) - + # configure loss function if args.loss_fn == 'log_sig': loss_fn = LogSigLoss() @@ -76,21 +77,21 @@ def train(args): loss_fn = LogExpLoss() else: raise ValueError(f'Unsupported loss function "{args.loss_fn}"') - + # prepare for data and dataset if args.subset is not None: data = load_dataset(args.dataset, data_dir=args.subset) else: data = load_dataset(args.dataset) - + if args.test: train_data = data['train'].select(range(100)) - eval_data = data['test'].select(range(10)) + eval_data = data['test'].select(range(10)) else: train_data = data['train'] eval_data = data['test'] - valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data)//10))) - + valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 10))) + if args.dataset == 'Dahoas/rm-static': train_dataset = RmStaticDataset(train_data, tokenizer, max_len) valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len) @@ -101,11 +102,11 @@ def train(args): eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len) else: raise ValueError(f'Unsupported dataset "{args.dataset}"') - + trainer = RewardModelTrainer(model=model, strategy=strategy, optim=optim, - loss_fn = loss_fn, + loss_fn=loss_fn, train_dataset=train_dataset, valid_dataset=valid_dataset, eval_dataset=eval_dataset, @@ -117,7 +118,10 @@ def train(args): strategy.save_model(trainer.model, args.save_path, only_rank0=True) # save optimizer checkpoint on all ranks if args.need_optim_ckpt: - strategy.save_optimizer(trainer.optimizer, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False) + strategy.save_optimizer(trainer.optimizer, + 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -128,7 +132,8 @@ def train(args): parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--need_optim_ckpt', type=bool, default=False) - parser.add_argument('--dataset', type=str, + parser.add_argument('--dataset', + type=str, choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'], default='Dahoas/rm-static') parser.add_argument('--subset', type=str, default=None) diff --git a/applications/ChatGPT/examples/train_sft.py b/applications/ChatGPT/examples/train_sft.py index 83b34f9dd1ea..009a8f9d0d2d 100644 --- a/applications/ChatGPT/examples/train_sft.py +++ b/applications/ChatGPT/examples/train_sft.py @@ -3,24 +3,24 @@ import loralib as lora import torch import torch.distributed as dist -from torch.utils.data.distributed import DistributedSampler -from chatgpt.dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator +from chatgpt.dataset import AlpacaDataCollator, AlpacaDataset, SFTDataset from chatgpt.models.base import RewardModel from chatgpt.models.bloom import BLOOMLM from chatgpt.models.gpt import GPTLM -from chatgpt.models.opt import OPTLM from chatgpt.models.llama import LlamaLM +from chatgpt.models.opt import OPTLM from chatgpt.trainer import SFTTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from chatgpt.utils import prepare_llama_tokenizer_and_embedding from datasets import load_dataset from torch.optim import Adam from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from transformers import AutoTokenizer, BloomTokenizerFast from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer -from colossalai.nn.optimizer import HybridAdam from colossalai.logging import get_dist_logger +from colossalai.nn.optimizer import HybridAdam def train(args): @@ -66,7 +66,7 @@ def train(args): ) else: raise ValueError(f'Unsupported model "{args.model}"') - + if args.model == 'llama': tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model) else: @@ -138,4 +138,3 @@ def train(args): parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log") args = parser.parse_args() train(args) - diff --git a/colossalai/amp/__init__.py b/colossalai/amp/__init__.py index 16da81f23898..963215476b6b 100644 --- a/colossalai/amp/__init__.py +++ b/colossalai/amp/__init__.py @@ -1,14 +1,16 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from .amp_type import AMP_TYPE -from colossalai.context import Config import torch.nn as nn -from torch.optim import Optimizer from torch.nn.modules.loss import _Loss -from .torch_amp import convert_to_torch_amp +from torch.optim import Optimizer + +from colossalai.context import Config + +from .amp_type import AMP_TYPE from .apex_amp import convert_to_apex_amp from .naive_amp import convert_to_naive_amp +from .torch_amp import convert_to_torch_amp __all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE'] diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py index a79e5006e7d2..19d85b80dd3d 100644 --- a/colossalai/auto_parallel/offload/amp_optimizer.py +++ b/colossalai/auto_parallel/offload/amp_optimizer.py @@ -1,24 +1,25 @@ -from typing import Dict, Tuple from enum import Enum +from typing import Dict, Tuple + import torch from torch.optim import Optimizer +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.utils import get_current_device from .base_offload_module import BaseOffloadModule -from .region_manager import RegionManager from .region import Region +from .region_manager import RegionManager class OptimState(Enum): SCALED = 0 UNSCALED = 1 -class AMPOptimizer(ColossalaiOptimizer): +class AMPOptimizer(ColossalaiOptimizer): """ A wrapper for Optimizer. Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py @@ -174,4 +175,4 @@ def __init__optimizer(self): # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors - self.optim.load_state_dict(self.optim.state_dict()) \ No newline at end of file + self.optim.load_state_dict(self.optim.state_dict()) diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index 59cea4ece266..47afb7ae9f4e 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -1,10 +1,11 @@ -from typing import Optional, Set from functools import partial +from typing import Optional, Set + import torch import torch.nn as nn -from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.gemini.tensor_utils import free_storage +from colossalai.nn.parallel.data_parallel import _cast_float from .region_manager import RegionManager from .util import GlobalRuntimeInfo @@ -20,10 +21,7 @@ class BaseOffloadModule: is_sync (bool): synchronous mode or not. """ - def __init__(self, - model: nn.Module, - region_manager: RegionManager, - is_sync=True): + def __init__(self, model: nn.Module, region_manager: RegionManager, is_sync=True): self.model = model self.region_manager = region_manager diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py index 02778696a106..8114834b614d 100644 --- a/colossalai/auto_parallel/offload/mem_optimize.py +++ b/colossalai/auto_parallel/offload/mem_optimize.py @@ -1,4 +1,5 @@ from typing import Dict + import torch import torch.fx from torch.fx import GraphModule @@ -7,10 +8,11 @@ from colossalai.fx import ColoTracer, is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from .region_manager import RegionManager -from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass from .base_offload_module import BaseOffloadModule -from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo +from .region_manager import RegionManager +from .runtime import runtime_asyn_offload_apply_pass, runtime_syn_offload_apply_pass +from .util import GlobalRuntimeInfo, compute_act_peak_mem, compute_max_param_mem, compute_total_param_mem + def memory_optimize(model: torch.nn.Module, inps: Dict[str, torch.Tensor], @@ -31,11 +33,12 @@ def memory_optimize(model: torch.nn.Module, region_manager._build_regions() GlobalRuntimeInfo.region_list = region_manager.region_list - act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2 - max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2 - total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2 + act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024**2 + max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024**2 + total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024**2 print( - f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}") + f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}" + ) if solver_name == 'syn': gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list) @@ -45,5 +48,5 @@ def memory_optimize(model: torch.nn.Module, raise TypeError(f"Unknown solver name {solver_name}!") gm.recompile() - optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn') + optimized_model = BaseOffloadModule(gm, region_manager, solver_name == 'syn') return optimized_model diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py index e6907cc4b81d..994e3f1c0f4c 100644 --- a/colossalai/auto_parallel/offload/region.py +++ b/colossalai/auto_parallel/offload/region.py @@ -1,8 +1,11 @@ -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple + import torch from torch.fx import Node + from colossalai.gemini.tensor_utils import alloc_storage, free_storage + class Region: """ Region: A container owning a piece of contiguous nodes in the DNN computing graph. @@ -52,15 +55,13 @@ def init_param_data(self, pre_alloc_tensor: torch.Tensor = None): Map the parameters in the region to a contiguous memory space. """ - self.fp16_data = torch.zeros( - self.param_num, dtype=torch.half, device='cuda') + self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda') offset = 0 for param in self.fp16_params: param.data = param.data.cuda() p_num = param.data.numel() self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) - param.data = self.fp16_data[offset:offset + - p_num].view(param.data.shape) + param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape) self.param_to_range[param] = (offset, offset + p_num) offset += p_num @@ -141,4 +142,4 @@ def split(self, cut_node_idx: int, cut_param_idx: int): def __update_params_ptr(self) -> None: for param in self.fp16_params: begin, end = self.param_to_range[param] - param.data = self.fp16_data[begin:end].view(param.data.shape) \ No newline at end of file + param.data = self.fp16_data[begin:end].view(param.data.shape) diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py index 30bfaf00d493..3a9d1cfa8b67 100644 --- a/colossalai/auto_parallel/offload/region_manager.py +++ b/colossalai/auto_parallel/offload/region_manager.py @@ -1,10 +1,11 @@ -from typing import List, Any, Dict, Tuple +from typing import Any, Dict, List, Tuple + import torch from torch.fx import Graph, Node +from .region import Region from .solver import SolverFactory from .training_simulator import TrainingSimulator -from .region import Region from .util import NodeInfo @@ -19,11 +20,7 @@ class RegionManager: cnode (List[str], optional): Common node List, should be the subset of input. """ - def __init__(self, - graph: Graph, - solver_name: str = 'asyn', - memory_budget: float = -1.0, - cnode: List[str] = None): + def __init__(self, graph: Graph, solver_name: str = 'asyn', memory_budget: float = -1.0, cnode: List[str] = None): self.graph = graph assert graph.owning_module is not None, 'The given graph is not associated with a owning_module' @@ -65,8 +62,7 @@ def _pre_process(self): init_region_list = self._linearize_graph() if len(self.shared_region_pairs) > 1: - raise NotImplementedError( - 'The current version only considers at most one pair of parameter sharing.') + raise NotImplementedError('The current version only considers at most one pair of parameter sharing.') elif len(self.shared_region_pairs) == 1: shared_regs = self.shared_region_pairs[0] @@ -122,12 +118,9 @@ def _early_region_placement(self, ts: TrainingSimulator): it may not find a suitable region placement strategy for the given execution flow. """ - reg_flow = torch.cat( - [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) - mem_block_num = torch.max( - torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) - coexist_matrix = torch.logical_or( - ts.fwd_reg_flow, ts.bwd_reg_flow) + reg_flow = torch.cat([ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0) + mem_block_num = torch.max(torch.sum(reg_flow[:, self.rid_in_pool], dim=1)) + coexist_matrix = torch.logical_or(ts.fwd_reg_flow, ts.bwd_reg_flow) block_to_regs = {} for block_idx in range(mem_block_num): @@ -135,8 +128,7 @@ def _early_region_placement(self, ts: TrainingSimulator): for reg in self.region_list: if reg.r_id in self.rid_in_pool: cur_reg_appears = coexist_matrix[:, reg.r_id] - cur_reg_coexists = torch.sum( - coexist_matrix[cur_reg_appears], dim=0).bool() + cur_reg_coexists = torch.sum(coexist_matrix[cur_reg_appears], dim=0).bool() for block_idx in range(mem_block_num): if not any(cur_reg_coexists[block_to_regs[block_idx]]): block_to_regs[block_idx].append(reg.r_id) @@ -146,8 +138,10 @@ def _early_region_placement(self, ts: TrainingSimulator): if reg.r_id not in self.reg_to_block: raise NotImplementedError( f'can not find a block from the memory pool to store parameters of the region') - self.memory_pool = torch.chunk(torch.zeros(int( - mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num)) + self.memory_pool = torch.chunk(torch.zeros(int(mem_block_num * self.mem_block_size / 2), + dtype=torch.half, + device='cuda'), + chunks=int(mem_block_num)) def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: """ @@ -181,7 +175,7 @@ def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]: def _search_block_size(self, region_list: List[Region], search_interval_byte: int = 1024, - search_range_byte: int = 128 * 1024 ** 2) -> int: + search_range_byte: int = 128 * 1024**2) -> int: """ Search for a suitable memory block size. @@ -208,8 +202,7 @@ def _get_wasted_mem(size_list: List[int], blk_size: int): acc_wasted += blk_size - left return acc_wasted - param_size_list = [ - region.param_size for region in region_list if region.r_id == region.shared_rid] + param_size_list = [region.param_size for region in region_list if region.r_id == region.shared_rid] start_size = max(param_size_list) min_mem_waste = float('+inf') @@ -244,8 +237,7 @@ def _init_region_data(self): region.fp16_data = shared_region.fp16_data region.fp32_data = shared_region.fp32_data region.param_to_range = shared_region.param_to_range - region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach( - ) + region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach() torch.cuda.empty_cache() @@ -343,10 +335,8 @@ def _maybe_param_comp_start() -> bool: elif n.op == "call_module": target = n.target submod = self.root_module.get_submodule(target) - if ( - len(list(submod.named_parameters(recurse=False))) != 0 - or len(list(submod.named_buffers(recurse=False))) != 0 - ): + if (len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0): label = True return label and not sum([v for _, v in param_op_deps.items()]) @@ -368,8 +358,7 @@ def _is_inplace(n: Node): if n.op == "call_function": inplace = n.kwargs.get("inplace", False) elif n.op == "call_module": - inplace = getattr(n.graph.owning_module.get_submodule( - n.target), "inplace", False) + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) return inplace label = False @@ -377,10 +366,8 @@ def _is_inplace(n: Node): if n.op == "call_module": target = n.target submod = self.root_module.get_submodule(target) - if ( - len(list(submod.named_parameters(recurse=False))) != 0 - or len(list(submod.named_buffers(recurse=False))) != 0 - ): + if (len(list(submod.named_parameters(recurse=False))) != 0 + or len(list(submod.named_buffers(recurse=False))) != 0): label = True elif n.op == "call_function": @@ -449,18 +436,16 @@ def _exception_node_handling(): # propagate common node attr if possible if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode - ]) or _is_cop(n.target): + ]) or _is_cop(n.target): self.cnode.append(n.name) else: - deps[n] = len( - [user for user in n.users if user.op != "output"]) + deps[n] = len([user for user in n.users if user.op != "output"]) # propagate param node attr if possible - if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops - ]) or n.op == "get_attr": + if len(n.all_input_nodes) == len( + [node for node in n.all_input_nodes if node.name in self.only_param_ops]) or n.op == "get_attr": self.only_param_ops.append(n.name) - param_op_deps[n] = len( - [user for user in n.users if user.op != "output"]) + param_op_deps[n] = len([user for user in n.users if user.op != "output"]) # record last activation node if _is_act(n._meta_data): @@ -483,8 +468,7 @@ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): if p in self.param_region_map: cur_reg.shared_rid = self.param_region_map[p].r_id self.param_region_map[p].shared_rid = cur_reg.r_id - self.shared_region_pairs.append( - (self.param_region_map[p], cur_reg)) + self.shared_region_pairs.append((self.param_region_map[p], cur_reg)) else: self.param_region_map[p] = cur_reg @@ -503,8 +487,7 @@ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region): if attr_itr in self.param_region_map: cur_reg.shared_rid = self.param_region_map[attr_itr].r_id self.param_region_map[attr_itr].shared_rid = cur_reg.r_id - self.shared_region_pairs.append( - (self.param_region_map[attr_itr], cur_reg)) + self.shared_region_pairs.append((self.param_region_map[attr_itr], cur_reg)) else: self.param_region_map[attr_itr] = cur_reg diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py index 91c7945bd65f..bfa30ebba6a3 100644 --- a/colossalai/auto_parallel/offload/runtime.py +++ b/colossalai/auto_parallel/offload/runtime.py @@ -1,4 +1,5 @@ from typing import List + import torch from torch.fx.node import Node @@ -65,8 +66,7 @@ def forward(ctx, input_, fwd_info, bwd_info): sync_rid = fwd_info.get('sync_rid', None) if sync_rid is not None: - prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get( - sync_rid, None) + prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() @@ -92,8 +92,7 @@ def backward(ctx, grad_output): if sync_rid is not None: wait_region = GlobalRuntimeInfo.region_list[sync_rid] assert isinstance(wait_region, Region) - prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get( - sync_rid, None) + prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get(sync_rid, None) if prefetch_event: prefetch_event.wait() else: @@ -129,6 +128,7 @@ def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info): ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info) return ret + def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info): ''' Convert Prefetch and Offload operation into runtime action. @@ -189,7 +189,8 @@ def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[R if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, + new_node = mod_graph.create_node('call_function', + convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)) replace_node_users(last_inp_node, new_node) @@ -206,11 +207,11 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ # upload parameters of the first region last_inp_node = tuple(mod_graph.nodes)[0] - first_region_with_p = [ - region for region in region_list if region.param_size][0] + first_region_with_p = [region for region in region_list if region.param_size][0] fwd_info = {"h2d_rid": first_region_with_p.r_id} with mod_graph.inserting_after(last_inp_node): - upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action, + upload_apply_node = mod_graph.create_node('call_function', + convert_fwd_upload_bwd_offload_to_action, args=(last_inp_node, fwd_info, {})) replace_node_users(last_inp_node, upload_apply_node) last_inp_node = upload_apply_node @@ -225,19 +226,20 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ fwd_info['h2d_rid'] = fwd_prefetch_region.r_id # forward offload - if r_idx > 0 and region_list[r_idx-1].need_offload: + if r_idx > 0 and region_list[r_idx - 1].need_offload: fwd_info['d2h_rid'] = r_idx - 1 bwd_info = {} # backward prefetch - if r_idx > 0 and region_list[r_idx-1].need_offload: + if r_idx > 0 and region_list[r_idx - 1].need_offload: bwd_info['sync_rid'] = r_idx - 1 - if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region: - bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id + if r_idx > 0 and region_list[r_idx - 1].bwd_prefetch_region: + bwd_info['h2d_rid'] = region_list[r_idx - 1].bwd_prefetch_region.r_id if fwd_info or bwd_info: with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, + new_node = mod_graph.create_node('call_function', + convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, fwd_info, bwd_info)) replace_node_users(last_inp_node, new_node) @@ -246,7 +248,8 @@ def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[ if region.bwd_prefetch_region: bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id} with mod_graph.inserting_after(last_inp_node): - new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action, + new_node = mod_graph.create_node('call_function', + convert_fwd_prefetch_bwd_offload_to_action, args=(last_inp_node, {}, bwd_info)) replace_node_users(last_inp_node, new_node) # gm.graph.print_tabular() diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py index 161f7ff86898..7c59d8e5dc0f 100644 --- a/colossalai/auto_parallel/offload/solver.py +++ b/colossalai/auto_parallel/offload/solver.py @@ -1,6 +1,6 @@ import time -from typing import List, Dict, Type from abc import ABC, abstractmethod +from typing import Dict, List, Type NOT_NVML = False try: @@ -10,10 +10,11 @@ import torch from torch.fx.node import Node + from colossalai.utils.cuda import get_current_device -from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator from .region import Region +from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator from .util import NodeInfo, NvDevicePower @@ -49,10 +50,7 @@ class Solver(ABC): It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time. """ - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0, - error_factor: float = 0.95) -> None: + def __init__(self, region_list: List[Region], memory_budget: float = -1.0, error_factor: float = 0.95) -> None: self.region_list = region_list @@ -60,8 +58,7 @@ def __init__(self, if memory_budget > 0: self.memory_budget = memory_budget * self.error_factor else: - self.memory_budget = torch.cuda.get_device_properties( - get_current_device()).total_memory * self.error_factor + self.memory_budget = torch.cuda.get_device_properties(get_current_device()).total_memory * self.error_factor self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth() self.comp_power: float = self._extract_computing_power() @@ -122,9 +119,7 @@ def _update_state(self, best_ts: TrainingSimulator): self.best_ts = best_ts self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem) - def _update_node_mem_info(self, - fwd_mem_info: Dict[Node, float], - bwd_mem_info: Dict[Node, float]): + def _update_node_mem_info(self, fwd_mem_info: Dict[Node, float], bwd_mem_info: Dict[Node, float]): """ Update the runtime memory information of the node. @@ -134,12 +129,10 @@ def _update_node_mem_info(self, """ for node, mem in fwd_mem_info.items(): - assert hasattr(node, 'node_info') and isinstance( - node.node_info, NodeInfo) + assert hasattr(node, 'node_info') and isinstance(node.node_info, NodeInfo) node.node_info.runtime_fwd_mem = mem for node, mem in bwd_mem_info.items(): - assert hasattr(node, 'node_info') and isinstance( - node.node_info, NodeInfo) + assert hasattr(node, 'node_info') and isinstance(node.node_info, NodeInfo) node.node_info.runtime_bwd_mem = mem def _extract_computing_power(self): @@ -183,15 +176,11 @@ def _profile_bandwidth(self): # from 1KB to 1GB for i in range(21): if link == 'h2d': - src_tensor = torch.ones( - int(t_size), dtype=torch.int8, pin_memory=True) - dst_tensor = torch.ones( - (int(t_size)), dtype=torch.int8, device='cuda') + src_tensor = torch.ones(int(t_size), dtype=torch.int8, pin_memory=True) + dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, device='cuda') elif link == 'd2h': - src_tensor = torch.ones( - int(t_size), dtype=torch.int8, device='cuda') - dst_tensor = torch.ones( - (int(t_size)), dtype=torch.int8, pin_memory=True) + src_tensor = torch.ones(int(t_size), dtype=torch.int8, device='cuda') + dst_tensor = torch.ones((int(t_size)), dtype=torch.int8, pin_memory=True) def func(): dst_tensor.copy_(src_tensor) @@ -209,9 +198,7 @@ def func(): class SynGreedySolver(Solver): - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0) -> None: + def __init__(self, region_list: List[Region], memory_budget: float = -1.0) -> None: super().__init__(region_list, memory_budget) self.best_ts: SynTrainingSimulator = None @@ -302,18 +289,14 @@ def _eval_one_choice(self, offload_region: Region): # the shared region needs to be moved twice if offload_region.r_id < offload_region.shared_rid: extra_comm_cost *= 2.0 - profit = self._compute_offload_profit( - ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class AsynGreedySolver(Solver): - def __init__(self, - region_list: List[Region], - memory_budget: float = -1.0, - search_window_size: int = 3): + def __init__(self, region_list: List[Region], memory_budget: float = -1.0, search_window_size: int = 3): super().__init__(region_list, memory_budget) self.search_window_size = search_window_size @@ -331,7 +314,7 @@ def _init_state(self): ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth) ts.execute() self._update_state(ts) - print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB") + print("init peak memory", self.best_ts.peak_mem / 1024**2, "MB") def _call_solver(self): """ @@ -362,8 +345,7 @@ def _call_solver(self): if host_region.bwd_prefetch_region is not None: continue - temp_ts, profit = self._try_to_offload( - host_region, region) + temp_ts, profit = self._try_to_offload(host_region, region) if self._compare_profit(profit, max_prefetch_profit): region_to_region_map[region.r_id] = host_region @@ -464,8 +446,7 @@ def _repair_strategy(self): assert offload_region.need_offload assert not offload_region.is_syn - ts, profit = self._try_convert_to_syn_upload(host_region, - offload_region) + ts, profit = self._try_convert_to_syn_upload(host_region, offload_region) if self._compare_profit(profit, max_profit): undo_host_region = host_region @@ -500,17 +481,13 @@ def _eval_one_choice(self): ts.execute() extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0) - profit = self._compute_offload_profit( - ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) + profit = self._compute_offload_profit(ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost) return ts, profit class SolverFactory: - solvers: Dict[str, Type[Solver]] = { - 'syn': SynGreedySolver, - 'asyn': AsynGreedySolver - } + solvers: Dict[str, Type[Solver]] = {'syn': SynGreedySolver, 'asyn': AsynGreedySolver} @staticmethod def create(solver_name: str) -> Type[Solver]: diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py index f277c183a912..24d9e13b54ba 100644 --- a/colossalai/auto_parallel/offload/training_simulator.py +++ b/colossalai/auto_parallel/offload/training_simulator.py @@ -1,7 +1,7 @@ import bisect -from typing import List, Dict -from collections import OrderedDict from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Dict, List from torch.fx.node import Node @@ -26,10 +26,7 @@ class TrainingSimulator(ABC): link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth. """ - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: self.region_list = region_list self.region_num = len(region_list) @@ -88,10 +85,7 @@ def _get_computing_overhead(self, flop: float) -> float: class SynTrainingSimulator(TrainingSimulator): - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: super().__init__(region_list, comp_power, link_to_bw) def execute(self): @@ -151,8 +145,7 @@ def _eval_bwd_mem_per_region(self, region: Region): self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem self.bwd_node_mem[node] = self.runtime_mem - self.runtime_mem -= (node.meta['bwd_mem_tmp'] + - calculate_fwd_tmp(node)) + self.runtime_mem -= (node.meta['bwd_mem_tmp'] + calculate_fwd_tmp(node)) # free bwd_mem_out self.bwd_node_deps[node] = len(node.all_input_nodes) @@ -178,22 +171,16 @@ def _eval_bwd_mem_per_region(self, region: Region): class AsynTrainingSimulator(TrainingSimulator): - def __init__(self, - region_list: List[Region], - comp_power: float, - link_to_bw: Dict[str, Dict[float, float]]) -> None: + def __init__(self, region_list: List[Region], comp_power: float, link_to_bw: Dict[str, Dict[float, float]]) -> None: super().__init__(region_list, comp_power, link_to_bw) self.iter_end_time: int = 0 # the last computation execution period - self.last_comp: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_comp: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the last parameter prefetch execution period - self.last_h2d: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_h2d: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the last gradient offload execution period - self.last_d2h: ExecutionPeriod = ExecutionPeriod( - start_time=0, end_time=0) + self.last_d2h: ExecutionPeriod = ExecutionPeriod(start_time=0, end_time=0) # the forward computation execution period of the region self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the forward parameter prefetch execution period of the region @@ -204,10 +191,8 @@ def __init__(self, self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the gradient offload execution period of the region # which is divided into those that are waiting and those that have been released - self.bwd_reg_to_offl_waiting: OrderedDict[int, - ExecutionPeriod] = OrderedDict() - self.bwd_reg_to_offl_freed: OrderedDict[int, - ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_waiting: OrderedDict[int, ExecutionPeriod] = OrderedDict() + self.bwd_reg_to_offl_freed: OrderedDict[int, ExecutionPeriod] = OrderedDict() # the region buffer, which records regions that are offloaded but not released self.reg_buffer_to_free: List[int] = [] @@ -217,10 +202,8 @@ def __init__(self, # the region execution flow, # where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU # when the execution reaches the i-th region. - self.fwd_reg_flow = torch.zeros( - (self.region_num, self.region_num)).bool() - self.bwd_reg_flow = torch.zeros( - (self.region_num, self.region_num)).bool() + self.fwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool() + self.bwd_reg_flow = torch.zeros((self.region_num, self.region_num)).bool() def execute(self): """ @@ -249,8 +232,7 @@ def execute(self): self.runtime_mem -= self.region_list[reg_id].param_size self.bwd_reg_to_offl_waiting.clear() - self.iter_end_time = max( - self.last_comp.end_time, self.last_d2h.end_time) + self.iter_end_time = max(self.last_comp.end_time, self.last_d2h.end_time) def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): """ @@ -260,8 +242,7 @@ def _insert_h2d_exec(self, region: Region, is_fwd: bool = True): pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time) pref_end_time = pref_start_time + \ 2.0 * self._get_communication_overhead('h2d', region.param_size) - pref_ep = ExecutionPeriod( - start_time=pref_start_time, end_time=pref_end_time) + pref_ep = ExecutionPeriod(start_time=pref_start_time, end_time=pref_end_time) if is_fwd: self.fwd_reg_to_pref[region.r_id] = pref_ep else: @@ -281,13 +262,11 @@ def _insert_comp_exec(self, region: Region, is_fwd: bool = True): reg_to_comp = self.bwd_reg_to_comp reg_to_pref = self.bwd_reg_to_pref flop_key = 'bwd_flop' - comp_start_time = max(self.last_comp.end_time, reg_to_pref.get( - region.r_id, ExecutionPeriod(0, 0)).end_time) + comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(region.r_id, ExecutionPeriod(0, 0)).end_time) comp_end_time = comp_start_time + \ sum([self._get_computing_overhead(node.meta.get(flop_key, 0)) for node in region.nodes]) - comp_ep = ExecutionPeriod( - start_time=comp_start_time, end_time=comp_end_time) + comp_ep = ExecutionPeriod(start_time=comp_start_time, end_time=comp_end_time) reg_to_comp[region.r_id] = comp_ep self.last_comp = comp_ep @@ -299,8 +278,7 @@ def _insert_d2h_exec(self, region: Region): offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time) offl_end_time = offl_start_time + \ self._get_communication_overhead('d2h', region.param_size) - offl_ep = ExecutionPeriod( - start_time=offl_start_time, end_time=offl_end_time) + offl_ep = ExecutionPeriod(start_time=offl_start_time, end_time=offl_end_time) self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep self.last_d2h = offl_ep @@ -332,16 +310,14 @@ def _eval_fwd_mem_per_region(self, region: Region): self.fwd_reg_flow[region.r_id, region.r_id] = True else: self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1] - self.fwd_reg_flow[region.r_id, - self.reg_buffer_to_free] = False + self.fwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False self.reg_buffer_to_free.clear() # prefetch parameters of the next region fwd_prefetch_region = region.fwd_prefetch_region if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]): self.runtime_mem += fwd_prefetch_region.param_size - self.fwd_reg_flow[region.r_id, - fwd_prefetch_region.r_id] = True + self.fwd_reg_flow[region.r_id, fwd_prefetch_region.r_id] = True for node in region.nodes: self.runtime_mem += calculate_fwd_tmp(node) + \ @@ -354,8 +330,7 @@ def _eval_fwd_mem_per_region(self, region: Region): if region.need_offload: self.runtime_mem -= region.param_size - assert len( - self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}' + assert len(self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}' self.reg_buffer_to_free.append(region.r_id) def _eval_bwd_cost_per_region(self, region: Region): @@ -398,8 +373,7 @@ def _eval_bwd_mem_per_region(self, region: Region): self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1] else: self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1] - self.bwd_reg_flow[region.r_id, - self.reg_buffer_to_free] = False + self.bwd_reg_flow[region.r_id, self.reg_buffer_to_free] = False # free gradients in the buffer while len(self.reg_buffer_to_free): @@ -415,8 +389,7 @@ def _eval_bwd_mem_per_region(self, region: Region): bwd_prefetch_region = region.bwd_prefetch_region if bwd_prefetch_region: self.runtime_mem += bwd_prefetch_region.param_size - self.bwd_reg_flow[region.r_id, - bwd_prefetch_region.r_id] = True + self.bwd_reg_flow[region.r_id, bwd_prefetch_region.r_id] = True # add the gradient of the parameter if region.r_id < region.shared_rid: @@ -437,8 +410,7 @@ def _eval_bwd_mem_per_region(self, region: Region): self.bwd_node_mem[node] = self.runtime_mem - self.runtime_mem -= (node.meta['bwd_mem_tmp'] + - calculate_fwd_tmp(node)) + self.runtime_mem -= (node.meta['bwd_mem_tmp'] + calculate_fwd_tmp(node)) # free bwd_mem_out self.bwd_node_deps[node] = len(node.all_input_nodes) diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py index a99c4eb20225..531a48e794a0 100644 --- a/colossalai/auto_parallel/offload/util.py +++ b/colossalai/auto_parallel/offload/util.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import List + import torch + from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp from .region import Region @@ -12,6 +14,7 @@ class NodeInfo: runtime_fwd_mem: float = 0 runtime_bwd_mem: float = 0 + class NvDevicePower: """ NVIDIA GPU computing performance (TFLOPs). @@ -70,21 +73,24 @@ def compute_act_peak_mem(region_list: List[Region]) -> float: return act_peak_mem + def compute_max_param_mem(region_list: List[Region]) -> float: return max(region.param_size for region in region_list) + def compute_total_param_mem(region_list: List[Region]) -> float: return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid) + def requires_upload_p_in_fwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or ( - shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid + and shared_reg.need_offload) + def requires_release_p_in_bwd(shared_reg: Region): - return (shared_reg.r_id >= shared_reg.shared_rid) or ( - shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload) + return (shared_reg.r_id >= shared_reg.shared_rid) or (shared_reg.r_id < shared_reg.shared_rid + and shared_reg.need_offload) + def requires_offload_g_in_bwd(region: Region): return region.param_size and (region.r_id <= region.shared_rid) - - diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py index b867a30686eb..39799a67c5a0 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/unary_elementwise_generator.py @@ -1,7 +1,7 @@ import copy from typing import List -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from .strategy_generator import FollowingStrategyGenerator diff --git a/colossalai/cli/benchmark/models.py b/colossalai/cli/benchmark/models.py index 38ea54188b8c..f8fd1c41a059 100644 --- a/colossalai/cli/benchmark/models.py +++ b/colossalai/cli/benchmark/models.py @@ -1,4 +1,5 @@ import torch + import colossalai.nn as col_nn diff --git a/colossalai/cli/benchmark/utils.py b/colossalai/cli/benchmark/utils.py index 825b795f21f6..ee7d92d6ea6a 100644 --- a/colossalai/cli/benchmark/utils.py +++ b/colossalai/cli/benchmark/utils.py @@ -1,10 +1,11 @@ import math import time +from typing import Callable, Dict, List, Tuple + import torch +from colossalai.context import Config, ParallelMode from colossalai.utils import MultiTimer -from colossalai.context import ParallelMode, Config -from typing import List, Dict, Tuple, Callable def get_time_stamp() -> int: @@ -25,8 +26,8 @@ def get_memory_states() -> Tuple[float]: Return the memory statistics. Returns: - max_allocated (float): the allocated CUDA memory - max_cached (float): the cached CUDA memory + max_allocated (float): the allocated CUDA memory + max_cached (float): the cached CUDA memory """ max_allocated = torch.cuda.max_memory_allocated() / (1024**3) @@ -101,7 +102,7 @@ def profile_model(model: torch.nn.Module, warmup_steps: int, profile_steps: int, profile_steps (int): the number of steps for profiling data_func (Callable): a function to generate random data timer (colossalai.utils.Multitimer): a timer instance for time recording - + Returns: fwd_time (float): the average forward time taken by forward pass in second bwd_time (float): the average backward time taken by forward pass in second diff --git a/colossalai/cli/check/__init__.py b/colossalai/cli/check/__init__.py index a86b32bb6a18..e2bb5a6aa18d 100644 --- a/colossalai/cli/check/__init__.py +++ b/colossalai/cli/check/__init__.py @@ -1,4 +1,5 @@ import click + from .check_installation import check_installation __all__ = ['check'] diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py index 220481b7af15..88ad0487b785 100644 --- a/colossalai/communication/__init__.py +++ b/colossalai/communication/__init__.py @@ -1,9 +1,17 @@ -from .collective import all_gather, reduce_scatter, all_reduce, broadcast, reduce -from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, send_backward, - send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, - recv_forward, recv_backward) +from .collective import all_gather, all_reduce, broadcast, reduce, reduce_scatter +from .p2p import ( + recv_backward, + recv_forward, + send_backward, + send_backward_recv_backward, + send_backward_recv_forward, + send_forward, + send_forward_backward_recv_forward_backward, + send_forward_recv_backward, + send_forward_recv_forward, +) from .ring import ring_forward -from .utils import send_obj_meta, recv_obj_meta +from .utils import recv_obj_meta, send_obj_meta __all__ = [ 'all_gather', diff --git a/colossalai/communication/p2p.py b/colossalai/communication/p2p.py index 6dd4d0d6608d..750fa21c017d 100644 --- a/colossalai/communication/p2p.py +++ b/colossalai/communication/p2p.py @@ -1,16 +1,18 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import operator +from functools import reduce from typing import List, Tuple, Union + import torch import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device -from functools import reduce -import operator -from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor + +from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks TensorShape = Union[torch.Size, List[int], Tuple[int]] @@ -260,7 +262,7 @@ def send_forward_recv_backward(output_tensor, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the input tensor to the + """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the gradient tensor from the next stage in pipeline as the input gradient tensor of this stage. @@ -319,7 +321,7 @@ def send_forward_recv_forward(output_tensor, next_rank=None, dtype=torch.float, scatter_gather_tensors=False) -> Union[torch.Tensor, List[torch.Tensor]]: - """Batched communication operation. Sends the input tensor to the + """Batched communication operation. Sends the input tensor to the next stage in pipeline, while receives the output tensor from the previous stage in pipeline as the input of this stage. diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py index ef9eceea847d..1516df356278 100644 --- a/colossalai/communication/utils.py +++ b/colossalai/communication/utils.py @@ -1,10 +1,11 @@ +from typing import List, Tuple, Union + import torch import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device -from typing import Union, List, Tuple TensorShape = Union[torch.Size, List[int], Tuple[int]] diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index 50178b5fa850..08ef4e35fe2d 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -1,6 +1,6 @@ from .config import Config, ConfigException +from .moe_context import MOE_CONTEXT from .parallel_context import ParallelContext from .parallel_mode import ParallelMode -from .moe_context import MOE_CONTEXT from .process_group_initializer import * from .random import * diff --git a/colossalai/context/config.py b/colossalai/context/config.py index 8903707708df..41a6c77fe57d 100644 --- a/colossalai/context/config.py +++ b/colossalai/context/config.py @@ -5,6 +5,7 @@ import sys from importlib.machinery import SourceFileLoader from pathlib import Path + from colossalai.logging import get_dist_logger diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index dd12dad6d347..0cd533fdef1a 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -10,15 +10,16 @@ import numpy as np import torch import torch.distributed as dist + from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING from colossalai.context.config import Config +from colossalai.context.singleton_meta import SingletonMeta from colossalai.global_variables import tensor_parallel_env as env from colossalai.logging import get_dist_logger from colossalai.registry import DIST_GROUP_INITIALIZER from .parallel_mode import ParallelMode from .random import add_seed, get_seeds, set_mode -from colossalai.context.singleton_meta import SingletonMeta class ParallelContext(metaclass=SingletonMeta): diff --git a/colossalai/context/process_group_initializer/__init__.py b/colossalai/context/process_group_initializer/__init__.py index d3937a947437..48d52d7b9e52 100644 --- a/colossalai/context/process_group_initializer/__init__.py +++ b/colossalai/context/process_group_initializer/__init__.py @@ -3,10 +3,10 @@ from .initializer_2p5d import Initializer_2p5D from .initializer_3d import Initializer_3D from .initializer_data import Initializer_Data +from .initializer_model import Initializer_Model from .initializer_pipeline import Initializer_Pipeline from .initializer_sequence import Initializer_Sequence from .initializer_tensor import Initializer_Tensor -from .initializer_model import Initializer_Model from .process_group_initializer import ProcessGroupInitializer __all__ = [ diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py index 4c05028041ce..ea5c2d56db85 100644 --- a/colossalai/context/process_group_initializer/initializer_1d.py +++ b/colossalai/context/process_group_initializer/initializer_1d.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import torch.distributed as dist + from colossalai.global_variables import tensor_parallel_env as env from colossalai.registry import DIST_GROUP_INITIALIZER diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py index 6b6fdc5d715c..635c9971ca40 100644 --- a/colossalai/context/process_group_initializer/initializer_2p5d.py +++ b/colossalai/context/process_group_initializer/initializer_2p5d.py @@ -4,6 +4,7 @@ import math import torch.distributed as dist + from colossalai.context import Config from colossalai.global_variables import tensor_parallel_env as env from colossalai.registry import DIST_GROUP_INITIALIZER diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py index b752b8f45654..1ed8eec86efc 100644 --- a/colossalai/context/process_group_initializer/initializer_3d.py +++ b/colossalai/context/process_group_initializer/initializer_3d.py @@ -4,6 +4,7 @@ import math import torch.distributed as dist + from colossalai.global_variables import tensor_parallel_env as env from colossalai.registry import DIST_GROUP_INITIALIZER @@ -213,7 +214,8 @@ def init_dist_group(self): for h in range(self.num_group): for k in range(self.depth): ranks = [ - h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth) + h * self.depth**3 + i + self.depth * (j + self.depth * k) + for j in range(self.depth) for i in range(self.depth) ] group = dist.new_group(ranks) @@ -266,7 +268,8 @@ def init_dist_group(self): for h in range(self.num_group): for j in range(self.depth): ranks = [ - h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth) + h * self.depth**3 + i + self.depth * (j + self.depth * k) + for k in range(self.depth) for i in range(self.depth) ] group = dist.new_group(ranks) diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/context/process_group_initializer/initializer_data.py index 0b8b0d91fcb9..9715ebff7f00 100644 --- a/colossalai/context/process_group_initializer/initializer_data.py +++ b/colossalai/context/process_group_initializer/initializer_data.py @@ -4,8 +4,9 @@ from torch import distributed as dist from colossalai.registry import DIST_GROUP_INITIALIZER -from .process_group_initializer import ProcessGroupInitializer + from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer @DIST_GROUP_INITIALIZER.register_module diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/context/process_group_initializer/initializer_model.py index 99b9cc0d4edc..bf42108ed893 100644 --- a/colossalai/context/process_group_initializer/initializer_model.py +++ b/colossalai/context/process_group_initializer/initializer_model.py @@ -2,9 +2,11 @@ # -*- encoding: utf-8 -*- import torch.distributed as dist + from colossalai.registry import DIST_GROUP_INITIALIZER -from .process_group_initializer import ProcessGroupInitializer + from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer @DIST_GROUP_INITIALIZER.register_module diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/context/process_group_initializer/initializer_tensor.py index d2b5be9cfffb..79963408a889 100644 --- a/colossalai/context/process_group_initializer/initializer_tensor.py +++ b/colossalai/context/process_group_initializer/initializer_tensor.py @@ -4,8 +4,9 @@ import torch.distributed as dist from colossalai.registry import DIST_GROUP_INITIALIZER -from .process_group_initializer import ProcessGroupInitializer + from ..parallel_mode import ParallelMode +from .process_group_initializer import ProcessGroupInitializer @DIST_GROUP_INITIALIZER.register_module diff --git a/colossalai/context/random/__init__.py b/colossalai/context/random/__init__.py index 422c3676c09d..d64b993257c1 100644 --- a/colossalai/context/random/__init__.py +++ b/colossalai/context/random/__init__.py @@ -1,5 +1,16 @@ -from ._helper import (seed, set_mode, with_seed, add_seed, get_seeds, get_states, get_current_mode, set_seed_states, - sync_states, moe_set_seed, reset_seeds) +from ._helper import ( + add_seed, + get_current_mode, + get_seeds, + get_states, + moe_set_seed, + reset_seeds, + seed, + set_mode, + set_seed_states, + sync_states, + with_seed, +) __all__ = [ 'seed', 'set_mode', 'with_seed', 'add_seed', 'get_seeds', 'get_states', 'get_current_mode', 'set_seed_states', diff --git a/colossalai/context/random/_helper.py b/colossalai/context/random/_helper.py index 973c4d9faa32..fbe9a3123c2e 100644 --- a/colossalai/context/random/_helper.py +++ b/colossalai/context/random/_helper.py @@ -7,8 +7,8 @@ import torch.cuda from torch import Tensor -from .seed_manager import SeedManager from ..parallel_mode import ParallelMode +from .seed_manager import SeedManager _SEED_MANAGER = SeedManager() diff --git a/colossalai/core.py b/colossalai/core.py index 153247bbed9c..91a753496bc3 100644 --- a/colossalai/core.py +++ b/colossalai/core.py @@ -3,4 +3,4 @@ from colossalai.context.parallel_context import global_context -__all__ = ['global_context'] \ No newline at end of file +__all__ = ['global_context'] diff --git a/colossalai/engine/gradient_accumulation/__init__.py b/colossalai/engine/gradient_accumulation/__init__.py index 4585b9a2529c..4cb6f4ad7384 100644 --- a/colossalai/engine/gradient_accumulation/__init__.py +++ b/colossalai/engine/gradient_accumulation/__init__.py @@ -1,10 +1,17 @@ +from typing import Iterable, List + import torch.nn as nn -from typing import List -from colossalai.engine import BaseGradientHandler -from typing import Iterable from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler -from ._gradient_accumulation import GradAccumDataloader, GradAccumOptimizer, GradAccumLrSchedulerByStep, GradAccumGradientHandler + +from colossalai.engine import BaseGradientHandler + +from ._gradient_accumulation import ( + GradAccumDataloader, + GradAccumGradientHandler, + GradAccumLrSchedulerByStep, + GradAccumOptimizer, +) __all__ = [ 'accumulate_gradient', 'GradAccumDataloader', 'GradAccumOptimizer', 'GradAccumLrSchedulerByStep', diff --git a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py index 89c28c3be87a..cf66be1cd821 100644 --- a/colossalai/engine/gradient_accumulation/_gradient_accumulation.py +++ b/colossalai/engine/gradient_accumulation/_gradient_accumulation.py @@ -1,21 +1,22 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Union +from typing import Any, Iterable, Tuple, Union + import torch.nn as nn from torch import Tensor -from typing import Iterable, Any, Tuple -from colossalai.nn.optimizer import ColossalaiOptimizer from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from colossalai.utils import conditional_context + from colossalai.engine import BaseGradientHandler +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.utils import conditional_context class GradAccumOptimizer(ColossalaiOptimizer): - """A wrapper for the optimizer to enable gradient accumulation by skipping the steps + """A wrapper for the optimizer to enable gradient accumulation by skipping the steps before accumulation size is reached. Args: @@ -161,7 +162,7 @@ def __next__(self) -> Union[Tensor, Tuple[Tensor]]: class GradAccumLrSchedulerByStep(_LRScheduler): - """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps + """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps before accumulation size is reached. Args: diff --git a/colossalai/engine/gradient_handler/__init__.py b/colossalai/engine/gradient_handler/__init__.py index 6177da69ba5b..2dea768bad7e 100644 --- a/colossalai/engine/gradient_handler/__init__.py +++ b/colossalai/engine/gradient_handler/__init__.py @@ -1,10 +1,9 @@ from ._base_gradient_handler import BaseGradientHandler from ._data_parallel_gradient_handler import DataParallelGradientHandler -from ._zero_gradient_handler import ZeROGradientHandler -from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler -from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._moe_gradient_handler import MoeGradientHandler +from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler +from ._zero_gradient_handler import ZeROGradientHandler __all__ = [ 'BaseGradientHandler', 'DataParallelGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', diff --git a/colossalai/engine/gradient_handler/_base_gradient_handler.py b/colossalai/engine/gradient_handler/_base_gradient_handler.py index c212359867d1..7d96dd8a88a6 100644 --- a/colossalai/engine/gradient_handler/_base_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_base_gradient_handler.py @@ -5,7 +5,7 @@ class BaseGradientHandler(ABC): - """A basic helper class to handle all-reduce operations of gradients across different parallel groups + """A basic helper class to handle all-reduce operations of gradients across different parallel groups before optimization. Args: diff --git a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py index d113fc516459..5cc7169c5a9f 100644 --- a/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_data_parallel_gradient_handler.py @@ -1,16 +1,17 @@ from colossalai.core import global_context as gpc from colossalai.registry import GRADIENT_HANDLER -from ._base_gradient_handler import BaseGradientHandler + from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce @GRADIENT_HANDLER.register_module class DataParallelGradientHandler(BaseGradientHandler): """A helper class to handle all-reduce operations in a data parallel group. - A all-reduce collective communication will be operated in + A all-reduce collective communication will be operated in :func:`handle_gradient` among a data parallel group. - For better performance, it bucketizes the gradients of all parameters that are + For better performance, it bucketizes the gradients of all parameters that are the same type to improve the efficiency of communication. Args: diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py index 02cea5e67a12..b499345d4e18 100644 --- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_moe_gradient_handler.py @@ -1,45 +1,46 @@ -from colossalai.core import global_context as gpc -from colossalai.registry import GRADIENT_HANDLER -from colossalai.utils.moe import get_moe_epsize_param_dict -from ._base_gradient_handler import BaseGradientHandler -from ...context.parallel_mode import ParallelMode -from .utils import bucket_allreduce -from colossalai.context.moe_context import MOE_CONTEXT - - -@GRADIENT_HANDLER.register_module -class MoeGradientHandler(BaseGradientHandler): - """A helper class to handle all-reduce operations in a data parallel group and - moe model parallel. A all-reduce collective communication will be operated in - :func:`handle_gradient` among a data parallel group. - For better performance, it bucketizes the gradients of all parameters that are - the same type to improve the efficiency of communication. - - Args: - model (Module): Model where the gradients accumulate. - optimizer (Optimizer): Optimizer for updating the parameters. - """ - - def __init__(self, model, optimizer=None): - super().__init__(model, optimizer) - - def handle_gradient(self): - """A method running an all-reduce operation in a data parallel group. - Then running an all-reduce operation for all parameters in experts - across moe model parallel group - """ - global_data = gpc.data_parallel_size - - if global_data > 1: - epsize_param_dict = get_moe_epsize_param_dict(self._model) - - # epsize is 1, indicating the params are replicated among processes in data parallelism - # use the ParallelMode.DATA to get data parallel group - # reduce gradients for all parameters in data parallelism - if 1 in epsize_param_dict: - bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in epsize_param_dict: - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - bucket_allreduce(param_list=epsize_param_dict[ep_size], - group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.core import global_context as gpc +from colossalai.registry import GRADIENT_HANDLER +from colossalai.utils.moe import get_moe_epsize_param_dict + +from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler +from .utils import bucket_allreduce + + +@GRADIENT_HANDLER.register_module +class MoeGradientHandler(BaseGradientHandler): + """A helper class to handle all-reduce operations in a data parallel group and + moe model parallel. A all-reduce collective communication will be operated in + :func:`handle_gradient` among a data parallel group. + For better performance, it bucketizes the gradients of all parameters that are + the same type to improve the efficiency of communication. + + Args: + model (Module): Model where the gradients accumulate. + optimizer (Optimizer): Optimizer for updating the parameters. + """ + + def __init__(self, model, optimizer=None): + super().__init__(model, optimizer) + + def handle_gradient(self): + """A method running an all-reduce operation in a data parallel group. + Then running an all-reduce operation for all parameters in experts + across moe model parallel group + """ + global_data = gpc.data_parallel_size + + if global_data > 1: + epsize_param_dict = get_moe_epsize_param_dict(self._model) + + # epsize is 1, indicating the params are replicated among processes in data parallelism + # use the ParallelMode.DATA to get data parallel group + # reduce gradients for all parameters in data parallelism + if 1 in epsize_param_dict: + bucket_allreduce(param_list=epsize_param_dict[1], group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in epsize_param_dict: + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + bucket_allreduce(param_list=epsize_param_dict[ep_size], + group=MOE_CONTEXT.parallel_info_dict[ep_size].dp_group) diff --git a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py index 83f5c00cf2af..5b49a9c0360d 100644 --- a/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_pipeline_parallel_gradient_handler.py @@ -4,9 +4,10 @@ import torch import torch.distributed as dist +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + from colossalai.core import global_context as gpc from colossalai.registry import GRADIENT_HANDLER -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from ._base_gradient_handler import BaseGradientHandler @@ -14,9 +15,9 @@ @GRADIENT_HANDLER.register_module class PipelineSharedModuleGradientHandler(BaseGradientHandler): """A helper class to handle all-reduce operations in sub parallel groups. - A all-reduce collective communication will be operated in + A all-reduce collective communication will be operated in :func:`handle_gradient` among all sub pipeline parallel groups. - For better performance, it bucketizes the gradients of all parameters that are + For better performance, it bucketizes the gradients of all parameters that are the same type to improve the efficiency of communication. Args: diff --git a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py index 53a8ea935a42..ea4f0fbb1c71 100644 --- a/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_sequence_parallel_gradient_handler.py @@ -1,16 +1,17 @@ from colossalai.core import global_context as gpc from colossalai.registry import GRADIENT_HANDLER -from ._base_gradient_handler import BaseGradientHandler + from ...context.parallel_mode import ParallelMode +from ._base_gradient_handler import BaseGradientHandler from .utils import bucket_allreduce @GRADIENT_HANDLER.register_module class SequenceParallelGradientHandler(BaseGradientHandler): """A helper class to handle all-reduce operations in a data parallel group. - A all-reduce collective communication will be operated in + A all-reduce collective communication will be operated in :func:`handle_gradient` among a data parallel group. - For better performance, it bucketizes the gradients of all parameters that are + For better performance, it bucketizes the gradients of all parameters that are the same type to improve the efficiency of communication. Args: diff --git a/colossalai/engine/gradient_handler/_zero_gradient_handler.py b/colossalai/engine/gradient_handler/_zero_gradient_handler.py index f85303e75184..19fd1e97f86f 100644 --- a/colossalai/engine/gradient_handler/_zero_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_zero_gradient_handler.py @@ -1,4 +1,5 @@ from colossalai.registry import GRADIENT_HANDLER + from ._base_gradient_handler import BaseGradientHandler diff --git a/colossalai/engine/schedule/__init__.py b/colossalai/engine/schedule/__init__.py index 54170286e99b..0f2c039d7057 100644 --- a/colossalai/engine/schedule/__init__.py +++ b/colossalai/engine/schedule/__init__.py @@ -1,5 +1,5 @@ from ._base_schedule import BaseSchedule -from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape from ._non_pipeline_schedule import NonPipelineSchedule +from ._pipeline_schedule import InterleavedPipelineSchedule, PipelineSchedule, get_tensor_shape __all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule', 'get_tensor_shape'] diff --git a/colossalai/engine/schedule/_base_schedule.py b/colossalai/engine/schedule/_base_schedule.py index ba797bad9778..a2d50041127a 100644 --- a/colossalai/engine/schedule/_base_schedule.py +++ b/colossalai/engine/schedule/_base_schedule.py @@ -2,10 +2,10 @@ # -*- encoding: utf-8 -*- from abc import ABC, abstractmethod +from typing import Callable, Iterable import torch -from typing import Iterable, Callable from colossalai.logging import get_dist_logger from colossalai.utils import get_current_device diff --git a/colossalai/engine/schedule/_non_pipeline_schedule.py b/colossalai/engine/schedule/_non_pipeline_schedule.py index c62bfb7d7375..b9239d928a7b 100644 --- a/colossalai/engine/schedule/_non_pipeline_schedule.py +++ b/colossalai/engine/schedule/_non_pipeline_schedule.py @@ -1,13 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Iterable +import inspect +from typing import Callable, Iterable import torch -import inspect -from ._base_schedule import BaseSchedule + from colossalai.utils import conditional_context -from typing import Callable + +from ._base_schedule import BaseSchedule class NonPipelineSchedule(BaseSchedule): diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/engine/schedule/_pipeline_schedule_v2.py index 50a87aafad02..28c58bd82b5c 100644 --- a/colossalai/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/engine/schedule/_pipeline_schedule_v2.py @@ -1,11 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Tuple, Iterable +from typing import Iterable, Tuple -from colossalai import engine -import colossalai.communication.p2p_v2 as comm import torch.cuda + +import colossalai.communication.p2p_v2 as comm +from colossalai import engine from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils.cuda import get_current_device @@ -35,7 +36,7 @@ def pack_return_tensors(return_tensors): class PipelineScheduleV2(PipelineSchedule): """Derived class of PipelineSchedule, the only difference is that forward_backward_step is reconstructed with p2p_v2 - + Args: num_microbatches (int): The number of microbatches. data_process_func (Callable, optional): @@ -43,9 +44,9 @@ class PipelineScheduleV2(PipelineSchedule): tensor_shape (torch.Size, optional): Specified shape in pipeline communication. scatter_gather_tensors (bool, optional): If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. - + Example: - + # this shows an example of customized data_process_func def data_process_func(stage_output, dataloader_output): output1, output2 = stage_output diff --git a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py index f28d65e2668a..cb31eb424c9a 100644 --- a/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py +++ b/colossalai/fx/passes/experimental/adding_shape_consistency_pass.py @@ -1,14 +1,16 @@ -import torch +import builtins +import operator +from copy import deepcopy from typing import List + +import torch from torch.fx import symbolic_trace from torch.fx.node import Node + +from colossalai.device.device_mesh import DeviceMesh from colossalai.fx.passes.split_module import split_module from colossalai.tensor.shape_consistency import ShapeConsistencyManager -from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec -import builtins -import operator -from copy import deepcopy def apply(*args, **kwargs): diff --git a/colossalai/fx/passes/passes_for_gpt2_test.py b/colossalai/fx/passes/passes_for_gpt2_test.py index f98fcd686ea4..abc1a089e9a9 100644 --- a/colossalai/fx/passes/passes_for_gpt2_test.py +++ b/colossalai/fx/passes/passes_for_gpt2_test.py @@ -1,14 +1,15 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional + import torch -from torch.fx.graph_module import GraphModule -from typing import Callable, List, Dict, Any, Optional -from torch.fx._compatibility import compatibility from packaging import version +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node + +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, pipe_split from colossalai.fx.passes.meta_info_prop import TensorMetadata -import inspect -from typing import List from colossalai.fx.passes.split_module import Partition -from colossalai.fx.passes.adding_split_node_pass import pipe_split, balanced_split_pass -from torch.fx.node import Node def customized_split_pass_for_gpt2(gm: torch.fx.GraphModule, pp_size: int, partition_list: List[int]): diff --git a/colossalai/fx/passes/shard_1d_pass.py b/colossalai/fx/passes/shard_1d_pass.py index d2bad06bb45a..9e262db5375a 100644 --- a/colossalai/fx/passes/shard_1d_pass.py +++ b/colossalai/fx/passes/shard_1d_pass.py @@ -1,9 +1,11 @@ +import operator + import torch import torch.nn as nn -import operator + from colossalai.tensor import ProcessGroup -from colossalai.tensor.distspec import ShardSpec from colossalai.tensor.compute_spec import ComputePattern, ComputeSpec +from colossalai.tensor.distspec import ShardSpec ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_FUNC_OP = [ @@ -13,7 +15,7 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int, col_normal: bool) -> torch.nn.parameter.Parameter: - """weight_split + """weight_split split a nn.Parameter Args: @@ -60,7 +62,7 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule): def transformer_mlp_pass(graph_module: torch.fx.GraphModule, process_group: ProcessGroup): """ - This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. + This IR pass checks for transformer MLP like structure and annotate column and row sharding to the linear layers. """ #TODO: Needs to handle special cases, like x = linear(x) + linear(x) graph = graph_module.graph diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index bc257edc8c89..9bc4bf1f5c42 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -1,9 +1,10 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional + import torch -from torch.fx.graph_module import GraphModule -from typing import Callable, List, Dict, Any, Optional -from torch.fx._compatibility import compatibility from packaging import version -import inspect +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule @compatibility(is_backward_compatible=True) @@ -38,7 +39,7 @@ def split_module( m: GraphModule, root_m: torch.nn.Module, split_callback: Callable[[torch.fx.node.Node], int], - merge_output = False, + merge_output=False, ): """ Adapted from https://github.com/pytorch/pytorch/blob/master/torch/fx/passes/split_module.py @@ -132,10 +133,8 @@ def record_cross_partition_use(def_node: torch.fx.node.Node, use_partition.inputs.setdefault(def_node.name) if def_partition_name is not None: use_partition.partitions_dependent_on.setdefault(def_partition_name) - - def record_output( - def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node] - ): # noqa: B950 + + def record_output(def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]): # noqa: B950 def_partition_name = getattr(def_node, "_fx_partition", None) use_partition_name = getattr(use_node, "_fx_partition", None) if def_partition_name != use_partition_name: @@ -291,7 +290,7 @@ def record_output( for partition_name in sorted_partitions: partition = partitions[partition_name] - + new_gm = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) return new_gm diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index bb4f3cd6a490..93554c828840 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -1,7 +1,9 @@ -import torch from typing import Dict -from torch.fx.node import Node, map_arg + +import torch from torch.fx.graph import Graph +from torch.fx.node import Node, map_arg + def get_comm_size(prev_partition, next_partition): """ @@ -83,7 +85,7 @@ def is_top(graph: Graph, node: Node): def get_all_consumers(graph: Graph, node: Node): """ Given a graph and a node of this graph, return all consumers of the node. - + Returns: List of ``Nodes`` that node appear in these nodes ``args`` and ``kwargs``. """ @@ -120,7 +122,7 @@ def forward(self, x): for node in gm.graph.nodes: if hasattr(node, 'bfs_level'): print(node.name, node.bfs_level) - + Output: graph(): %x : [#users=2] = placeholder[target=x] @@ -169,4 +171,3 @@ def get_node_module(node) -> torch.nn.Module: assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' module = node.graph.owning_module.get_submodule(node.target) return module - diff --git a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py index a43aef063e19..c518ec28da41 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/activation_function.py +++ b/colossalai/fx/profiler/experimental/profiler_function/activation_function.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_function # TODO: different activation has different FLOPs count, currently unused. diff --git a/colossalai/fx/profiler/experimental/profiler_function/embedding.py b/colossalai/fx/profiler/experimental/profiler_function/embedding.py index d6e43d781b8b..1d362015fc8b 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/embedding.py +++ b/colossalai/fx/profiler/experimental/profiler_function/embedding.py @@ -1,5 +1,7 @@ -import torch from typing import Optional + +import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/linear.py b/colossalai/fx/profiler/experimental/profiler_function/linear.py index 01fe4c871370..ecc578d61b91 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/linear.py +++ b/colossalai/fx/profiler/experimental/profiler_function/linear.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/normalization.py b/colossalai/fx/profiler/experimental/profiler_function/normalization.py index c4ea508d70f8..d5a50ef1ed47 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/normalization.py +++ b/colossalai/fx/profiler/experimental/profiler_function/normalization.py @@ -1,5 +1,7 @@ from typing import List, Optional, Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/pooling.py b/colossalai/fx/profiler/experimental/profiler_function/pooling.py index a639f5ee83c1..75d0f10b3a83 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/pooling.py +++ b/colossalai/fx/profiler/experimental/profiler_function/pooling.py @@ -1,5 +1,7 @@ from typing import Tuple, Union + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py index 1e8561206ba0..705483f18e82 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/python_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/python_ops.py @@ -1,6 +1,8 @@ import operator from typing import Any, Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py index abdd7ad565ba..eed444ce7da8 100644 --- a/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py +++ b/colossalai/fx/profiler/experimental/profiler_function/torch_ops.py @@ -1,7 +1,9 @@ -from functools import reduce import operator +from functools import reduce from typing import Any, Optional, Tuple + import torch + from ..registry import meta_profiler_function diff --git a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py index 2ebf514ad269..ae065e0c7c17 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/activation_function.py +++ b/colossalai/fx/profiler/experimental/profiler_module/activation_function.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module # TODO: different activation has different FLOPs count, currently unused. diff --git a/colossalai/fx/profiler/experimental/profiler_module/attention.py b/colossalai/fx/profiler/experimental/profiler_module/attention.py index 8daf74b232bf..ba9f264d732a 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/attention.py +++ b/colossalai/fx/profiler/experimental/profiler_module/attention.py @@ -1,5 +1,7 @@ from typing import Optional, Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/dropout.py b/colossalai/fx/profiler/experimental/profiler_module/dropout.py index 417e0ed46863..7361239eb1bd 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/dropout.py +++ b/colossalai/fx/profiler/experimental/profiler_module/dropout.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/embedding.py b/colossalai/fx/profiler/experimental/profiler_module/embedding.py index dca6f9453af3..a1ade5d3ad93 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/embedding.py +++ b/colossalai/fx/profiler/experimental/profiler_module/embedding.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module @@ -8,4 +10,4 @@ def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[i # nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) flops = 0 macs = 0 - return flops, macs \ No newline at end of file + return flops, macs diff --git a/colossalai/fx/profiler/experimental/profiler_module/linear.py b/colossalai/fx/profiler/experimental/profiler_module/linear.py index e1ffb6f244d2..71fed3196c13 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/linear.py +++ b/colossalai/fx/profiler/experimental/profiler_module/linear.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/pooling.py b/colossalai/fx/profiler/experimental/profiler_module/pooling.py index e429ac3eea28..b3b630b2dee9 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/pooling.py +++ b/colossalai/fx/profiler/experimental/profiler_module/pooling.py @@ -1,5 +1,7 @@ from typing import Tuple + import torch + from ..registry import meta_profiler_module diff --git a/colossalai/fx/profiler/experimental/profiler_module/rnn.py b/colossalai/fx/profiler/experimental/profiler_module/rnn.py index 6e733d6da915..f2186f8bdc8a 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/rnn.py +++ b/colossalai/fx/profiler/experimental/profiler_module/rnn.py @@ -1,8 +1,10 @@ -from functools import reduce import operator +from functools import reduce +from typing import Optional, Tuple, Union + import torch + from ..registry import meta_profiler_module -from typing import Optional, Tuple, Union def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, diff --git a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py index d3aed874eb10..a04890d24c23 100644 --- a/colossalai/fx/profiler/experimental/profiler_module/torch_op.py +++ b/colossalai/fx/profiler/experimental/profiler_module/torch_op.py @@ -1,7 +1,9 @@ import operator +from typing import Optional, Tuple, Union + import torch + from ..registry import meta_profiler_module -from typing import Optional, Tuple, Union @meta_profiler_module.register(torch.nn.Flatten) diff --git a/colossalai/fx/proxy.py b/colossalai/fx/proxy.py index 06272c48f852..7317072c6298 100644 --- a/colossalai/fx/proxy.py +++ b/colossalai/fx/proxy.py @@ -1,7 +1,9 @@ import operator +from typing import Any, List, Union + import torch -from torch.fx.proxy import Proxy, Attribute -from typing import List, Union, Any +from torch.fx.proxy import Attribute, Proxy + from colossalai.fx.tracer.meta_patch import meta_patched_function __all__ = ['ColoProxy'] diff --git a/colossalai/fx/tracer/_tracer_utils.py b/colossalai/fx/tracer/_tracer_utils.py index 0ec49a90a133..e160497a7444 100644 --- a/colossalai/fx/tracer/_tracer_utils.py +++ b/colossalai/fx/tracer/_tracer_utils.py @@ -1,6 +1,8 @@ -from typing import List, Union, Any -from ..proxy import ColoProxy, ColoAttribute +from typing import Any, List, Union + import torch + +from ..proxy import ColoAttribute, ColoProxy from .meta_patch import meta_patched_function, meta_patched_module __all__ = ['is_element_in_list', 'extract_meta'] diff --git a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py index e28e52585fff..3f40ec2a67ee 100644 --- a/colossalai/fx/tracer/meta_patch/patched_module/__init__.py +++ b/colossalai/fx/tracer/meta_patch/patched_module/__init__.py @@ -4,4 +4,4 @@ from .linear import * from .normalization import * from .pooling import * -from .rnn import * \ No newline at end of file +from .rnn import * diff --git a/colossalai/gemini/ophooks/_shard_param_ophook.py b/colossalai/gemini/ophooks/_shard_param_ophook.py index 57f76970cc86..80736d14085e 100644 --- a/colossalai/gemini/ophooks/_shard_param_ophook.py +++ b/colossalai/gemini/ophooks/_shard_param_ophook.py @@ -1,4 +1,5 @@ import torch + from colossalai.registry import OPHOOKS from . import BaseOpHook diff --git a/colossalai/gemini/paramhooks/_param_hookmgr.py b/colossalai/gemini/paramhooks/_param_hookmgr.py index ee57cb46a90d..84f32be358e3 100644 --- a/colossalai/gemini/paramhooks/_param_hookmgr.py +++ b/colossalai/gemini/paramhooks/_param_hookmgr.py @@ -1,6 +1,7 @@ +import functools from typing import Callable, List + import torch -import functools class BaseParamHookMgr(object): diff --git a/colossalai/gemini/stateful_tensor.py b/colossalai/gemini/stateful_tensor.py index 18fc8fd14d3c..cb8d52f68dfb 100644 --- a/colossalai/gemini/stateful_tensor.py +++ b/colossalai/gemini/stateful_tensor.py @@ -1,7 +1,7 @@ from enum import Enum -from typing import Optional +from typing import Optional, Union + import torch -from typing import Union from colossalai.gemini.gemini_context import GeminiMemoryManager @@ -19,7 +19,7 @@ class TensorState(Enum): class StatefulTensor(object): - """A Structure stores a Torch Tensor and labeled states. + """A Structure stores a Torch Tensor and labeled states. Inspired from the paper: PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management diff --git a/colossalai/gemini/stateful_tensor_mgr.py b/colossalai/gemini/stateful_tensor_mgr.py index c300f9bffc89..0feaf1de46c3 100644 --- a/colossalai/gemini/stateful_tensor_mgr.py +++ b/colossalai/gemini/stateful_tensor_mgr.py @@ -1,13 +1,15 @@ import functools -import torch import types -from colossalai.utils.cuda import get_current_device -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from time import time +from typing import List + +import torch + from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy -from typing import List +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.logging import get_dist_logger -from time import time +from colossalai.utils.cuda import get_current_device class StatefulTensorMgr(object): diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/gemini/tensor_placement_policy.py index cfcfb385667c..0e575254c0b6 100644 --- a/colossalai/gemini/tensor_placement_policy.py +++ b/colossalai/gemini/tensor_placement_policy.py @@ -1,15 +1,15 @@ +import functools from abc import ABC, abstractmethod from time import time -from typing import List, Optional +from typing import List, Optional, Type + import torch -from colossalai.utils import get_current_device -from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.gemini.memory_tracer import MemStatsCollector -from typing import Type -import functools +from colossalai.gemini.stateful_tensor import StatefulTensor +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.utils import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity class TensorPlacementPolicy(ABC): diff --git a/colossalai/gemini/tensor_utils.py b/colossalai/gemini/tensor_utils.py index bcc159f9954a..72ab6a3e0018 100644 --- a/colossalai/gemini/tensor_utils.py +++ b/colossalai/gemini/tensor_utils.py @@ -1,6 +1,8 @@ +from typing import Tuple, Union + import torch + from colossalai.gemini.stateful_tensor import StatefulTensor -from typing import Union, Tuple def is_storage_empty(tensor: torch.Tensor) -> bool: diff --git a/colossalai/global_variables.py b/colossalai/global_variables.py index e3575ea12ad0..61b31965e2e6 100644 --- a/colossalai/global_variables.py +++ b/colossalai/global_variables.py @@ -1,56 +1,56 @@ -from typing import Optional - - -class TensorParallelEnv(object): - _instance = None - - def __new__(cls, *args, **kwargs): - if cls._instance is None: - cls._instance = object.__new__(cls, *args, **kwargs) - return cls._instance - - def __init__(self, *args, **kwargs): - self.load(*args, **kwargs) - - def load(self, - mode: Optional[str] = None, - vocab_parallel: bool = False, - parallel_input_1d: bool = False, - summa_dim: int = None, - tesseract_dim: int = None, - tesseract_dep: int = None, - depth_3d: int = None, - input_group_3d=None, - weight_group_3d=None, - output_group_3d=None, - input_x_weight_group_3d=None, - output_x_weight_group_3d=None): - self.mode = mode - self.vocab_parallel = vocab_parallel - self.parallel_input_1d = parallel_input_1d - self.summa_dim = summa_dim - self.tesseract_dim = tesseract_dim - self.tesseract_dep = tesseract_dep - self.depth_3d = depth_3d - self.input_group_3d = input_group_3d - self.weight_group_3d = weight_group_3d - self.output_group_3d = output_group_3d - self.input_x_weight_group_3d = input_x_weight_group_3d - self.output_x_weight_group_3d = output_x_weight_group_3d - - def save(self): - return dict(mode=self.mode, - vocab_parallel=self.vocab_parallel, - parallel_input_1d=self.parallel_input_1d, - summa_dim=self.summa_dim, - tesseract_dim=self.tesseract_dim, - tesseract_dep=self.tesseract_dep, - depth_3d=self.depth_3d, - input_group_3d=self.input_group_3d, - weight_group_3d=self.weight_group_3d, - output_group_3d=self.output_group_3d, - input_x_weight_group_3d=self.input_x_weight_group_3d, - output_x_weight_group_3d=self.output_x_weight_group_3d) - - -tensor_parallel_env = TensorParallelEnv() +from typing import Optional + + +class TensorParallelEnv(object): + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = object.__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self, *args, **kwargs): + self.load(*args, **kwargs) + + def load(self, + mode: Optional[str] = None, + vocab_parallel: bool = False, + parallel_input_1d: bool = False, + summa_dim: int = None, + tesseract_dim: int = None, + tesseract_dep: int = None, + depth_3d: int = None, + input_group_3d=None, + weight_group_3d=None, + output_group_3d=None, + input_x_weight_group_3d=None, + output_x_weight_group_3d=None): + self.mode = mode + self.vocab_parallel = vocab_parallel + self.parallel_input_1d = parallel_input_1d + self.summa_dim = summa_dim + self.tesseract_dim = tesseract_dim + self.tesseract_dep = tesseract_dep + self.depth_3d = depth_3d + self.input_group_3d = input_group_3d + self.weight_group_3d = weight_group_3d + self.output_group_3d = output_group_3d + self.input_x_weight_group_3d = input_x_weight_group_3d + self.output_x_weight_group_3d = output_x_weight_group_3d + + def save(self): + return dict(mode=self.mode, + vocab_parallel=self.vocab_parallel, + parallel_input_1d=self.parallel_input_1d, + summa_dim=self.summa_dim, + tesseract_dim=self.tesseract_dim, + tesseract_dep=self.tesseract_dep, + depth_3d=self.depth_3d, + input_group_3d=self.input_group_3d, + weight_group_3d=self.weight_group_3d, + output_group_3d=self.output_group_3d, + input_x_weight_group_3d=self.input_x_weight_group_3d, + output_x_weight_group_3d=self.output_x_weight_group_3d) + + +tensor_parallel_env = TensorParallelEnv() diff --git a/colossalai/kernel/cuda_native/csrc/compat.h b/colossalai/kernel/cuda_native/csrc/compat.h index 00066dc95475..a62beef91a8a 100644 --- a/colossalai/kernel/cuda_native/csrc/compat.h +++ b/colossalai/kernel/cuda_native/csrc/compat.h @@ -7,4 +7,4 @@ #define DATA_PTR data_ptr #else #define DATA_PTR data -#endif \ No newline at end of file +#endif diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu index 26efa2ad6f31..9a6a8ebc3983 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu @@ -1,7 +1,6 @@ #include #include - #include "cuda_util.h" /* GPU function guard */ diff --git a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu index a39a6dae0f7f..0294410d52ac 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/dropout_kernels.cu @@ -1,1002 +1,1001 @@ -#include -#include - -#include "kernels.h" - -#include - - -namespace cg = cooperative_groups; - -curandStatePhilox4_32_10_t *curandstate; - -/** - * @brief element-wise activation function on device, like Relu, Gelu - * - * @tparam enum class ActivationType, kRelu, kGelu - * @tparam input type - * @param any shape of float and __half2 - * @return same shape and type with input - */ -template -__forceinline__ __device__ T activation_kernel(T x); - -template <> -__device__ float activation_kernel(float x) { - float cdf = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); - return x * cdf; -} - -template <> -__device__ __half2 -activation_kernel(__half2 val) { - __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); - float2 tmp_pow = __half22float2(val_pow3); - float2 tmp = __half22float2(val); - - tmp.x = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); - tmp.y = - 0.5f * - (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); - return __hmul2(val, __float22half2_rn(tmp)); -} - -template <> -__device__ float activation_kernel(float x) { - return fmaxf(x, 0); -} - -template <> -__device__ __half2 -activation_kernel(__half2 x) { - return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), - fmaxf(0.f, __half2float(x.y))); -} - -/** - * @brief element-wise activation backward function on device - * - * @tparam enum class ActivationType - * @tparam input type - * @param any shape of float and __half2 - * @return same shape of input - */ -template -__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * (dg1 + dg2 + dg3); -} - -template <> -__device__ __half activation_bwd_kernel( - __half grad, __half x_half) { - float x = __half2float(x_half); - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return grad * __float2half(dg1 + dg2 + dg3); -} - -template <> -__device__ float activation_bwd_kernel(float grad, - float x) { - return x > 0.f ? grad : 0.f; -} - -template <> -__device__ __half -activation_bwd_kernel(__half grad, __half x) { - const __half half_zero = __float2half(0.f); - return x > half_zero ? grad : half_zero; -} - -template <> -__device__ __half2 activation_bwd_kernel( - __half2 grad2, __half2 x_half2) { - const __half half_zero = __float2half(0.f); - return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, - x_half2.y > half_zero ? grad2.y : half_zero); -} - -/** - * @brief init curand states in global memory - * - * @thread grid_dim * block*dim to suuport any size of states - * @param state persistant curand states - * @param seed seed to init states - * @return void - */ -__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, - int seed) { - /* Each thread gets same seed, a different sequence - number, no offset */ - int id = threadIdx.x + blockIdx.x * blockDim.x; - curand_init(seed, id, 0, &state[id]); -} - -void launch_curand_init(int total_count, int dim, cudaStream_t stream) { - cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); - int grid_dim = total_count >> 9; - curand_init_kernel<<>>( - curandstate, std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); -} - -/** - * @brief element-wise dropout, store dropped position in mask, it's not - * in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out any size of float and __half - * @param in same with out - * @param mask uint8 type, same size with out - * @param seed seed to curand - * @return void - */ -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - float *__restrict__ out, - const float *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - - float4 input4 = data4[i]; - float4 res4; - res4.x = input4.x * scale * m[0]; - res4.y = input4.y * scale * m[1]; - res4.z = input4.z * scale * m[2]; - res4.w = input4.w * scale * m[3]; - out4[i] = res4; -} - -__global__ void ls_dropout_kernel(const int total_count, const float ratio, - __half *__restrict__ out, - const __half *__restrict__ in, - uint8_t *__restrict__ mask, const int seed) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - outs_float4[i] = out_float4; -} - -/** - * @brief element-wise dropout backward with dropout mask, it's - * not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param in any size of float and __half - * @param mask uint8 type, same size with in - * @return void - */ -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - float *out, const float *in, - const uint8_t *__restrict__ mask) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *in4 = reinterpret_cast(in); - const uint32_t *mask4 = reinterpret_cast(mask); - - uint32_t *m4 = reinterpret_cast(m); - m4[0] = mask4[i]; - - float4 input4 = in4[i]; - float4 res4; - res4.x = input4.x * scale * static_cast(m[0]); - res4.y = input4.y * scale * static_cast(m[1]); - res4.z = input4.z * scale * static_cast(m[2]); - res4.w = input4.w * scale * static_cast(m[3]); - out4[i] = res4; -} - -__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, - __half *out, const __half *in, - const uint8_t *__restrict__ mask) { - const __half scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - float4 *out4 = reinterpret_cast(out); - const float4 *vals_float4 = reinterpret_cast(in); - const uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - uint64_t *m8 = reinterpret_cast(m); - m8[0] = mask8[i]; - - float4 val_float4 = vals_float4[i]; - float4 out_float4; - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = __hmul2(val_half2[0], scale_mask_1); - out_half2[1] = __hmul2(val_half2[1], scale_mask_2); - out_half2[2] = __hmul2(val_half2[2], scale_mask_3); - out_half2[3] = __hmul2(val_half2[3], scale_mask_4); - out4[i] = out_float4; -} - -template <> -void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, - int total_count, float ratio, cudaStream_t stream, - bool backward) { - int grid_dim = total_count >> 12; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -template <> -void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, - int total_count, float ratio, - cudaStream_t stream, bool backward) { - int grid_dim = total_count >> 13; - if (!backward) { - ls_dropout_kernel<<>>( - total_count, ratio, out, vals, mask, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count()); - } else { - ls_dropout_bwd_kernel<<>>(total_count, ratio, - out, vals, mask); - } -} - -/** - * @brief fused bias, dropout, and residual at the end of Attention and FFN, - * store dropped position in mask, it's not in-place - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param residual [batch_size, seq_len, hidden_size], float and __half - * @param seed seed to curand - * @param hidden_size hidden size - * @return void - */ -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const float *__restrict__ residual, - const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 output4; - - output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; - output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; - output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; - output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; - - out4[i] = output4; -} - -__global__ void ls_dropout_res_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const __half *__restrict__ residual, - const int seed, const int hidden_size) { - const __half scale = 1. / (1. - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *residual4 = reinterpret_cast(residual); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = static_cast(rand.x > ratio); - m[1] = static_cast(rand.y > ratio); - m[2] = static_cast(rand.z > ratio); - m[3] = static_cast(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = static_cast(rand.x > ratio); - m[5] = static_cast(rand.y > ratio); - m[6] = static_cast(rand.z > ratio); - m[7] = static_cast(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = m8[0]; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - const float4 res4 = residual4[i]; - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - const __half2 *res_half2 = reinterpret_cast(&res4); - __half2 scale_mask_1 = - __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); - __half2 scale_mask_2 = - __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); - __half2 scale_mask_3 = - __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); - __half2 scale_mask_4 = - __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); - out_half2[0] = - __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); - out_half2[1] = - __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); - out_half2[2] = - __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); - out_half2[3] = - __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_res_bias(float *out, const float *vals, - uint8_t *mask, const float *bias, - const float *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 12; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, - uint8_t *mask, const __half *bias, - const __half *residual, int total_count, - int dim, float ratio, - cudaStream_t stream) { - int grid_dim = total_count >> 13; - ls_dropout_res_bias_kernel<<>>( - total_count, ratio, out, vals, mask, bias, residual, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias and dropout backward at the end of Attention and FFN - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, float *__restrict__ in_grad, - float *__restrict__ bias_grad, const float *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - // every block generate 8 bias result - __shared__ float tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - float val = out_grad[idx]; - val *= scale * static_cast(mask[idx]); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - float sum = 0; - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -__global__ void ls_dropout_bias_bwd_kernel( - const int row_size, const float ratio, __half *__restrict__ in_grad, - __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); - __shared__ __half2 tile[8][129]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); - const __half2 *out_grad2 = reinterpret_cast(out_grad); - __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); - int stride = hidden_size * 128; - __half2 local_sum = __float2half2_rn(0.f); - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - for (int r = threadIdx.y; r < row_size; r += 128) { - __half2 val = out_grad2[idx]; - __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); - val *= scale * m2; - local_sum += val; - in_grad2[idx] = val; - idx += stride; - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - - __half2 sum = __float2half2_rn(0.f); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 7; - int y = tid & (127); - if (y < 32) { -#pragma unroll - for (int i = 0; i < 4; i++) { - sum += tile[x][y + i * 32]; - } - } - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (y == 0) tile[0][x] = sum; - __syncthreads(); - - if (threadIdx.x < 8) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); - bias_grad2[pos] = tile[0][threadIdx.x]; - } -} - -template -void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template <> -void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, - const __half *out_grad, const uint8_t *mask, - int row_size, int dim, float ratio, - cudaStream_t stream) { - dim >>= 1; - dim3 grid_dim((dim - 1) / 8 + 1); - dim3 block_dim(8, 128); - ls_dropout_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); -} - -template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, - const float *out_grad, - const uint8_t *mask, int row_size, - int dim, float ratio, - cudaStream_t stream); - -/** - * @brief fused bias, activation, and dropout at the end of first ffn - * - * @thread - * gridDim.x = hidden_size / 8 - * blockDim.x = 8 - * blockDim.y = 1024 / 8 = 128 - * - * @tparam act_type activation function, like kRelu, kGelu - * @param total_count total elements - * @param ratio drop ratio - * @param out [batch_size, seq_len, hidden_size], float and __half - * @param in [batch_size, seq_len, hidden_size], float and __half - * @param mask [batch_size, seq_len, hidden_size], uint8 type - * @param bias [hidden_size], ffn bias - * @param seed seed to curand - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, float *__restrict__ out, - const float *__restrict__ in, uint8_t *__restrict__ mask, - const float *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 4 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - uint8_t m[4]; - - float4 *out4 = reinterpret_cast(out); - const float4 *data4 = reinterpret_cast(in); - const float4 *bias4 = reinterpret_cast(bias); - uint32_t *mask4 = reinterpret_cast(mask); - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int bias_i = i % (hidden_size >> 2); - uint32_t *m4 = reinterpret_cast(m); - mask4[i] = m4[0]; - const float4 input4 = data4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 output4; - - output4.x = - activation_kernel(input4.x + b4.x) * scale * m[0]; - output4.y = - activation_kernel(input4.y + b4.y) * scale * m[1]; - output4.z = - activation_kernel(input4.z + b4.z) * scale * m[2]; - output4.w = - activation_kernel(input4.w + b4.w) * scale * m[3]; - - out4[i] = output4; -} - -template -__global__ void ls_dropout_act_bias_kernel( - const int total_count, const float ratio, __half *__restrict__ out, - const __half *__restrict__ in, uint8_t *__restrict__ mask, - const __half *__restrict__ bias, const int seed, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - - int i = blockIdx.x * blockDim.x + threadIdx.x; - - if (i * 8 >= total_count) return; - - curandStatePhilox4_32_10_t state; - curand_init(seed, i, 0, &state); - - const float4 *vals_float4 = reinterpret_cast(in); - float4 *outs_float4 = reinterpret_cast(out); - const float4 *bias4 = reinterpret_cast(bias); - uint64_t *mask8 = reinterpret_cast(mask); - - uint8_t m[8]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - rand = curand_uniform4(&state); - m[4] = (uint8_t)(rand.x > ratio); - m[5] = (uint8_t)(rand.y > ratio); - m[6] = (uint8_t)(rand.z > ratio); - m[7] = (uint8_t)(rand.w > ratio); - uint64_t *m8 = reinterpret_cast(m); - mask8[i] = *m8; - - int bias_i = i % (hidden_size >> 3); - float4 val_float4 = vals_float4[i]; - const float4 b4 = __ldg(&bias4[bias_i]); - float4 out_float4; - - __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); - __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); - const __half2 *b_half2 = reinterpret_cast(&b4); - - __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); - __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); - __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); - __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); - out_half2[0] = __hmul2( - activation_kernel(__hadd2(val_half2[0], b_half2[0])), - scale_mask_1); - out_half2[1] = __hmul2( - activation_kernel(__hadd2(val_half2[1], b_half2[1])), - scale_mask_2); - out_half2[2] = __hmul2( - activation_kernel(__hadd2(val_half2[2], b_half2[2])), - scale_mask_3); - out_half2[3] = __hmul2( - activation_kernel(__hadd2(val_half2[3], b_half2[3])), - scale_mask_4); - outs_float4[i] = out_float4; -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - float *out, const float *vals, uint8_t *mask, const float *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 10; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -template <> -void launch_ls_dropout_act_bias( - __half *out, const __half *vals, uint8_t *mask, const __half *bias, - int total_count, int dim, float ratio, cudaStream_t stream) { - int grid_dim = total_count >> 11; - ls_dropout_act_bias_kernel - <<>>( - total_count, ratio, out, vals, mask, bias, - std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()) - .count(), - dim); -} - -/** - * @brief fused bias, activation, and dropout backward - * - * @thread - * gridDim.x = total_count / 1024 - * blockDim.x = 1024 - * - * @tparam act_type kRelu - * @param row_size batch_size * seq_len - * @param ratio dropout ratio - * @param in_grad [batch_size, seq_len, hidden_size], input grad - * @param bias_grad [hidden_size], bias grad - * @param out_grad [batch_size, seq_len, hidden_size], output grad - * @param mask [batch_size, seq_len, hidden_size], dropout mask - * @param hidden_size - * @return void - */ -template -__global__ void ls_dropout_act_bias_bwd_kernel( - const int row_size, const float ratio, T *in_grad, - T *__restrict__ bias_grad, const T *__restrict__ input, - const T *__restrict__ bias, const T *out_grad, - const uint8_t *__restrict__ mask, const int hidden_size) { - const float scale = 1.f / (1.f - ratio); - __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - - int stride = hidden_size * WARP_SIZE; - float local_sum = 0; - - int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); - if (col_idx < hidden_size) { - for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { - float val = out_grad[idx]; - float in = input[idx]; - float b = bias[idx % hidden_size]; - val = activation_bwd_kernel( - val * scale * static_cast(mask[idx]), in + b); - local_sum += val; - in_grad[idx] = val; - idx += stride; - } - } - - tile[threadIdx.x][threadIdx.y] = local_sum; - __syncthreads(); - float sum = tile[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; - __syncthreads(); - - if (threadIdx.y == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - bias_grad[pos] = tile[0][threadIdx.x]; - } -} - -// @brief fused bias, activation, and dropout backward -// It is deprecated for precision reason. Keep it for future optimization. -// -// template -// __global__ void ls_dropout_act_bias_bwd_kernel( -// const int row_size, const float ratio, __half * in_grad, -// __half *__restrict__ bias_grad, const __half *__restrict__ input, const -// __half *__restrict__ bias, const __half * out_grad, const uint8_t -// *__restrict__ mask, const int hidden_size) { -// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); -// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; - -// cg::thread_block b = cg::this_thread_block(); -// cg::thread_block_tile g = cg::tiled_partition(b); - -// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); -// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); -// const __half2 *out_grad2 = reinterpret_cast(out_grad); -// const __half2 *input2 = reinterpret_cast(input); -// const __half2 *bias2 = reinterpret_cast(bias); - -// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - -// int stride = hidden_size * WARP_SIZE; -// __half2 local_sum = __float2half2_rn(0.f); - -// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); -// if (col_idx < hidden_size) { -// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { -// __half2 val = out_grad2[idx]; -// __half2 in2 = input2[idx]; -// __half2 b2 = bias2[idx % hidden_size ]; -// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); -// val = activation_bwd_kernel(val * scale -// * -// m2, -// in2+b2); -// local_sum += val; -// in_grad2[idx] = val; -// idx += stride; -// } -// } - -// tile[threadIdx.x][threadIdx.y] = local_sum; -// __syncthreads(); -// __half2 sum = tile[threadIdx.y][threadIdx.x]; -// __syncthreads(); - -// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - -// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; -// __syncthreads(); - -// if (threadIdx.y == 0) { -// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); -// bias_grad2[pos] = tile[0][threadIdx.x]; -// } -// } - -template -void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, - const T *bias, const T *out_grad, - const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream) { - dim3 grid_dim((dim - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - ls_dropout_act_bias_bwd_kernel<<>>( - row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); -} - -// template <> -// void launch_ls_dropout_act_bias_bwd( -// __half *in_grad, __half *bias_grad,const __half *input, const __half -// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int -// dim, float ratio, cudaStream_t stream) { -// dim >>= 1; -// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); -// dim3 block_dim(WARP_SIZE, WARP_SIZE); -// ls_dropout_act_bias_bwd_kernel -// <<>>(row_size, ratio, in_grad, -// bias_grad, -// input, bias,out_grad, mask, dim); -// } - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - float *in_grad, float *bias_grad, const float *input, const float *bias, - const float *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); - -template void launch_ls_dropout_act_bias_bwd( - __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, - const __half *out_grad, const uint8_t *mask, int row_size, int dim, - float ratio, cudaStream_t stream); +#include + +#include +#include + +#include "kernels.h" + +namespace cg = cooperative_groups; + +curandStatePhilox4_32_10_t *curandstate; + +/** + * @brief element-wise activation function on device, like Relu, Gelu + * + * @tparam enum class ActivationType, kRelu, kGelu + * @tparam input type + * @param any shape of float and __half2 + * @return same shape and type with input + */ +template +__forceinline__ __device__ T activation_kernel(T x); + +template <> +__device__ float activation_kernel(float x) { + float cdf = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); + return x * cdf; +} + +template <> +__device__ __half2 +activation_kernel(__half2 val) { + __half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = + 0.5f * + (1.0f + tanhf((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); +} + +template <> +__device__ float activation_kernel(float x) { + return fmaxf(x, 0); +} + +template <> +__device__ __half2 +activation_kernel(__half2 x) { + return __floats2half2_rn(fmaxf(0.f, __half2float(x.x)), + fmaxf(0.f, __half2float(x.y))); +} + +/** + * @brief element-wise activation backward function on device + * + * @tparam enum class ActivationType + * @tparam input type + * @param any shape of float and __half2 + * @return same shape of input + */ +template +__forceinline__ __device__ T activation_bwd_kernel(T grad, T x); + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * (dg1 + dg2 + dg3); +} + +template <> +__device__ __half activation_bwd_kernel( + __half grad, __half x_half) { + float x = __half2float(x_half); + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return grad * __float2half(dg1 + dg2 + dg3); +} + +template <> +__device__ float activation_bwd_kernel(float grad, + float x) { + return x > 0.f ? grad : 0.f; +} + +template <> +__device__ __half +activation_bwd_kernel(__half grad, __half x) { + const __half half_zero = __float2half(0.f); + return x > half_zero ? grad : half_zero; +} + +template <> +__device__ __half2 activation_bwd_kernel( + __half2 grad2, __half2 x_half2) { + const __half half_zero = __float2half(0.f); + return __floats2half2_rn(x_half2.x > half_zero ? grad2.x : half_zero, + x_half2.y > half_zero ? grad2.y : half_zero); +} + +/** + * @brief init curand states in global memory + * + * @thread grid_dim * block*dim to suuport any size of states + * @param state persistant curand states + * @param seed seed to init states + * @return void + */ +__global__ void curand_init_kernel(curandStatePhilox4_32_10_t *state, + int seed) { + /* Each thread gets same seed, a different sequence + number, no offset */ + int id = threadIdx.x + blockIdx.x * blockDim.x; + curand_init(seed, id, 0, &state[id]); +} + +void launch_curand_init(int total_count, int dim, cudaStream_t stream) { + cudaMalloc(&curandstate, total_count * sizeof(curandStatePhilox4_32_10_t)); + int grid_dim = total_count >> 9; + curand_init_kernel<<>>( + curandstate, std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); +} + +/** + * @brief element-wise dropout, store dropped position in mask, it's not + * in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out any size of float and __half + * @param in same with out + * @param mask uint8 type, same size with out + * @param seed seed to curand + * @return void + */ +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + float *__restrict__ out, + const float *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + + float4 input4 = data4[i]; + float4 res4; + res4.x = input4.x * scale * m[0]; + res4.y = input4.y * scale * m[1]; + res4.z = input4.z * scale * m[2]; + res4.w = input4.w * scale * m[3]; + out4[i] = res4; +} + +__global__ void ls_dropout_kernel(const int total_count, const float ratio, + __half *__restrict__ out, + const __half *__restrict__ in, + uint8_t *__restrict__ mask, const int seed) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + outs_float4[i] = out_float4; +} + +/** + * @brief element-wise dropout backward with dropout mask, it's + * not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param in any size of float and __half + * @param mask uint8 type, same size with in + * @return void + */ +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + float *out, const float *in, + const uint8_t *__restrict__ mask) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *in4 = reinterpret_cast(in); + const uint32_t *mask4 = reinterpret_cast(mask); + + uint32_t *m4 = reinterpret_cast(m); + m4[0] = mask4[i]; + + float4 input4 = in4[i]; + float4 res4; + res4.x = input4.x * scale * static_cast(m[0]); + res4.y = input4.y * scale * static_cast(m[1]); + res4.z = input4.z * scale * static_cast(m[2]); + res4.w = input4.w * scale * static_cast(m[3]); + out4[i] = res4; +} + +__global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, + __half *out, const __half *in, + const uint8_t *__restrict__ mask) { + const __half scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + float4 *out4 = reinterpret_cast(out); + const float4 *vals_float4 = reinterpret_cast(in); + const uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + uint64_t *m8 = reinterpret_cast(m); + m8[0] = mask8[i]; + + float4 val_float4 = vals_float4[i]; + float4 out_float4; + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = __hmul2(val_half2[0], scale_mask_1); + out_half2[1] = __hmul2(val_half2[1], scale_mask_2); + out_half2[2] = __hmul2(val_half2[2], scale_mask_3); + out_half2[3] = __hmul2(val_half2[3], scale_mask_4); + out4[i] = out_float4; +} + +template <> +void launch_ls_dropout(float *out, const float *vals, uint8_t *mask, + int total_count, float ratio, cudaStream_t stream, + bool backward) { + int grid_dim = total_count >> 12; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +template <> +void launch_ls_dropout<__half>(__half *out, const __half *vals, uint8_t *mask, + int total_count, float ratio, + cudaStream_t stream, bool backward) { + int grid_dim = total_count >> 13; + if (!backward) { + ls_dropout_kernel<<>>( + total_count, ratio, out, vals, mask, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count()); + } else { + ls_dropout_bwd_kernel<<>>(total_count, ratio, + out, vals, mask); + } +} + +/** + * @brief fused bias, dropout, and residual at the end of Attention and FFN, + * store dropped position in mask, it's not in-place + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param residual [batch_size, seq_len, hidden_size], float and __half + * @param seed seed to curand + * @param hidden_size hidden size + * @return void + */ +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const float *__restrict__ residual, + const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 output4; + + output4.x = (input4.x + b4.x) * scale * m[0] + res4.x; + output4.y = (input4.y + b4.y) * scale * m[1] + res4.y; + output4.z = (input4.z + b4.z) * scale * m[2] + res4.z; + output4.w = (input4.w + b4.w) * scale * m[3] + res4.w; + + out4[i] = output4; +} + +__global__ void ls_dropout_res_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const __half *__restrict__ residual, + const int seed, const int hidden_size) { + const __half scale = 1. / (1. - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *residual4 = reinterpret_cast(residual); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = static_cast(rand.x > ratio); + m[1] = static_cast(rand.y > ratio); + m[2] = static_cast(rand.z > ratio); + m[3] = static_cast(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = static_cast(rand.x > ratio); + m[5] = static_cast(rand.y > ratio); + m[6] = static_cast(rand.z > ratio); + m[7] = static_cast(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = m8[0]; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + const float4 res4 = residual4[i]; + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + const __half2 *res_half2 = reinterpret_cast(&res4); + __half2 scale_mask_1 = + __halves2half2(scale * __float2half(m[0]), scale * __float2half(m[1])); + __half2 scale_mask_2 = + __halves2half2(scale * __float2half(m[2]), scale * __float2half(m[3])); + __half2 scale_mask_3 = + __halves2half2(scale * __float2half(m[4]), scale * __float2half(m[5])); + __half2 scale_mask_4 = + __halves2half2(scale * __float2half(m[6]), scale * __float2half(m[7])); + out_half2[0] = + __hfma2(__hadd2(val_half2[0], b_half2[0]), scale_mask_1, res_half2[0]); + out_half2[1] = + __hfma2(__hadd2(val_half2[1], b_half2[1]), scale_mask_2, res_half2[1]); + out_half2[2] = + __hfma2(__hadd2(val_half2[2], b_half2[2]), scale_mask_3, res_half2[2]); + out_half2[3] = + __hfma2(__hadd2(val_half2[3], b_half2[3]), scale_mask_4, res_half2[3]); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_res_bias(float *out, const float *vals, + uint8_t *mask, const float *bias, + const float *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 12; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_res_bias<__half>(__half *out, const __half *vals, + uint8_t *mask, const __half *bias, + const __half *residual, int total_count, + int dim, float ratio, + cudaStream_t stream) { + int grid_dim = total_count >> 13; + ls_dropout_res_bias_kernel<<>>( + total_count, ratio, out, vals, mask, bias, residual, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias and dropout backward at the end of Attention and FFN + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, float *__restrict__ in_grad, + float *__restrict__ bias_grad, const float *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + // every block generate 8 bias result + __shared__ float tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + float val = out_grad[idx]; + val *= scale * static_cast(mask[idx]); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + float sum = 0; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +__global__ void ls_dropout_bias_bwd_kernel( + const int row_size, const float ratio, __half *__restrict__ in_grad, + __half *__restrict__ bias_grad, const __half *__restrict__ out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); + __shared__ __half2 tile[8][129]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); + const __half2 *out_grad2 = reinterpret_cast(out_grad); + __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, 8); + int stride = hidden_size * 128; + __half2 local_sum = __float2half2_rn(0.f); + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + for (int r = threadIdx.y; r < row_size; r += 128) { + __half2 val = out_grad2[idx]; + __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); + val *= scale * m2; + local_sum += val; + in_grad2[idx] = val; + idx += stride; + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + + __half2 sum = __float2half2_rn(0.f); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 7; + int y = tid & (127); + if (y < 32) { +#pragma unroll + for (int i = 0; i < 4; i++) { + sum += tile[x][y + i * 32]; + } + } + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (y == 0) tile[0][x] = sum; + __syncthreads(); + + if (threadIdx.x < 8) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, 8); + bias_grad2[pos] = tile[0][threadIdx.x]; + } +} + +template +void launch_ls_dropout_bias_bwd(T *in_grad, T *bias_grad, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template <> +void launch_ls_dropout_bias_bwd(__half *in_grad, __half *bias_grad, + const __half *out_grad, const uint8_t *mask, + int row_size, int dim, float ratio, + cudaStream_t stream) { + dim >>= 1; + dim3 grid_dim((dim - 1) / 8 + 1); + dim3 block_dim(8, 128); + ls_dropout_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, out_grad, mask, dim); +} + +template void launch_ls_dropout_bias_bwd(float *in_grad, float *bias_grad, + const float *out_grad, + const uint8_t *mask, int row_size, + int dim, float ratio, + cudaStream_t stream); + +/** + * @brief fused bias, activation, and dropout at the end of first ffn + * + * @thread + * gridDim.x = hidden_size / 8 + * blockDim.x = 8 + * blockDim.y = 1024 / 8 = 128 + * + * @tparam act_type activation function, like kRelu, kGelu + * @param total_count total elements + * @param ratio drop ratio + * @param out [batch_size, seq_len, hidden_size], float and __half + * @param in [batch_size, seq_len, hidden_size], float and __half + * @param mask [batch_size, seq_len, hidden_size], uint8 type + * @param bias [hidden_size], ffn bias + * @param seed seed to curand + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, float *__restrict__ out, + const float *__restrict__ in, uint8_t *__restrict__ mask, + const float *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 4 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + uint8_t m[4]; + + float4 *out4 = reinterpret_cast(out); + const float4 *data4 = reinterpret_cast(in); + const float4 *bias4 = reinterpret_cast(bias); + uint32_t *mask4 = reinterpret_cast(mask); + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int bias_i = i % (hidden_size >> 2); + uint32_t *m4 = reinterpret_cast(m); + mask4[i] = m4[0]; + const float4 input4 = data4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 output4; + + output4.x = + activation_kernel(input4.x + b4.x) * scale * m[0]; + output4.y = + activation_kernel(input4.y + b4.y) * scale * m[1]; + output4.z = + activation_kernel(input4.z + b4.z) * scale * m[2]; + output4.w = + activation_kernel(input4.w + b4.w) * scale * m[3]; + + out4[i] = output4; +} + +template +__global__ void ls_dropout_act_bias_kernel( + const int total_count, const float ratio, __half *__restrict__ out, + const __half *__restrict__ in, uint8_t *__restrict__ mask, + const __half *__restrict__ bias, const int seed, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i * 8 >= total_count) return; + + curandStatePhilox4_32_10_t state; + curand_init(seed, i, 0, &state); + + const float4 *vals_float4 = reinterpret_cast(in); + float4 *outs_float4 = reinterpret_cast(out); + const float4 *bias4 = reinterpret_cast(bias); + uint64_t *mask8 = reinterpret_cast(mask); + + uint8_t m[8]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + rand = curand_uniform4(&state); + m[4] = (uint8_t)(rand.x > ratio); + m[5] = (uint8_t)(rand.y > ratio); + m[6] = (uint8_t)(rand.z > ratio); + m[7] = (uint8_t)(rand.w > ratio); + uint64_t *m8 = reinterpret_cast(m); + mask8[i] = *m8; + + int bias_i = i % (hidden_size >> 3); + float4 val_float4 = vals_float4[i]; + const float4 b4 = __ldg(&bias4[bias_i]); + float4 out_float4; + + __half2 *val_half2 = reinterpret_cast<__half2 *>(&val_float4); + __half2 *out_half2 = reinterpret_cast<__half2 *>(&out_float4); + const __half2 *b_half2 = reinterpret_cast(&b4); + + __half2 scale_mask_1 = __floats2half2_rn(scale * m[0], scale * m[1]); + __half2 scale_mask_2 = __floats2half2_rn(scale * m[2], scale * m[3]); + __half2 scale_mask_3 = __floats2half2_rn(scale * m[4], scale * m[5]); + __half2 scale_mask_4 = __floats2half2_rn(scale * m[6], scale * m[7]); + out_half2[0] = __hmul2( + activation_kernel(__hadd2(val_half2[0], b_half2[0])), + scale_mask_1); + out_half2[1] = __hmul2( + activation_kernel(__hadd2(val_half2[1], b_half2[1])), + scale_mask_2); + out_half2[2] = __hmul2( + activation_kernel(__hadd2(val_half2[2], b_half2[2])), + scale_mask_3); + out_half2[3] = __hmul2( + activation_kernel(__hadd2(val_half2[3], b_half2[3])), + scale_mask_4); + outs_float4[i] = out_float4; +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + float *out, const float *vals, uint8_t *mask, const float *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 10; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +template <> +void launch_ls_dropout_act_bias( + __half *out, const __half *vals, uint8_t *mask, const __half *bias, + int total_count, int dim, float ratio, cudaStream_t stream) { + int grid_dim = total_count >> 11; + ls_dropout_act_bias_kernel + <<>>( + total_count, ratio, out, vals, mask, bias, + std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(), + dim); +} + +/** + * @brief fused bias, activation, and dropout backward + * + * @thread + * gridDim.x = total_count / 1024 + * blockDim.x = 1024 + * + * @tparam act_type kRelu + * @param row_size batch_size * seq_len + * @param ratio dropout ratio + * @param in_grad [batch_size, seq_len, hidden_size], input grad + * @param bias_grad [hidden_size], bias grad + * @param out_grad [batch_size, seq_len, hidden_size], output grad + * @param mask [batch_size, seq_len, hidden_size], dropout mask + * @param hidden_size + * @return void + */ +template +__global__ void ls_dropout_act_bias_bwd_kernel( + const int row_size, const float ratio, T *in_grad, + T *__restrict__ bias_grad, const T *__restrict__ input, + const T *__restrict__ bias, const T *out_grad, + const uint8_t *__restrict__ mask, const int hidden_size) { + const float scale = 1.f / (1.f - ratio); + __shared__ float tile[WARP_SIZE][WARP_SIZE + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + + int stride = hidden_size * WARP_SIZE; + float local_sum = 0; + + int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); + if (col_idx < hidden_size) { + for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { + float val = out_grad[idx]; + float in = input[idx]; + float b = bias[idx % hidden_size]; + val = activation_bwd_kernel( + val * scale * static_cast(mask[idx]), in + b); + local_sum += val; + in_grad[idx] = val; + idx += stride; + } + } + + tile[threadIdx.x][threadIdx.y] = local_sum; + __syncthreads(); + float sum = tile[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; + __syncthreads(); + + if (threadIdx.y == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + bias_grad[pos] = tile[0][threadIdx.x]; + } +} + +// @brief fused bias, activation, and dropout backward +// It is deprecated for precision reason. Keep it for future optimization. +// +// template +// __global__ void ls_dropout_act_bias_bwd_kernel( +// const int row_size, const float ratio, __half * in_grad, +// __half *__restrict__ bias_grad, const __half *__restrict__ input, const +// __half *__restrict__ bias, const __half * out_grad, const uint8_t +// *__restrict__ mask, const int hidden_size) { +// const __half2 scale = __float2half2_rn(1.f / (1.f - ratio)); +// __shared__ __half2 tile[WARP_SIZE][WARP_SIZE + 1]; + +// cg::thread_block b = cg::this_thread_block(); +// cg::thread_block_tile g = cg::tiled_partition(b); + +// __half2 *in_grad2 = reinterpret_cast<__half2 *>(in_grad); +// __half2 *bias_grad2 = reinterpret_cast<__half2 *>(bias_grad); +// const __half2 *out_grad2 = reinterpret_cast(out_grad); +// const __half2 *input2 = reinterpret_cast(input); +// const __half2 *bias2 = reinterpret_cast(bias); + +// int col_idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + +// int stride = hidden_size * WARP_SIZE; +// __half2 local_sum = __float2half2_rn(0.f); + +// int idx = flat_2dim(threadIdx.y, col_idx, hidden_size); +// if (col_idx < hidden_size) { +// for (int r = threadIdx.y; r < row_size; r += WARP_SIZE) { +// __half2 val = out_grad2[idx]; +// __half2 in2 = input2[idx]; +// __half2 b2 = bias2[idx % hidden_size ]; +// __half2 m2 = __floats2half2_rn(mask[2 * idx], mask[2 * idx + 1]); +// val = activation_bwd_kernel(val * scale +// * +// m2, +// in2+b2); +// local_sum += val; +// in_grad2[idx] = val; +// idx += stride; +// } +// } + +// tile[threadIdx.x][threadIdx.y] = local_sum; +// __syncthreads(); +// __half2 sum = tile[threadIdx.y][threadIdx.x]; +// __syncthreads(); + +// for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + +// if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; +// __syncthreads(); + +// if (threadIdx.y == 0) { +// int pos = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); +// bias_grad2[pos] = tile[0][threadIdx.x]; +// } +// } + +template +void launch_ls_dropout_act_bias_bwd(T *in_grad, T *bias_grad, const T *input, + const T *bias, const T *out_grad, + const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream) { + dim3 grid_dim((dim - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + ls_dropout_act_bias_bwd_kernel<<>>( + row_size, ratio, in_grad, bias_grad, input, bias, out_grad, mask, dim); +} + +// template <> +// void launch_ls_dropout_act_bias_bwd( +// __half *in_grad, __half *bias_grad,const __half *input, const __half +// *bias, const __half *out_grad, const uint8_t *mask, int row_size, int +// dim, float ratio, cudaStream_t stream) { +// dim >>= 1; +// dim3 grid_dim((dim - 1) / WARP_SIZE + 1); +// dim3 block_dim(WARP_SIZE, WARP_SIZE); +// ls_dropout_act_bias_bwd_kernel +// <<>>(row_size, ratio, in_grad, +// bias_grad, +// input, bias,out_grad, mask, dim); +// } + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + float *in_grad, float *bias_grad, const float *input, const float *bias, + const float *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); + +template void launch_ls_dropout_act_bias_bwd( + __half *in_grad, __half *bias_grad, const __half *input, const __half *bias, + const __half *out_grad, const uint8_t *mask, int row_size, int dim, + float ratio, cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu index bc90c54c0a00..625b02cd25d9 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/general_kernels.cu @@ -1,232 +1,232 @@ -#include - -#include "kernels.h" - -namespace cg = cooperative_groups; - -/** -@brief: fuse_transpose_bias -Calculate the sum of elements in each column of the matrix. - -@thread -gridDim.x = ceil(cols / WARP_SIZE) -blockDim.x = WARP_SIZE -blockDim.y = WARP_SIZE - -@param -inp: [rows, cols] -out: [cols] -rows: the number of rows in the matrix -cols: the number of cols in the matrix -*/ -template -__global__ void column_sum_reduce(const T *__restrict__ inp, - T *__restrict__ out, int rows, int cols) { - __shared__ float tile[WARP_SIZE][WARP_SIZE]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); - int y_stride = cols * WARP_SIZE; - float localSum = 0; - - // Loop across matrix row - // TODO: optimize to log complexity - if (idx < cols) { - int offset = flat_2dim(threadIdx.y, idx, cols); - for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - // The sum of a row in tile is equal to the sum of a col in original matrix - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - // The change of threadIdx.x is continuous - float sum = tile[threadIdx.y][threadIdx.x]; - - __syncthreads(); - - // Calculate the sum of a row in tile - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); - if (pos < cols) out[pos] = sum; - } -} - -// [r, c] -> [c] -template <> -void launch_fuse_transpose_bias_kernel(const float *inp, float *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce - <<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, - int rows, int cols, - cudaStream_t stream) { - dim3 grid_dim((cols - 1) / WARP_SIZE + 1); - dim3 block_dim(WARP_SIZE, WARP_SIZE); - - column_sum_reduce<__half> - <<>>(inp, out, rows, cols); -} - -/** -@brief: fused_add2 -Add two matrix inp1 and inp2 to out. - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -inp1: [batch_size, seq_len, hidden_dim] -inp2: [batch_size, seq_len, hidden_dim] -out: [batch_size, seq_len, hidden_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -*/ -template -__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, - int hidden_dim); - -template <> -__global__ void fused_add2_kernel(float *out, const float *inp1, - const float *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - val.x = vinp1.x + vinp2.x; - val.y = vinp1.y + vinp2.y; - val.z = vinp1.z + vinp2.z; - val.w = vinp1.w + vinp2.w; - out_4[offset + i] = val; - } -} - -template <> -__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, - const __half *inp2, int hidden_dim) { - int row_id = blockIdx.x; - int offset = flat_2dim(row_id, 0, hidden_dim); - - const float4 *inp1_4 = reinterpret_cast(inp1); - const float4 *inp2_4 = reinterpret_cast(inp2); - float4 *out_4 = reinterpret_cast(out); - float4 vinp1; - float4 vinp2; - float4 val; - __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); - __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); - __half2 *h2_val = reinterpret_cast<__half2 *>(&val); - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinp1 = inp1_4[offset + i]; - vinp2 = inp2_4[offset + i]; - h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); - h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); - h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); - h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); - out_4[offset + i] = val; - } -} - -//[b, s, h] -> [b, s, h] -template <> -void launch_fused_add2(float *out, const float *inp1, const float *inp2, - int batch_size, int seq_len, int hidden_dim, - cudaStream_t &stream) { - hidden_dim >>= 2; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template <> -void launch_fused_add2<__half>(__half *out, const __half *inp1, - const __half *inp2, int batch_size, int seq_len, - int hidden_dim, cudaStream_t &stream) { - hidden_dim >>= 3; - - dim3 grid_dim(batch_size * seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - fused_add2_kernel<<>>(out, inp1, inp2, - hidden_dim); -} - -template -__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, - int sz0, int sz2, int sz1_1, int sz1_2) { - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); - if (idx >= nele) { - return; - } - float4 *dst_ptr = (float4 *)output + idx; - int idx2 = idx % sz2; - idx = idx / sz2; - int idx1 = idx % (sz1_1 + sz1_2); - int idx0 = idx / (sz1_1 + sz1_2); - float4 *src_ptr = nullptr; - int sz1 = 0; - if (idx1 < sz1_1) { - sz1 = sz1_1; - src_ptr = (float4 *)inp1; - } else { - idx1 -= sz1_1; - sz1 = sz1_2; - src_ptr = (float4 *)inp2; - } - src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); - dst_ptr[0] = src_ptr[0]; -} - -template <> -void launch_concat3_dim1(const float *inp1, const float *inp2, - float *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 2; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} - -template <> -void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, - __half *output, int sz0, int sz2, int sz1_1, - int sz1_2, cudaStream_t stream) { - sz2 >>= 3; - int nele = sz0 * sz2 * (sz1_1 + sz1_2); - int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; - kernel_concat3_dim1<<>>( - inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); -} +#include + +#include "kernels.h" + +namespace cg = cooperative_groups; + +/** +@brief: fuse_transpose_bias +Calculate the sum of elements in each column of the matrix. + +@thread +gridDim.x = ceil(cols / WARP_SIZE) +blockDim.x = WARP_SIZE +blockDim.y = WARP_SIZE + +@param +inp: [rows, cols] +out: [cols] +rows: the number of rows in the matrix +cols: the number of cols in the matrix +*/ +template +__global__ void column_sum_reduce(const T *__restrict__ inp, + T *__restrict__ out, int rows, int cols) { + __shared__ float tile[WARP_SIZE][WARP_SIZE]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = flat_2dim(blockIdx.x, threadIdx.x, WARP_SIZE); + int y_stride = cols * WARP_SIZE; + float localSum = 0; + + // Loop across matrix row + // TODO: optimize to log complexity + if (idx < cols) { + int offset = flat_2dim(threadIdx.y, idx, cols); + for (int r = threadIdx.y; r < rows; r += WARP_SIZE) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + // The sum of a row in tile is equal to the sum of a col in original matrix + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + // The change of threadIdx.x is continuous + float sum = tile[threadIdx.y][threadIdx.x]; + + __syncthreads(); + + // Calculate the sum of a row in tile + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = flat_2dim(blockIdx.x, threadIdx.y, WARP_SIZE); + if (pos < cols) out[pos] = sum; + } +} + +// [r, c] -> [c] +template <> +void launch_fuse_transpose_bias_kernel(const float *inp, float *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce + <<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half *inp, __half *out, + int rows, int cols, + cudaStream_t stream) { + dim3 grid_dim((cols - 1) / WARP_SIZE + 1); + dim3 block_dim(WARP_SIZE, WARP_SIZE); + + column_sum_reduce<__half> + <<>>(inp, out, rows, cols); +} + +/** +@brief: fused_add2 +Add two matrix inp1 and inp2 to out. + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +inp1: [batch_size, seq_len, hidden_dim] +inp2: [batch_size, seq_len, hidden_dim] +out: [batch_size, seq_len, hidden_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +*/ +template +__global__ void fused_add2_kernel(T *out, const T *inp1, const T *inp2, + int hidden_dim); + +template <> +__global__ void fused_add2_kernel(float *out, const float *inp1, + const float *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + val.x = vinp1.x + vinp2.x; + val.y = vinp1.y + vinp2.y; + val.z = vinp1.z + vinp2.z; + val.w = vinp1.w + vinp2.w; + out_4[offset + i] = val; + } +} + +template <> +__global__ void fused_add2_kernel<__half>(__half *out, const __half *inp1, + const __half *inp2, int hidden_dim) { + int row_id = blockIdx.x; + int offset = flat_2dim(row_id, 0, hidden_dim); + + const float4 *inp1_4 = reinterpret_cast(inp1); + const float4 *inp2_4 = reinterpret_cast(inp2); + float4 *out_4 = reinterpret_cast(out); + float4 vinp1; + float4 vinp2; + float4 val; + __half2 *h2_inp1 = reinterpret_cast<__half2 *>(&vinp1); + __half2 *h2_inp2 = reinterpret_cast<__half2 *>(&vinp2); + __half2 *h2_val = reinterpret_cast<__half2 *>(&val); + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinp1 = inp1_4[offset + i]; + vinp2 = inp2_4[offset + i]; + h2_val[0] = __hadd2(h2_inp1[0], h2_inp2[0]); + h2_val[1] = __hadd2(h2_inp1[1], h2_inp2[1]); + h2_val[2] = __hadd2(h2_inp1[2], h2_inp2[2]); + h2_val[3] = __hadd2(h2_inp1[3], h2_inp2[3]); + out_4[offset + i] = val; + } +} + +//[b, s, h] -> [b, s, h] +template <> +void launch_fused_add2(float *out, const float *inp1, const float *inp2, + int batch_size, int seq_len, int hidden_dim, + cudaStream_t &stream) { + hidden_dim >>= 2; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template <> +void launch_fused_add2<__half>(__half *out, const __half *inp1, + const __half *inp2, int batch_size, int seq_len, + int hidden_dim, cudaStream_t &stream) { + hidden_dim >>= 3; + + dim3 grid_dim(batch_size * seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + fused_add2_kernel<<>>(out, inp1, inp2, + hidden_dim); +} + +template +__global__ void kernel_concat3_dim1(const T *inp1, const T *inp2, T *output, + int sz0, int sz2, int sz1_1, int sz1_2) { + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int idx = flat_2dim(blockIdx.x, threadIdx.x, blockDim.x); + if (idx >= nele) { + return; + } + float4 *dst_ptr = (float4 *)output + idx; + int idx2 = idx % sz2; + idx = idx / sz2; + int idx1 = idx % (sz1_1 + sz1_2); + int idx0 = idx / (sz1_1 + sz1_2); + float4 *src_ptr = nullptr; + int sz1 = 0; + if (idx1 < sz1_1) { + sz1 = sz1_1; + src_ptr = (float4 *)inp1; + } else { + idx1 -= sz1_1; + sz1 = sz1_2; + src_ptr = (float4 *)inp2; + } + src_ptr += flat_3dim(idx0, idx1, idx2, sz1, sz2); + dst_ptr[0] = src_ptr[0]; +} + +template <> +void launch_concat3_dim1(const float *inp1, const float *inp2, + float *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 2; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} + +template <> +void launch_concat3_dim1<__half>(const __half *inp1, const __half *inp2, + __half *output, int sz0, int sz2, int sz1_1, + int sz1_2, cudaStream_t stream) { + sz2 >>= 3; + int nele = sz0 * sz2 * (sz1_1 + sz1_2); + int nblock = (nele + MAX_THREADS - 1) / MAX_THREADS; + kernel_concat3_dim1<<>>( + inp1, inp2, output, sz0, sz2, sz1_1, sz1_2); +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h index 563a7fe284a3..025fbf3f8f15 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h @@ -1,96 +1,96 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -template -class Dropout { - public: - struct Config { - float ratio; - bool training; - - Config(float r) : ratio(r), training(true) {} - float RATIO() const { return training ? ratio : 0.0; } - }; - - Dropout(const Config &config, size_t max_ele_num) - : _config(config), _mask(nullptr) { - _mask = cuda_malloc(max_ele_num); - } - - virtual ~Dropout() { cuda_free(_mask); } - - // after attention softmax - void dropout(T *output, const T *input, int count, cudaStream_t stream, - bool bwd = false) { - launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, - bwd); - } - - void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { - launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), - stream, true); - } - - // transformer layer's postprocessing dropout, after attn or ffn module, - // before residual add. - void bias_dropout_residual(T *output, const T *input, const T *residual, - const T *bias, int rows, int cols, - cudaStream_t stream) { - launch_ls_dropout_res_bias(output, input, _mask, bias, residual, - rows * cols, cols, _config.RATIO(), stream); - } - - void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, - int rows, int cols, cudaStream_t stream) { - launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, - _config.RATIO(), stream); - } - - // dropout inside ffn. - void bias_act_dropout(T *output, const T *input, const T *bias, int rows, - int cols, std::string activation_fn, - cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias( - output, input, _mask, bias, rows * cols, cols, _config.RATIO(), - stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, - const T *bias, int rows, int cols, - std::string activation_fn, cudaStream_t stream) { - if (activation_fn == "relu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else if (activation_fn == "gelu") { - launch_ls_dropout_act_bias_bwd( - d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, - _config.RATIO(), stream); - } else { - throw std::runtime_error("not supported activation: " + activation_fn); - } - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - private: - uint8_t *_mask; - Config _config; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +template +class Dropout { + public: + struct Config { + float ratio; + bool training; + + Config(float r) : ratio(r), training(true) {} + float RATIO() const { return training ? ratio : 0.0; } + }; + + Dropout(const Config &config, size_t max_ele_num) + : _config(config), _mask(nullptr) { + _mask = cuda_malloc(max_ele_num); + } + + virtual ~Dropout() { cuda_free(_mask); } + + // after attention softmax + void dropout(T *output, const T *input, int count, cudaStream_t stream, + bool bwd = false) { + launch_ls_dropout(output, input, _mask, count, _config.RATIO(), stream, + bwd); + } + + void d_dropout(T *d_inp_out, int count, cudaStream_t stream) { + launch_ls_dropout(d_inp_out, d_inp_out, _mask, count, _config.RATIO(), + stream, true); + } + + // transformer layer's postprocessing dropout, after attn or ffn module, + // before residual add. + void bias_dropout_residual(T *output, const T *input, const T *residual, + const T *bias, int rows, int cols, + cudaStream_t stream) { + launch_ls_dropout_res_bias(output, input, _mask, bias, residual, + rows * cols, cols, _config.RATIO(), stream); + } + + void d_bias_dropout_residual(T *d_input, T *d_bias, const T *d_output, + int rows, int cols, cudaStream_t stream) { + launch_ls_dropout_bias_bwd(d_input, d_bias, d_output, _mask, rows, cols, + _config.RATIO(), stream); + } + + // dropout inside ffn. + void bias_act_dropout(T *output, const T *input, const T *bias, int rows, + int cols, std::string activation_fn, + cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias( + output, input, _mask, bias, rows * cols, cols, _config.RATIO(), + stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + void d_bias_act_dropout(T *d_inp_out, T *d_bias_out, const T *input, + const T *bias, int rows, int cols, + std::string activation_fn, cudaStream_t stream) { + if (activation_fn == "relu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else if (activation_fn == "gelu") { + launch_ls_dropout_act_bias_bwd( + d_inp_out, d_bias_out, input, bias, d_inp_out, _mask, rows, cols, + _config.RATIO(), stream); + } else { + throw std::runtime_error("not supported activation: " + activation_fn); + } + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + private: + uint8_t *_mask; + Config _config; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h index fbb9c5465c24..735e1363cc46 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/kernels.h @@ -3,10 +3,11 @@ #include #include #include -#include #include #include +#include + #define MAX_THREADS 1024 #define WARP_SIZE 32 @@ -132,8 +133,9 @@ __forceinline__ __host__ __device__ int flat_3dim(int id1, int id2, int id3, } /* Convert 4-dim tensor index into vector index */ -__forceinline__ __host__ __device__ int -flat_4dim(int id1, int id2, int id3, int id4, int dim2, int dim3, int dim4) { +__forceinline__ __host__ __device__ int flat_4dim(int id1, int id2, int id3, + int id4, int dim2, int dim3, + int dim4) { // return id1*(dim2*dim3*dim4) + id2*(dim3*dim4) + id3*dim4 + id4; int res = id4; @@ -201,9 +203,9 @@ __forceinline__ __host__ __device__ int flat_6dim(int id1, int id2, int id3, } /* Convert vector index to 6-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, - int *id0, int *id1, int *id2, int *id3, int *id4, int *id5) { +__forceinline__ __host__ __device__ void decompose_6dim( + int src, int dim1, int dim2, int dim3, int dim4, int dim5, int *id0, + int *id1, int *id2, int *id3, int *id4, int *id5) { *id5 = src % dim5; src /= dim5; @@ -221,9 +223,11 @@ decompose_6dim(int src, int dim1, int dim2, int dim3, int dim4, int dim5, } /* Convert vector index to 5-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_5dim(int src, int dim1, int dim2, int dim3, int dim4, int *id0, - int *id1, int *id2, int *id3, int *id4) { +__forceinline__ __host__ __device__ void decompose_5dim(int src, int dim1, + int dim2, int dim3, + int dim4, int *id0, + int *id1, int *id2, + int *id3, int *id4) { *id4 = src % dim4; src /= dim4; @@ -253,8 +257,9 @@ __forceinline__ __host__ __device__ void decompose_4dim(int src, int dim1, } /* Convert vector index to 3-dim tensor index */ -__forceinline__ __host__ __device__ void -decompose_3dim(int src, int dim1, int dim2, int *id0, int *id1, int *id2) { +__forceinline__ __host__ __device__ void decompose_3dim(int src, int dim1, + int dim2, int *id0, + int *id1, int *id2) { *id2 = src % dim2; src /= dim2; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h index ded5c0fdcbee..a7767e187ffc 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/normalize_layer.h @@ -1,64 +1,65 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template class Normalize_Layer { -public: - struct Config { - uint32_t hidden_dim; - bool use_mean; - Config(uint32_t hidden_dim, bool use_mean = false) - : hidden_dim(hidden_dim), use_mean(use_mean) {} - }; - - Normalize_Layer(Config config, size_t max_rows) - : config_(config), vars_(nullptr), means_(nullptr) { - vars_ = cuda_malloc(max_rows); - if (config_.use_mean) { - means_ = cuda_malloc(max_rows); - } - } - - ~Normalize_Layer() { - cuda_free(vars_); - cuda_free(means_); - } - - void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, - int batch_size, cudaStream_t stream) { - launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, - config_.hidden_dim, stream); - } - - /* - residual_grad, inp_or_out, betta should be treated carefully. - inp_or_out = input if use_mean else output - residual_grad, betta can be nullptr. - residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln - betta are only used to compute xhat, - (use_mean == false) ^ (betta == nullptr) should be true - */ - void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, const T *gamma, - const T *betta, int batch_size, cudaStream_t stream[2]) { - launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, - inp_or_out, gamma, betta, vars_, means_, batch_size, - config_.hidden_dim, stream); - } - - inline bool use_mean() const { return config_.use_mean; } - -private: - Config config_; - T *vars_; - T *means_; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Normalize_Layer { + public: + struct Config { + uint32_t hidden_dim; + bool use_mean; + Config(uint32_t hidden_dim, bool use_mean = false) + : hidden_dim(hidden_dim), use_mean(use_mean) {} + }; + + Normalize_Layer(Config config, size_t max_rows) + : config_(config), vars_(nullptr), means_(nullptr) { + vars_ = cuda_malloc(max_rows); + if (config_.use_mean) { + means_ = cuda_malloc(max_rows); + } + } + + ~Normalize_Layer() { + cuda_free(vars_); + cuda_free(means_); + } + + void Forward(T *ln_res, const T *inp, const T *gamma, const T *betta, + int batch_size, cudaStream_t stream) { + launch_layer_norm(ln_res, vars_, means_, inp, gamma, betta, batch_size, + config_.hidden_dim, stream); + } + + /* + residual_grad, inp_or_out, betta should be treated carefully. + inp_or_out = input if use_mean else output + residual_grad, betta can be nullptr. + residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln + betta are only used to compute xhat, + (use_mean == false) ^ (betta == nullptr) should be true + */ + void Backward(T *gamma_grad, T *betta_grad, T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, const T *gamma, + const T *betta, int batch_size, cudaStream_t stream[2]) { + launch_ln_bw(gamma_grad, betta_grad, inp_grad, out_grad, residual_grad, + inp_or_out, gamma, betta, vars_, means_, batch_size, + config_.hidden_dim, stream); + } + + inline bool use_mean() const { return config_.use_mean; } + + private: + Config config_; + T *vars_; + T *means_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h index ec447ad84c54..b917abaf0336 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -1,42 +1,42 @@ -#pragma once - -#include -#include -#include - -#include - -#include "kernels.h" - -using namespace std; - -template -class Softmax { - public: - struct Config { - size_t nhead; - Config(size_t nhead) : nhead(nhead) {} - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, - int to_len, cudaStream_t &stream, bool mask_future = true) { - launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, - to_len, mask_future, stream); - } - - void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, - int to_len, cudaStream_t stream) { - launch_attn_softmax_bw(out_grad, soft_out, - batch_size * config_.nhead * from_len, to_len, - stream); - } - - void reset_size(size_t nhead) { config_.nhead = nhead; } - - private: - Config config_; -}; +#pragma once + +#include +#include +#include + +#include + +#include "kernels.h" + +using namespace std; + +template +class Softmax { + public: + struct Config { + size_t nhead; + Config(size_t nhead) : nhead(nhead) {} + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(T *vals, const T *attn_mask, int batch_size, int from_len, + int to_len, cudaStream_t &stream, bool mask_future = true) { + launch_attn_softmax(vals, attn_mask, batch_size, config_.nhead, from_len, + to_len, mask_future, stream); + } + + void Backward(T *out_grad, const T *soft_out, int batch_size, int from_len, + int to_len, cudaStream_t stream) { + launch_attn_softmax_bw(out_grad, soft_out, + batch_size * config_.nhead * from_len, to_len, + stream); + } + + void reset_size(size_t nhead) { config_.nhead = nhead; } + + private: + Config config_; +}; diff --git a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu index 3e61d4e35832..e2f1869b165e 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/normalize_kernels.cu @@ -1,1169 +1,1172 @@ -#include "block_reduce.h" -#include "kernels.h" -#include - -namespace cg = cooperative_groups; -const float LN_EPSILON = 1e-8f; -#define TILE_DIM 32 - -template __forceinline__ __device__ T add_eps(T x) { - return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); -} - -/** -@brief: ker_layer_norm -Standard layer normalization. -It will not only output the layer norm result, - but also outputs variance. - may also output means, depends on whether - the means argument is nullptr - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -ln_res: [batch_size* seq_len, hidden_size], ln result. -vars: [batch_size* seq_len], variance per token -means: [batch_size* seq_len], means per token, can be nullput -inp: [batch_size * seq_len, hidden_size], ln input. -scale: [hidden_size], ln scale -bias: [hidden_size], ln bias -*/ -template -__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, - const T *scale, const T *bias, int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val = inp_f4[idx]; - l_sum += val.x + val.y + val.z + val.w; - l_square_sum += - val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 4.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 vscale = __ldg((const float4 *)scale + idx); - float4 vbias = __ldg((const float4 *)bias + idx); - float4 val = inp_f4[idx]; - val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; - val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; - val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; - val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; - output_f4[idx] = val; - } -} - -template <> -__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, - __half *means, const __half *inp, - const __half *scale, const __half *bias, - int hidden_size) { - // step 0. compute local sum - float l_sum = 0; - float l_square_sum = 0; - const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 val_f2 = __half22float2(val_h2[i]); - l_sum += val_f2.x + val_f2.y; - l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; - } - } - - // step 1. compute reduce sum - float mean_dim = float(hidden_size) * 8.f; - float reduce_val[2] = {l_sum, l_square_sum}; - blockReduce(reduce_val); - __shared__ float s_mean, s_var; - if (threadIdx.x == 0) { - s_mean = reduce_val[0] / mean_dim; - if (means != nullptr) { - means[blockIdx.x] = s_mean; - } - s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; - vars[blockIdx.x] = s_var; - s_var = rsqrtf(s_var); - } - __syncthreads(); - - // step 2. layer norm result - float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; - for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - // load scale, bias, input - float4 scale_f4 = __ldg((const float4 *)scale + idx); - __half2 *scale_h2 = (__half2 *)(&scale_f4); - float4 bias_f4 = __ldg((const float4 *)bias + idx); - __half2 *bias_h2 = (__half2 *)(&bias_f4); - float4 val_f4 = inp_f4[idx]; - __half2 *val_h2 = (__half2 *)(&val_f4); - -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 scale_f2 = __half22float2(scale_h2[i]); - float2 bias_f2 = __half22float2(bias_h2[i]); - float2 val_f2 = __half22float2(val_h2[i]); - val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; - val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; - val_h2[i] = __float22half2_rn(val_f2); - } - output_f4[idx] = val_f4; - } -} - -// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; -// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x -// * val_f2_1.x + val_f2_1.y * val_f2_1.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 2; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; -// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * -// 2) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_h2[i] = __float22half2_rn(val_f2); -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// } -// } - -// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, -// __half *means, const __half *inp, -// const __half *scale, const __half -// *bias, int hidden_size) { -// // step 0. compute local sum -// float l_sum = 0; -// float l_square_sum = 0; -// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// float4 val_f4 = inp_f4[idx]; -// float4 val_f4_1 = inp_f4[idx+1]; -// float4 val_f4_2 = inp_f4[idx+2]; -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + -// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * -// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x -// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + -// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + -// val_f2_3.y * val_f2_3.y; -// } -// } - -// // step 1. compute reduce sum -// float mean_dim = float(hidden_size) * 8.f * 4; -// float reduce_val[2] = {l_sum, l_square_sum}; -// blockReduce(reduce_val); -// __shared__ float s_mean, s_var; -// if (threadIdx.x == 0) { -// s_mean = reduce_val[0] / mean_dim; -// if (means != nullptr) { -// means[blockIdx.x] = s_mean; -// } -// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; -// vars[blockIdx.x] = s_var; -// s_var = rsqrtf(s_var); -// } -// __syncthreads(); - -// // step 2. layer norm result -// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; -// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * -// 4) { -// // load scale, bias, input -// float4 scale_f4 = __ldg((const float4 *)scale + idx); -// __half2 *scale_h2 = (__half2 *)(&scale_f4); -// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); -// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); -// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); -// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); -// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); -// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); -// float4 bias_f4 = __ldg((const float4 *)bias + idx); -// __half2 *bias_h2 = (__half2 *)(&bias_f4); -// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); -// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); -// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); -// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); -// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); -// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); -// float4 val_f4 = inp_f4[idx]; -// __half2 *val_h2 = (__half2 *)(&val_f4); -// float4 val_f4_1 = inp_f4[idx+1]; -// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); -// float4 val_f4_2 = inp_f4[idx+2]; -// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); -// float4 val_f4_3 = inp_f4[idx+3]; -// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); - -// #pragma unroll -// for (int i = 0; i < 4; i++) { -// float2 scale_f2 = __half22float2(scale_h2[i]); -// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); -// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); -// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); -// float2 bias_f2 = __half22float2(bias_h2[i]); -// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); -// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); -// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); -// float2 val_f2 = __half22float2(val_h2[i]); -// float2 val_f2_1 = __half22float2(val_h2_1[i]); -// float2 val_f2_2 = __half22float2(val_h2_2[i]); -// float2 val_f2_3 = __half22float2(val_h2_3[i]); -// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; -// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; -// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + -// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y -// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * -// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var -// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * -// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) -// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = -// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); -// val_h2_2[i] = __float22half2_rn(val_f2_2); -// val_h2_3[i] = __float22half2_rn(val_f2_3); -// } -// output_f4[idx] = val_f4; -// output_f4[idx+1] = val_f4_1; -// output_f4[idx+2] = val_f4_2; -// output_f4[idx+3] = val_f4_3; -// } -// } - -template <> -void launch_layer_norm(float *ln_res, float *vars, float *means, - const float *inp, const float *scale, - const float *bias, int batch_size, int hidden_dim, - cudaStream_t stream) { - if (hidden_dim % 4 != 0) { - throw std::runtime_error("violate hidden_dim % 4 = 0"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, - const __half *inp, const __half *scale, - const __half *bias, int batch_size, - int hidden_dim, cudaStream_t stream) { - if (hidden_dim % 8 != 0) { - throw std::runtime_error("violate hidden_dim % 8 = 0"); - } - hidden_dim >>= 3; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - dim3 grid_dim(batch_size); - dim3 block_dim(nthread); - - ker_layer_norm<__half><<>>( - ln_res, vars, means, inp, scale, bias, hidden_dim); - // if (hidden_dim % 8 != 0) { - // throw std::runtime_error("violate hidden_dim % 8 = 0"); - // } - // hidden_dim >>= 3; - - // if (hidden_dim * 8 < 8192) { - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm<__half><<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { - // hidden_dim >>= 1; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x2<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { - // hidden_dim >>= 2; - // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - // dim3 grid_dim(batch_size); - // dim3 block_dim(nthread); - // ker_layer_norm_x4<<>>( - // ln_res, vars, means, inp, scale, bias, hidden_dim); - // } else { - // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - // } -} - -/** -@brief: ker_ln_bw_dgamma_dbetta -Layer norm backword kernel, compute the gradient of gamma and betta. -dbetta = sum(dout, dim=0) -dgamma = sum(xhat * dout, dim=0) -xhat = (input - mean) * rsqrt(var) or - (output - betta) / gamma - - -@thread -gridDim.x = hidden_size / 32 -blockDim.x = 32 -blockDim.y = 32 - -@param -gamma_grad: [hidden_size], gradient of gamma -betta_grad: [hidden_size], gradient of betta -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat, maybe nullptr -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat, maybe nullptr -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -(gamma && betta) ^ (vars && means) should be true -*/ -template -__global__ void -ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, const T *out_grad, - const T *inp_or_out, const T *gamma, const T *betta, - const T *vars, const T *means, int rows, int width) { - __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - // Loop across inp height - float dbetta = 0; - float dgamma = 0; - float dout, val; - if (idx < width) { - if (means == nullptr) { - float vbetta = (float)betta[idx]; - float vgamma = (float)gamma[idx]; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is output - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - vbetta) / add_eps(vgamma) * dout); - offset += y_stride; - } - } else { - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - dout = (float)out_grad[offset]; - // inp_or_out is input - val = (float)inp_or_out[offset]; - dbetta += dout; - dgamma += ((val - (float)means[r]) * - rsqrtf((float)vars[r] + LN_EPSILON) * dout); - offset += y_stride; - } - } - } - - // Sum the shared buffer. - betta_buffer[threadIdx.x][threadIdx.y] = dbetta; - gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; - __syncthreads(); - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - __syncthreads(); - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (threadIdx.x == 0 && idx < width) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/** -@brief: ker_ln_bw_dinp -Layer norm backword kernel, compute the gradient of input. -dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) - * rsqrt(var) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dxhat = dout * gamma - - -@thread -gridDim.x = batch_size * seq_len -blockDim.x = hidden_size - -@param -inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output -residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, - usually appear in pre-layer-norm for transformer layer, maybe nullptr -inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr - ln input if means is not nullptr -gamma: [hidden_size], gamma of ln, - used to compute xhat and dxhat -betta: [hidden_size], betta of ln, - used to compute xhat, maybe nullptr -vars: [batch_size * seq_len], variance of ln forward, - used to compute xhat and dinp -means: [batch_size * seq_len], mean of ln forward, - used to compute xhat, maybe nullptr -*/ -template -__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, - const T *residual_grad, const T *inp_or_out, - const T *gamma, const T *betta, const T *vars, - const T *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - float4 dxhat, xhat; - float var_rsqrt; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - dxhat = ((const float4 *)out_grad)[offset]; - float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; - dxhat.x *= vgamma.x; - dxhat.y *= vgamma.y; - dxhat.z *= vgamma.z; - dxhat.w *= vgamma.w; - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - xhat = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); - xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); - xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); - xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; - xhat.x = (xhat.x - fmean) * var_rsqrt; - xhat.y = (xhat.y - fmean) * var_rsqrt; - xhat.z = (xhat.z - fmean) * var_rsqrt; - xhat.w = (xhat.w - fmean) * var_rsqrt; - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - float reduce_val[2] = {0.f, 0.f}; - if (threadIdx.x < hidden_dim) { - reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; - reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + - dxhat.w * xhat.w; - } - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; - dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - dxhat.x += dresidual.x; - dxhat.y += dresidual.y; - dxhat.z += dresidual.z; - dxhat.w += dresidual.w; - } - ((float4 *)inp_grad)[offset] = dxhat; -} - -template <> -__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, - int hidden_dim) { - int offset = blockIdx.x * hidden_dim + threadIdx.x; - - float2 dxhat[4], xhat[4]; - float var_rsqrt; - float4 vtmp; - __half2 *tmp_h2; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[threadIdx.x]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vbetta = __half22float2(betta_h2[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; -} - -__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float var_rsqrt; - float4 vtmp, vtmp_1; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 2; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; -} - -__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, - const __half *residual_grad, - const __half *inp_or_out, const __half *gamma, - const __half *betta, const __half *vars, - const __half *means, int hidden_dim) { - int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; - - float2 dxhat[4], xhat[4]; - float2 dxhat_1[4], xhat_1[4]; - float2 dxhat_2[4], xhat_2[4]; - float2 dxhat_3[4], xhat_3[4]; - float var_rsqrt; - float4 vtmp, vtmp_1, vtmp_2, vtmp_3; - __half2 *tmp_h2; - __half2 *tmp_h2_1; - __half2 *tmp_h2_2; - __half2 *tmp_h2_3; - float reduce_val[2] = {0.f, 0.f}; - - if (threadIdx.x < hidden_dim) { - // step 0. dxhat = dout * gamma - vtmp = ((const float4 *)out_grad)[offset]; - vtmp_1 = ((const float4 *)out_grad)[offset + 1]; - vtmp_2 = ((const float4 *)out_grad)[offset + 2]; - vtmp_3 = ((const float4 *)out_grad)[offset + 3]; - tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); - tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); - tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); - tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); - float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; - float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; - float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; - float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; - __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); - __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); - __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); - __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vdout = __half22float2(tmp_h2[i]); - float2 vdout_1 = __half22float2(tmp_h2_1[i]); - float2 vdout_2 = __half22float2(tmp_h2_2[i]); - float2 vdout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - dxhat[i].x = vdout.x * vgamma.x; - dxhat[i].y = vdout.y * vgamma.y; - dxhat_1[i].x = vdout_1.x * vgamma_1.x; - dxhat_1[i].y = vdout_1.y * vgamma_1.y; - dxhat_2[i].x = vdout_2.x * vgamma_2.x; - dxhat_2[i].y = vdout_2.y * vgamma_2.y; - dxhat_3[i].x = vdout_3.x * vgamma_3.x; - dxhat_3[i].y = vdout_3.y * vgamma_3.y; - reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + - dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + - dxhat_3[i].y; - } - - /* - step 1. xhat = (output - betta) / gamma or - (input - mean) * rsqrtf(var) - */ - vtmp = ((const float4 *)inp_or_out)[offset]; - vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; - vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; - vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; - var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); - if (means == nullptr) { - // inp_or_out is output, xhat = (output - betta) / gamma - float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; - float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; - float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; - float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; - __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); - __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); - __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); - __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vout = __half22float2(tmp_h2[i]); - float2 vout_1 = __half22float2(tmp_h2_1[i]); - float2 vout_2 = __half22float2(tmp_h2_2[i]); - float2 vout_3 = __half22float2(tmp_h2_3[i]); - float2 vgamma = __half22float2(gamma_h2[i]); - float2 vgamma_1 = __half22float2(gamma_h2_1[i]); - float2 vgamma_2 = __half22float2(gamma_h2_2[i]); - float2 vgamma_3 = __half22float2(gamma_h2_3[i]); - float2 vbetta = __half22float2(betta_h2[i]); - float2 vbetta_1 = __half22float2(betta_h2_1[i]); - float2 vbetta_2 = __half22float2(betta_h2_2[i]); - float2 vbetta_3 = __half22float2(betta_h2_3[i]); - xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); - xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); - xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); - xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); - xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); - xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); - xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); - xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } else { - // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) - float fmean = (float)means[blockIdx.x]; -#pragma unroll - for (int i = 0; i < 4; i++) { - float2 vinp = __half22float2(tmp_h2[i]); - float2 vinp_1 = __half22float2(tmp_h2_1[i]); - float2 vinp_2 = __half22float2(tmp_h2_2[i]); - float2 vinp_3 = __half22float2(tmp_h2_3[i]); - xhat[i].x = (vinp.x - fmean) * var_rsqrt; - xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; - xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; - xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; - xhat[i].y = (vinp.y - fmean) * var_rsqrt; - xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; - xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; - xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; - reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; - reduce_val[1] += - xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; - reduce_val[1] += - xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; - reduce_val[1] += - xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; - } - } - } - - /* step2. block reduce sum for dxhat and dxhat*xhat */ - blockReduce(reduce_val); - __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; - if (threadIdx.x == 0) { - float mean_dim = hidden_dim * 8 * 4; - s_sum_dxhat = reduce_val[0] / mean_dim; - s_sum_dxhat_xhat = reduce_val[1] / mean_dim; - } - __syncthreads(); - - /* - step3. compute input gradient - (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) - */ - if (threadIdx.x >= hidden_dim) { - return; - } - if (residual_grad) { - // Add the residual grad, - // usually in pre-layer-norm for transformer layer - float4 dresidual = ((const float4 *)residual_grad)[offset]; - float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; - float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; - float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; - __half *hdres = reinterpret_cast<__half *>(&dresidual); - __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); - __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); - __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i])); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i])); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_2[2 * i])); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_3[2 * i])); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres[2 * i + 1])); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt + - __half2float(hdres_1[2 * i + 1])); - } - } else { -#pragma unroll - for (int i = 0; i < 4; i++) { - tmp_h2[i].x = __float2half( - (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].x = __float2half( - (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].x = __float2half( - (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].x = __float2half( - (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2[i].y = __float2half( - (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_1[i].y = __float2half( - (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_2[i].y = __float2half( - (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - tmp_h2_3[i].y = __float2half( - (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * - var_rsqrt); - } - } - ((float4 *)inp_grad)[offset] = vtmp; - ((float4 *)inp_grad)[offset + 1] = vtmp_1; - ((float4 *)inp_grad)[offset + 2] = vtmp_2; - ((float4 *)inp_grad)[offset + 3] = vtmp_3; -} - -/** -Layer norm backword, - compute the gradient of gamma, betta and input. -dbetta = sum(dout, dim=0) -xhat = (input - mean) * rsqrt(var) if mean is not nullptr - (output - betta) / gamma if mean is nullptr -dgamma = sum(xhat * dout, dim=0) -dxhat = dout * gamma -dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) - * rsqrt(var) - -residual_grad, means, betta can be nullptr. -residual_grad will be added to dinp if it is not nullptr - which is useful in transformer layer when pre-ln -means and betta are only used to compute xhat, - (means == nullptr) ^ (betta == nullptr) should be true -*/ -template <> -void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, - const float *out_grad, const float *residual_grad, - const float *inp_or_out, const float *gamma, - const float *betta, const float *vars, - const float *means, int batch, int hidden_dim, - cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 4 != 0 || hidden_dim > 4096) { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); - } - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, - hidden_dim); -} - -template <> -void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, - __half *inp_grad, const __half *out_grad, - const __half *residual_grad, const __half *inp_or_out, - const __half *gamma, const __half *betta, - const __half *vars, const __half *means, int batch, - int hidden_dim, cudaStream_t stream[2]) { - // compute grad of gamma and betta - dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - ker_ln_bw_dgamma_dbetta<__half><<>>( - gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, - batch, hidden_dim); - - // compute grad of input - if (hidden_dim % 8 != 0) { - throw std::runtime_error("hidden_dim % 8 != 0"); - } - hidden_dim >>= 3; - - if (hidden_dim * 8 <= 8192) { - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { - hidden_dim >>= 1; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x2<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { - hidden_dim >>= 2; - int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); - ker_ln_bw_dinp_x4<<>>( - inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, - means, hidden_dim); - } else { - throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); - } -} +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float LN_EPSILON = 1e-8f; +#define TILE_DIM 32 + +template +__forceinline__ __device__ T add_eps(T x) { + return fabsf(x) > LN_EPSILON ? x : (x < 0 ? -LN_EPSILON : LN_EPSILON); +} + +/** +@brief: ker_layer_norm +Standard layer normalization. +It will not only output the layer norm result, + but also outputs variance. + may also output means, depends on whether + the means argument is nullptr + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +ln_res: [batch_size* seq_len, hidden_size], ln result. +vars: [batch_size* seq_len], variance per token +means: [batch_size* seq_len], means per token, can be nullput +inp: [batch_size * seq_len, hidden_size], ln input. +scale: [hidden_size], ln scale +bias: [hidden_size], ln bias +*/ +template +__global__ void ker_layer_norm(T *ln_res, T *vars, T *means, const T *inp, + const T *scale, const T *bias, int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val = inp_f4[idx]; + l_sum += val.x + val.y + val.z + val.w; + l_square_sum += + val.x * val.x + val.y * val.y + val.z * val.z + val.w * val.w; + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 4.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 vscale = __ldg((const float4 *)scale + idx); + float4 vbias = __ldg((const float4 *)bias + idx); + float4 val = inp_f4[idx]; + val.x = (val.x - s_mean) * s_var * vscale.x + vbias.x; + val.y = (val.y - s_mean) * s_var * vscale.y + vbias.y; + val.z = (val.z - s_mean) * s_var * vscale.z + vbias.z; + val.w = (val.w - s_mean) * s_var * vscale.w + vbias.w; + output_f4[idx] = val; + } +} + +template <> +__global__ void ker_layer_norm<__half>(__half *ln_res, __half *vars, + __half *means, const __half *inp, + const __half *scale, const __half *bias, + int hidden_size) { + // step 0. compute local sum + float l_sum = 0; + float l_square_sum = 0; + const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 val_f2 = __half22float2(val_h2[i]); + l_sum += val_f2.x + val_f2.y; + l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y; + } + } + + // step 1. compute reduce sum + float mean_dim = float(hidden_size) * 8.f; + float reduce_val[2] = {l_sum, l_square_sum}; + blockReduce(reduce_val); + __shared__ float s_mean, s_var; + if (threadIdx.x == 0) { + s_mean = reduce_val[0] / mean_dim; + if (means != nullptr) { + means[blockIdx.x] = s_mean; + } + s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; + vars[blockIdx.x] = s_var; + s_var = rsqrtf(s_var); + } + __syncthreads(); + + // step 2. layer norm result + float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size; + for (uint idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + // load scale, bias, input + float4 scale_f4 = __ldg((const float4 *)scale + idx); + __half2 *scale_h2 = (__half2 *)(&scale_f4); + float4 bias_f4 = __ldg((const float4 *)bias + idx); + __half2 *bias_h2 = (__half2 *)(&bias_f4); + float4 val_f4 = inp_f4[idx]; + __half2 *val_h2 = (__half2 *)(&val_f4); + +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 scale_f2 = __half22float2(scale_h2[i]); + float2 bias_f2 = __half22float2(bias_h2[i]); + float2 val_f2 = __half22float2(val_h2[i]); + val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; + val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; + val_h2[i] = __float22half2_rn(val_f2); + } + output_f4[idx] = val_f4; + } +} + +// __global__ void ker_layer_norm_x2(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * 2 * hidden_size; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y; +// l_square_sum += val_f2.x * val_f2.x + val_f2.y * val_f2.y + val_f2_1.x +// * val_f2_1.x + val_f2_1.y * val_f2_1.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 2; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 2; +// for (uint idx = 2 * threadIdx.x; idx < hidden_size * 2; idx += blockDim.x * +// 2) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_h2[i] = __float22half2_rn(val_f2); +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_h2_1[i] = __float22half2_rn(val_f2_1); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// } +// } + +// __global__ void ker_layer_norm_x4(__half *ln_res, __half *vars, +// __half *means, const __half *inp, +// const __half *scale, const __half +// *bias, int hidden_size) { +// // step 0. compute local sum +// float l_sum = 0; +// float l_square_sum = 0; +// const float4 *inp_f4 = (const float4 *)inp + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// float4 val_f4 = inp_f4[idx]; +// float4 val_f4_1 = inp_f4[idx+1]; +// float4 val_f4_2 = inp_f4[idx+2]; +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// l_sum += val_f2.x + val_f2.y + val_f2_1.x + val_f2_1.y + val_f2_2.x + +// val_f2_2.y + val_f2_3.x + val_f2_3.y; l_square_sum += val_f2.x * +// val_f2.x + val_f2.y * val_f2.y; l_square_sum += val_f2_1.x * val_f2_1.x +// + val_f2_1.y * val_f2_1.y; l_square_sum += val_f2_2.x * val_f2_2.x + +// val_f2_2.y * val_f2_2.y; l_square_sum += val_f2_3.x * val_f2_3.x + +// val_f2_3.y * val_f2_3.y; +// } +// } + +// // step 1. compute reduce sum +// float mean_dim = float(hidden_size) * 8.f * 4; +// float reduce_val[2] = {l_sum, l_square_sum}; +// blockReduce(reduce_val); +// __shared__ float s_mean, s_var; +// if (threadIdx.x == 0) { +// s_mean = reduce_val[0] / mean_dim; +// if (means != nullptr) { +// means[blockIdx.x] = s_mean; +// } +// s_var = reduce_val[1] / mean_dim - s_mean * s_mean + LN_EPSILON; +// vars[blockIdx.x] = s_var; +// s_var = rsqrtf(s_var); +// } +// __syncthreads(); + +// // step 2. layer norm result +// float4 *output_f4 = (float4 *)ln_res + blockIdx.x * hidden_size * 4; +// for (uint idx = 4 * threadIdx.x; idx < hidden_size * 4; idx += blockDim.x * +// 4) { +// // load scale, bias, input +// float4 scale_f4 = __ldg((const float4 *)scale + idx); +// __half2 *scale_h2 = (__half2 *)(&scale_f4); +// float4 scale_f4_1 = __ldg((const float4 *)scale + idx + 1); +// __half2 *scale_h2_1 = (__half2 *)(&scale_f4_1); +// float4 scale_f4_2 = __ldg((const float4 *)scale + idx + 2); +// __half2 *scale_h2_2 = (__half2 *)(&scale_f4_2); +// float4 scale_f4_3 = __ldg((const float4 *)scale + idx + 3); +// __half2 *scale_h2_3 = (__half2 *)(&scale_f4_3); +// float4 bias_f4 = __ldg((const float4 *)bias + idx); +// __half2 *bias_h2 = (__half2 *)(&bias_f4); +// float4 bias_f4_1 = __ldg((const float4 *)bias + idx + 1); +// __half2 *bias_h2_1 = (__half2 *)(&bias_f4_1); +// float4 bias_f4_2 = __ldg((const float4 *)bias + idx + 2); +// __half2 *bias_h2_2 = (__half2 *)(&bias_f4_2); +// float4 bias_f4_3 = __ldg((const float4 *)bias + idx + 3); +// __half2 *bias_h2_3 = (__half2 *)(&bias_f4_3); +// float4 val_f4 = inp_f4[idx]; +// __half2 *val_h2 = (__half2 *)(&val_f4); +// float4 val_f4_1 = inp_f4[idx+1]; +// __half2 *val_h2_1 = (__half2 *)(&val_f4_1); +// float4 val_f4_2 = inp_f4[idx+2]; +// __half2 *val_h2_2 = (__half2 *)(&val_f4_2); +// float4 val_f4_3 = inp_f4[idx+3]; +// __half2 *val_h2_3 = (__half2 *)(&val_f4_3); + +// #pragma unroll +// for (int i = 0; i < 4; i++) { +// float2 scale_f2 = __half22float2(scale_h2[i]); +// float2 scale_f2_1 = __half22float2(scale_h2_1[i]); +// float2 scale_f2_2 = __half22float2(scale_h2_2[i]); +// float2 scale_f2_3 = __half22float2(scale_h2_3[i]); +// float2 bias_f2 = __half22float2(bias_h2[i]); +// float2 bias_f2_1 = __half22float2(bias_h2_1[i]); +// float2 bias_f2_2 = __half22float2(bias_h2_2[i]); +// float2 bias_f2_3 = __half22float2(bias_h2_3[i]); +// float2 val_f2 = __half22float2(val_h2[i]); +// float2 val_f2_1 = __half22float2(val_h2_1[i]); +// float2 val_f2_2 = __half22float2(val_h2_2[i]); +// float2 val_f2_3 = __half22float2(val_h2_3[i]); +// val_f2.x = (val_f2.x - s_mean) * s_var * scale_f2.x + bias_f2.x; +// val_f2.y = (val_f2.y - s_mean) * s_var * scale_f2.y + bias_f2.y; +// val_f2_1.x = (val_f2_1.x - s_mean) * s_var * scale_f2_1.x + +// bias_f2_1.x; val_f2_1.y = (val_f2_1.y - s_mean) * s_var * scale_f2_1.y +// + bias_f2_1.y; val_f2_2.x = (val_f2_2.x - s_mean) * s_var * +// scale_f2_2.x + bias_f2_2.x; val_f2_2.y = (val_f2_2.y - s_mean) * s_var +// * scale_f2_2.y + bias_f2_2.y; val_f2_3.x = (val_f2_3.x - s_mean) * +// s_var * scale_f2_3.x + bias_f2_3.x; val_f2_3.y = (val_f2_3.y - s_mean) +// * s_var * scale_f2_3.y + bias_f2_3.y; val_h2[i] = +// __float22half2_rn(val_f2); val_h2_1[i] = __float22half2_rn(val_f2_1); +// val_h2_2[i] = __float22half2_rn(val_f2_2); +// val_h2_3[i] = __float22half2_rn(val_f2_3); +// } +// output_f4[idx] = val_f4; +// output_f4[idx+1] = val_f4_1; +// output_f4[idx+2] = val_f4_2; +// output_f4[idx+3] = val_f4_3; +// } +// } + +template <> +void launch_layer_norm(float *ln_res, float *vars, float *means, + const float *inp, const float *scale, + const float *bias, int batch_size, int hidden_dim, + cudaStream_t stream) { + if (hidden_dim % 4 != 0) { + throw std::runtime_error("violate hidden_dim % 4 = 0"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half *ln_res, __half *vars, __half *means, + const __half *inp, const __half *scale, + const __half *bias, int batch_size, + int hidden_dim, cudaStream_t stream) { + if (hidden_dim % 8 != 0) { + throw std::runtime_error("violate hidden_dim % 8 = 0"); + } + hidden_dim >>= 3; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + dim3 grid_dim(batch_size); + dim3 block_dim(nthread); + + ker_layer_norm<__half><<>>( + ln_res, vars, means, inp, scale, bias, hidden_dim); + // if (hidden_dim % 8 != 0) { + // throw std::runtime_error("violate hidden_dim % 8 = 0"); + // } + // hidden_dim >>= 3; + + // if (hidden_dim * 8 < 8192) { + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm<__half><<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 >= 8192 && hidden_dim * 8 <= 8192 * 2) { + // hidden_dim >>= 1; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x2<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else if (hidden_dim * 8 > 8192 * 2 && hidden_dim * 8 <= 8192 * 4) { + // hidden_dim >>= 2; + // int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + // dim3 grid_dim(batch_size); + // dim3 block_dim(nthread); + // ker_layer_norm_x4<<>>( + // ln_res, vars, means, inp, scale, bias, hidden_dim); + // } else { + // throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + // } +} + +/** +@brief: ker_ln_bw_dgamma_dbetta +Layer norm backword kernel, compute the gradient of gamma and betta. +dbetta = sum(dout, dim=0) +dgamma = sum(xhat * dout, dim=0) +xhat = (input - mean) * rsqrt(var) or + (output - betta) / gamma + + +@thread +gridDim.x = hidden_size / 32 +blockDim.x = 32 +blockDim.y = 32 + +@param +gamma_grad: [hidden_size], gradient of gamma +betta_grad: [hidden_size], gradient of betta +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat, maybe nullptr +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat, maybe nullptr +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +(gamma && betta) ^ (vars && means) should be true +*/ +template +__global__ void ker_ln_bw_dgamma_dbetta(T *gamma_grad, T *betta_grad, + const T *out_grad, const T *inp_or_out, + const T *gamma, const T *betta, + const T *vars, const T *means, int rows, + int width) { + __shared__ float betta_buffer[TILE_DIM][TILE_DIM]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + // Loop across inp height + float dbetta = 0; + float dgamma = 0; + float dout, val; + if (idx < width) { + if (means == nullptr) { + float vbetta = (float)betta[idx]; + float vgamma = (float)gamma[idx]; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is output + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - vbetta) / add_eps(vgamma) * dout); + offset += y_stride; + } + } else { + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + dout = (float)out_grad[offset]; + // inp_or_out is input + val = (float)inp_or_out[offset]; + dbetta += dout; + dgamma += ((val - (float)means[r]) * + rsqrtf((float)vars[r] + LN_EPSILON) * dout); + offset += y_stride; + } + } + } + + // Sum the shared buffer. + betta_buffer[threadIdx.x][threadIdx.y] = dbetta; + gamma_buffer[threadIdx.x][threadIdx.y] = dgamma; + __syncthreads(); + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + __syncthreads(); + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (threadIdx.x == 0 && idx < width) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/** +@brief: ker_ln_bw_dinp +Layer norm backword kernel, compute the gradient of input. +dinp = (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / hidden_dim) + * rsqrt(var) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dxhat = dout * gamma + + +@thread +gridDim.x = batch_size * seq_len +blockDim.x = hidden_size + +@param +inp_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +out_grad: [batch_size * seq_len, hidden_size], gradient of betta ln output +residual_grad: [batch_size * seq_len, hidden_size], gradient of residual input, + usually appear in pre-layer-norm for transformer layer, maybe nullptr +inp_or_out: [batch_size * seq_len, hidden_size], ln output if means is nullptr + ln input if means is not nullptr +gamma: [hidden_size], gamma of ln, + used to compute xhat and dxhat +betta: [hidden_size], betta of ln, + used to compute xhat, maybe nullptr +vars: [batch_size * seq_len], variance of ln forward, + used to compute xhat and dinp +means: [batch_size * seq_len], mean of ln forward, + used to compute xhat, maybe nullptr +*/ +template +__global__ void ker_ln_bw_dinp(T *inp_grad, const T *out_grad, + const T *residual_grad, const T *inp_or_out, + const T *gamma, const T *betta, const T *vars, + const T *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + float4 dxhat, xhat; + float var_rsqrt; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + dxhat = ((const float4 *)out_grad)[offset]; + float4 vgamma = ((const float4 *)gamma)[threadIdx.x]; + dxhat.x *= vgamma.x; + dxhat.y *= vgamma.y; + dxhat.z *= vgamma.z; + dxhat.w *= vgamma.w; + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + xhat = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + xhat.x = (xhat.x - vbetta.x) / add_eps(vgamma.x); + xhat.y = (xhat.y - vbetta.y) / add_eps(vgamma.y); + xhat.z = (xhat.z - vbetta.z) / add_eps(vgamma.z); + xhat.w = (xhat.w - vbetta.w) / add_eps(vgamma.w); + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; + xhat.x = (xhat.x - fmean) * var_rsqrt; + xhat.y = (xhat.y - fmean) * var_rsqrt; + xhat.z = (xhat.z - fmean) * var_rsqrt; + xhat.w = (xhat.w - fmean) * var_rsqrt; + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + float reduce_val[2] = {0.f, 0.f}; + if (threadIdx.x < hidden_dim) { + reduce_val[0] = dxhat.x + dxhat.y + dxhat.z + dxhat.w; + reduce_val[1] = dxhat.x * xhat.x + dxhat.y * xhat.y + dxhat.z * xhat.z + + dxhat.w * xhat.w; + } + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + dxhat.x = (dxhat.x - s_sum_dxhat - xhat.x * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.y = (dxhat.y - s_sum_dxhat - xhat.y * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.z = (dxhat.z - s_sum_dxhat - xhat.z * s_sum_dxhat_xhat) * var_rsqrt; + dxhat.w = (dxhat.w - s_sum_dxhat - xhat.w * s_sum_dxhat_xhat) * var_rsqrt; + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + dxhat.x += dresidual.x; + dxhat.y += dresidual.y; + dxhat.z += dresidual.z; + dxhat.w += dresidual.w; + } + ((float4 *)inp_grad)[offset] = dxhat; +} + +template <> +__global__ void ker_ln_bw_dinp<__half>(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, + int hidden_dim) { + int offset = blockIdx.x * hidden_dim + threadIdx.x; + + float2 dxhat[4], xhat[4]; + float var_rsqrt; + float4 vtmp; + __half2 *tmp_h2; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[threadIdx.x]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vbetta = __half22float2(betta_h2[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; +} + +__global__ void ker_ln_bw_dinp_x2(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 2 + threadIdx.x * 2; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float var_rsqrt; + float4 vtmp, vtmp_1; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 2]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 2 + 1]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[2 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[2 * threadIdx.x + 1]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 2; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; +} + +__global__ void ker_ln_bw_dinp_x4(__half *inp_grad, const __half *out_grad, + const __half *residual_grad, + const __half *inp_or_out, const __half *gamma, + const __half *betta, const __half *vars, + const __half *means, int hidden_dim) { + int offset = blockIdx.x * hidden_dim * 4 + threadIdx.x * 4; + + float2 dxhat[4], xhat[4]; + float2 dxhat_1[4], xhat_1[4]; + float2 dxhat_2[4], xhat_2[4]; + float2 dxhat_3[4], xhat_3[4]; + float var_rsqrt; + float4 vtmp, vtmp_1, vtmp_2, vtmp_3; + __half2 *tmp_h2; + __half2 *tmp_h2_1; + __half2 *tmp_h2_2; + __half2 *tmp_h2_3; + float reduce_val[2] = {0.f, 0.f}; + + if (threadIdx.x < hidden_dim) { + // step 0. dxhat = dout * gamma + vtmp = ((const float4 *)out_grad)[offset]; + vtmp_1 = ((const float4 *)out_grad)[offset + 1]; + vtmp_2 = ((const float4 *)out_grad)[offset + 2]; + vtmp_3 = ((const float4 *)out_grad)[offset + 3]; + tmp_h2 = reinterpret_cast<__half2 *>(&vtmp); + tmp_h2_1 = reinterpret_cast<__half2 *>(&vtmp_1); + tmp_h2_2 = reinterpret_cast<__half2 *>(&vtmp_2); + tmp_h2_3 = reinterpret_cast<__half2 *>(&vtmp_3); + float4 gamma_f4 = ((const float4 *)gamma)[threadIdx.x * 4]; + float4 gamma_f4_1 = ((const float4 *)gamma)[threadIdx.x * 4 + 1]; + float4 gamma_f4_2 = ((const float4 *)gamma)[threadIdx.x * 4 + 2]; + float4 gamma_f4_3 = ((const float4 *)gamma)[threadIdx.x * 4 + 3]; + __half2 *gamma_h2 = reinterpret_cast<__half2 *>(&gamma_f4); + __half2 *gamma_h2_1 = reinterpret_cast<__half2 *>(&gamma_f4_1); + __half2 *gamma_h2_2 = reinterpret_cast<__half2 *>(&gamma_f4_2); + __half2 *gamma_h2_3 = reinterpret_cast<__half2 *>(&gamma_f4_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vdout = __half22float2(tmp_h2[i]); + float2 vdout_1 = __half22float2(tmp_h2_1[i]); + float2 vdout_2 = __half22float2(tmp_h2_2[i]); + float2 vdout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + dxhat[i].x = vdout.x * vgamma.x; + dxhat[i].y = vdout.y * vgamma.y; + dxhat_1[i].x = vdout_1.x * vgamma_1.x; + dxhat_1[i].y = vdout_1.y * vgamma_1.y; + dxhat_2[i].x = vdout_2.x * vgamma_2.x; + dxhat_2[i].y = vdout_2.y * vgamma_2.y; + dxhat_3[i].x = vdout_3.x * vgamma_3.x; + dxhat_3[i].y = vdout_3.y * vgamma_3.y; + reduce_val[0] += dxhat[i].x + dxhat[i].y + dxhat_1[i].x + dxhat_1[i].y + + dxhat_2[i].x + dxhat_2[i].y + dxhat_3[i].x + + dxhat_3[i].y; + } + + /* + step 1. xhat = (output - betta) / gamma or + (input - mean) * rsqrtf(var) + */ + vtmp = ((const float4 *)inp_or_out)[offset]; + vtmp_1 = ((const float4 *)inp_or_out)[offset + 1]; + vtmp_2 = ((const float4 *)inp_or_out)[offset + 2]; + vtmp_3 = ((const float4 *)inp_or_out)[offset + 3]; + var_rsqrt = rsqrtf((float)vars[blockIdx.x] + LN_EPSILON); + if (means == nullptr) { + // inp_or_out is output, xhat = (output - betta) / gamma + float4 vbetta = ((const float4 *)betta)[4 * threadIdx.x]; + float4 vbetta_1 = ((const float4 *)betta)[4 * threadIdx.x + 1]; + float4 vbetta_2 = ((const float4 *)betta)[4 * threadIdx.x + 2]; + float4 vbetta_3 = ((const float4 *)betta)[4 * threadIdx.x + 3]; + __half2 *betta_h2 = reinterpret_cast<__half2 *>(&vbetta); + __half2 *betta_h2_1 = reinterpret_cast<__half2 *>(&vbetta_1); + __half2 *betta_h2_2 = reinterpret_cast<__half2 *>(&vbetta_2); + __half2 *betta_h2_3 = reinterpret_cast<__half2 *>(&vbetta_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vout = __half22float2(tmp_h2[i]); + float2 vout_1 = __half22float2(tmp_h2_1[i]); + float2 vout_2 = __half22float2(tmp_h2_2[i]); + float2 vout_3 = __half22float2(tmp_h2_3[i]); + float2 vgamma = __half22float2(gamma_h2[i]); + float2 vgamma_1 = __half22float2(gamma_h2_1[i]); + float2 vgamma_2 = __half22float2(gamma_h2_2[i]); + float2 vgamma_3 = __half22float2(gamma_h2_3[i]); + float2 vbetta = __half22float2(betta_h2[i]); + float2 vbetta_1 = __half22float2(betta_h2_1[i]); + float2 vbetta_2 = __half22float2(betta_h2_2[i]); + float2 vbetta_3 = __half22float2(betta_h2_3[i]); + xhat[i].x = (vout.x - vbetta.x) / add_eps(vgamma.x); + xhat_1[i].x = (vout_1.x - vbetta_1.x) / add_eps(vgamma_1.x); + xhat_2[i].x = (vout_2.x - vbetta_2.x) / add_eps(vgamma_2.x); + xhat_3[i].x = (vout_3.x - vbetta_3.x) / add_eps(vgamma_3.x); + xhat[i].y = (vout.y - vbetta.y) / add_eps(vgamma.y); + xhat_1[i].y = (vout_1.y - vbetta_1.y) / add_eps(vgamma_1.y); + xhat_2[i].y = (vout_2.y - vbetta_2.y) / add_eps(vgamma_2.y); + xhat_3[i].y = (vout_3.y - vbetta_3.y) / add_eps(vgamma_3.y); + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } else { + // inp_or_out is input, xhat = (input - mean) * rsqrtf(var) + float fmean = (float)means[blockIdx.x]; +#pragma unroll + for (int i = 0; i < 4; i++) { + float2 vinp = __half22float2(tmp_h2[i]); + float2 vinp_1 = __half22float2(tmp_h2_1[i]); + float2 vinp_2 = __half22float2(tmp_h2_2[i]); + float2 vinp_3 = __half22float2(tmp_h2_3[i]); + xhat[i].x = (vinp.x - fmean) * var_rsqrt; + xhat_1[i].x = (vinp_1.x - fmean) * var_rsqrt; + xhat_2[i].x = (vinp_2.x - fmean) * var_rsqrt; + xhat_3[i].x = (vinp_3.x - fmean) * var_rsqrt; + xhat[i].y = (vinp.y - fmean) * var_rsqrt; + xhat_1[i].y = (vinp_1.y - fmean) * var_rsqrt; + xhat_2[i].y = (vinp_2.y - fmean) * var_rsqrt; + xhat_3[i].y = (vinp_3.y - fmean) * var_rsqrt; + reduce_val[1] += xhat[i].x * dxhat[i].x + xhat[i].y * dxhat[i].y; + reduce_val[1] += + xhat_1[i].x * dxhat_1[i].x + xhat_1[i].y * dxhat_1[i].y; + reduce_val[1] += + xhat_2[i].x * dxhat_2[i].x + xhat_2[i].y * dxhat_2[i].y; + reduce_val[1] += + xhat_3[i].x * dxhat_3[i].x + xhat_3[i].y * dxhat_3[i].y; + } + } + } + + /* step2. block reduce sum for dxhat and dxhat*xhat */ + blockReduce(reduce_val); + __shared__ float s_sum_dxhat, s_sum_dxhat_xhat; + if (threadIdx.x == 0) { + float mean_dim = hidden_dim * 8 * 4; + s_sum_dxhat = reduce_val[0] / mean_dim; + s_sum_dxhat_xhat = reduce_val[1] / mean_dim; + } + __syncthreads(); + + /* + step3. compute input gradient + (dxhat - (sum(dxhat) + xhat * sum(dxhat * xhat)) / mean_dim) * rsqrt(var) + */ + if (threadIdx.x >= hidden_dim) { + return; + } + if (residual_grad) { + // Add the residual grad, + // usually in pre-layer-norm for transformer layer + float4 dresidual = ((const float4 *)residual_grad)[offset]; + float4 dresidual_1 = ((const float4 *)residual_grad)[offset + 1]; + float4 dresidual_2 = ((const float4 *)residual_grad)[offset + 2]; + float4 dresidual_3 = ((const float4 *)residual_grad)[offset + 3]; + __half *hdres = reinterpret_cast<__half *>(&dresidual); + __half *hdres_1 = reinterpret_cast<__half *>(&dresidual_1); + __half *hdres_2 = reinterpret_cast<__half *>(&dresidual_2); + __half *hdres_3 = reinterpret_cast<__half *>(&dresidual_3); +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i])); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i])); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_2[2 * i])); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_3[2 * i])); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres[2 * i + 1])); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt + + __half2float(hdres_1[2 * i + 1])); + } + } else { +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp_h2[i].x = __float2half( + (dxhat[i].x - s_sum_dxhat - xhat[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].x = __float2half( + (dxhat_1[i].x - s_sum_dxhat - xhat_1[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].x = __float2half( + (dxhat_2[i].x - s_sum_dxhat - xhat_2[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].x = __float2half( + (dxhat_3[i].x - s_sum_dxhat - xhat_3[i].x * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2[i].y = __float2half( + (dxhat[i].y - s_sum_dxhat - xhat[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_1[i].y = __float2half( + (dxhat_1[i].y - s_sum_dxhat - xhat_1[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_2[i].y = __float2half( + (dxhat_2[i].y - s_sum_dxhat - xhat_2[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + tmp_h2_3[i].y = __float2half( + (dxhat_3[i].y - s_sum_dxhat - xhat_3[i].y * s_sum_dxhat_xhat) * + var_rsqrt); + } + } + ((float4 *)inp_grad)[offset] = vtmp; + ((float4 *)inp_grad)[offset + 1] = vtmp_1; + ((float4 *)inp_grad)[offset + 2] = vtmp_2; + ((float4 *)inp_grad)[offset + 3] = vtmp_3; +} + +/** +Layer norm backword, + compute the gradient of gamma, betta and input. +dbetta = sum(dout, dim=0) +xhat = (input - mean) * rsqrt(var) if mean is not nullptr + (output - betta) / gamma if mean is nullptr +dgamma = sum(xhat * dout, dim=0) +dxhat = dout * gamma +dinp = (dxhat - (sum(dxhat, 1) + xhat * sum(dxhat * xhat, 1)) / hidden_dim) + * rsqrt(var) + +residual_grad, means, betta can be nullptr. +residual_grad will be added to dinp if it is not nullptr + which is useful in transformer layer when pre-ln +means and betta are only used to compute xhat, + (means == nullptr) ^ (betta == nullptr) should be true +*/ +template <> +void launch_ln_bw(float *gamma_grad, float *betta_grad, float *inp_grad, + const float *out_grad, const float *residual_grad, + const float *inp_or_out, const float *gamma, + const float *betta, const float *vars, + const float *means, int batch, int hidden_dim, + cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 4 != 0 || hidden_dim > 4096) { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 4096"); + } + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, means, + hidden_dim); +} + +template <> +void launch_ln_bw<__half>(__half *gamma_grad, __half *betta_grad, + __half *inp_grad, const __half *out_grad, + const __half *residual_grad, const __half *inp_or_out, + const __half *gamma, const __half *betta, + const __half *vars, const __half *means, int batch, + int hidden_dim, cudaStream_t stream[2]) { + // compute grad of gamma and betta + dim3 grid_dim(((hidden_dim + TILE_DIM - 1) / TILE_DIM) * TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + ker_ln_bw_dgamma_dbetta<__half><<>>( + gamma_grad, betta_grad, out_grad, inp_or_out, gamma, betta, vars, means, + batch, hidden_dim); + + // compute grad of input + if (hidden_dim % 8 != 0) { + throw std::runtime_error("hidden_dim % 8 != 0"); + } + hidden_dim >>= 3; + + if (hidden_dim * 8 <= 8192) { + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 8192 && hidden_dim * 8 <= 8192 * 2) { + hidden_dim >>= 1; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x2<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else if (hidden_dim * 8 > 2 * 8192 && hidden_dim * 8 <= 8192 * 4) { + hidden_dim >>= 2; + int nthread = min(((hidden_dim + 31) / 32) * 32, MAX_THREADS); + ker_ln_bw_dinp_x4<<>>( + inp_grad, out_grad, residual_grad, inp_or_out, gamma, betta, vars, + means, hidden_dim); + } else { + throw std::runtime_error("hidden_dim % 4 != 0 || hidden_dim > 32768"); + } +} diff --git a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu index 98af433fe397..3862a699d3c3 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/softmax_kernels.cu @@ -1,365 +1,365 @@ -#include -#include - -#include -#include - -#include "block_reduce.h" -#include "kernels.h" - -namespace cg = cooperative_groups; -const float EPSILON = 1e-8f; - -/** -@brief: softmax_kernel -Softmax forward kernel for - enc-self-attn, dec-self-attn, encdec-attn - -@thread -gridDim.x = dynamic -gridDim.y = batch_size -gridDim.z = nhead -blockDim.x = from_len - -@param -inp: [batch_size, nhead, from_len, to_len], softmax input. -attn_mask: [batch_size, to_len], padding tokens are -inf, - non padding tokens are 0. - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template -__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // block reduce max - blockReduce(l_max); - // write shared - __shared__ float s_max[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_max[i] = l_max[i]; - } - } - __syncthreads(); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - s_max[i]); - l_sum[i] += val[i][j]; - } - } - // block reduce sum - blockReduce(l_sum); - // write shared - __shared__ float s_sum[token_per_reduce]; - if (threadIdx.x == 0) { - for (int i = 0; i < token_per_reduce; i++) { - s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - } - } - __syncthreads(); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * s_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -template -__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, - int to_len, bool mask_future) { - int batch_id = blockIdx.y; - int head_id = blockIdx.z; - const int nhead = gridDim.z; - const int token_per_reduce = 1; - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - T mval[ele_per_thread]; - if (attn_mask) { - attn_mask += batch_id * to_len; - BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); - } - - inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); - for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; - token_id += gridDim.x * token_per_reduce) { - T inp_val[token_per_reduce][ele_per_thread]; - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, - REDUCE_FLOAT_INF_NEG); - } - - /* step 1. compute max */ - // thread local max - float val[token_per_reduce][ele_per_thread]; - float l_max[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_max[i] = REDUCE_FLOAT_INF_NEG; - for (int j = 0; j < ele_per_thread; j++) { - if (attn_mask) { - val[i][j] = (float)inp_val[i][j] + (float)mval[j]; - } else { - if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { - val[i][j] = REDUCE_FLOAT_INF_NEG; - } else { - val[i][j] = (float)inp_val[i][j]; - } - } - l_max[i] = fmaxf(l_max[i], val[i][j]); - } - } - // warp reduce max - warpReduce(l_max); - - /* step 2. compute sum */ - // thread local sum - float l_sum[token_per_reduce]; - for (int i = 0; i < token_per_reduce; i++) { - l_sum[i] = 0.f; - for (int j = 0; j < ele_per_thread; j++) { - val[i][j] = __expf(val[i][j] - l_max[i]); - l_sum[i] += val[i][j]; - } - } - // warp reduce sum - warpReduce(l_sum); - - /* step 3. compute final result */ - for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { - l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); - for (int j = 0; j < ele_per_thread; j++) { - inp_val[i][j] = (T)(val[i][j] * l_sum[i]); - } - BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], - to_len); - } - } // blockIdx.x -} - -/* - attn_mask!=nullptr for enc-self-attn and enc-dec-attn - attn_mask=nullptr and mask_future=ture for dec-self-attn training - attn_mask=nullptr and mask_future=false for dec-self-attn infer -*/ -template <> -void launch_attn_softmax(float *inp, const float *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 16; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 32; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 64; - ker_attn_softmax<<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -template <> -void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, - int batch_size, int nhead, int from_len, - int to_len, bool mask_future, - cudaStream_t stream) { - dim3 grid_dim(1, batch_size, nhead); - if (to_len <= 32) { - ker_attn_softmax_lt32<__half, 32, 1><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 64) { - ker_attn_softmax_lt32<__half, 32, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 128) { - grid_dim.x = 8; - ker_attn_softmax<__half, 64, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 256) { - grid_dim.x = 16; - ker_attn_softmax<__half, 128, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else if (to_len <= 512) { - grid_dim.x = 32; - ker_attn_softmax<__half, 256, 2><<>>( - inp, attn_mask, from_len, to_len, mask_future); - } else { - throw std::runtime_error( - "Sequence length greater than 512 is currently not supported"); - } -} - -/** -@brief: ker_attn_softmax_bw -Softmax backward in self attention. - -@thread -gridDim.x = batch_size * nhead * seq_len / warps_per_block -blockDim.x = WARP_SIZE -blockDim.y = warps_per_block - -@param -grad: [batch_size, nhead, seq_len, seq_len], output grad. -output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. -*/ -template -__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - inp += offset; - - T grad_reg[ITERATIONS]; - T inp_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - inp_reg[i] = inp[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)inp_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); - } -} - -template -void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, - int softmax_len, cudaStream_t stream) { - const int warps_per_block = 4; - // rows = batch_size * nhead * from_len - dim3 grid_dim(rows / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (softmax_len <= 32) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 64) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 128) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 256) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 384) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 512) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 768) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 1024) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else if (softmax_len <= 2048) - ker_attn_softmax_bw - <<>>(out_grad, soft_inp, softmax_len); - else - throw std::runtime_error( - std::string( - "Special sequence length found in softmax backward, seq_len: ") + - std::to_string(softmax_len)); -} - -template void launch_attn_softmax_bw<__half>(__half *out_grad, - const __half *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); -template void launch_attn_softmax_bw(float *out_grad, - const float *soft_inp, int rows, - int softmax_len, - cudaStream_t stream); +#include +#include + +#include +#include + +#include "block_reduce.h" +#include "kernels.h" + +namespace cg = cooperative_groups; +const float EPSILON = 1e-8f; + +/** +@brief: softmax_kernel +Softmax forward kernel for + enc-self-attn, dec-self-attn, encdec-attn + +@thread +gridDim.x = dynamic +gridDim.y = batch_size +gridDim.z = nhead +blockDim.x = from_len + +@param +inp: [batch_size, nhead, from_len, to_len], softmax input. +attn_mask: [batch_size, to_len], padding tokens are -inf, + non padding tokens are 0. + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template +__global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // block reduce max + blockReduce(l_max); + // write shared + __shared__ float s_max[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_max[i] = l_max[i]; + } + } + __syncthreads(); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - s_max[i]); + l_sum[i] += val[i][j]; + } + } + // block reduce sum + blockReduce(l_sum); + // write shared + __shared__ float s_sum[token_per_reduce]; + if (threadIdx.x == 0) { + for (int i = 0; i < token_per_reduce; i++) { + s_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + } + } + __syncthreads(); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * s_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +template +__global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, + int to_len, bool mask_future) { + int batch_id = blockIdx.y; + int head_id = blockIdx.z; + const int nhead = gridDim.z; + const int token_per_reduce = 1; + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + T mval[ele_per_thread]; + if (attn_mask) { + attn_mask += batch_id * to_len; + BlockLoad(ts_load).Load(attn_mask, mval, to_len, REDUCE_FLOAT_INF_NEG); + } + + inp += flat_3dim(batch_id, head_id, 0, nhead, from_len * to_len); + for (int token_id = blockIdx.x * token_per_reduce; token_id < from_len; + token_id += gridDim.x * token_per_reduce) { + T inp_val[token_per_reduce][ele_per_thread]; + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + BlockLoad(ts_load).Load(inp + (token_id + i) * to_len, inp_val[i], to_len, + REDUCE_FLOAT_INF_NEG); + } + + /* step 1. compute max */ + // thread local max + float val[token_per_reduce][ele_per_thread]; + float l_max[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_max[i] = REDUCE_FLOAT_INF_NEG; + for (int j = 0; j < ele_per_thread; j++) { + if (attn_mask) { + val[i][j] = (float)inp_val[i][j] + (float)mval[j]; + } else { + if (mask_future && ele_per_thread * threadIdx.x + j > token_id + i) { + val[i][j] = REDUCE_FLOAT_INF_NEG; + } else { + val[i][j] = (float)inp_val[i][j]; + } + } + l_max[i] = fmaxf(l_max[i], val[i][j]); + } + } + // warp reduce max + warpReduce(l_max); + + /* step 2. compute sum */ + // thread local sum + float l_sum[token_per_reduce]; + for (int i = 0; i < token_per_reduce; i++) { + l_sum[i] = 0.f; + for (int j = 0; j < ele_per_thread; j++) { + val[i][j] = __expf(val[i][j] - l_max[i]); + l_sum[i] += val[i][j]; + } + } + // warp reduce sum + warpReduce(l_sum); + + /* step 3. compute final result */ + for (int i = 0; i < token_per_reduce && (token_id + i) < from_len; i++) { + l_sum[i] = __fdividef(1.0f, l_sum[i] + EPSILON); + for (int j = 0; j < ele_per_thread; j++) { + inp_val[i][j] = (T)(val[i][j] * l_sum[i]); + } + BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], + to_len); + } + } // blockIdx.x +} + +/* + attn_mask!=nullptr for enc-self-attn and enc-dec-attn + attn_mask=nullptr and mask_future=ture for dec-self-attn training + attn_mask=nullptr and mask_future=false for dec-self-attn infer +*/ +template <> +void launch_attn_softmax(float *inp, const float *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 16; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 32; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 64; + ker_attn_softmax<<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +template <> +void launch_attn_softmax<__half>(__half *inp, const __half *attn_mask, + int batch_size, int nhead, int from_len, + int to_len, bool mask_future, + cudaStream_t stream) { + dim3 grid_dim(1, batch_size, nhead); + if (to_len <= 32) { + ker_attn_softmax_lt32<__half, 32, 1><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 64) { + ker_attn_softmax_lt32<__half, 32, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 128) { + grid_dim.x = 8; + ker_attn_softmax<__half, 64, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 256) { + grid_dim.x = 16; + ker_attn_softmax<__half, 128, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else if (to_len <= 512) { + grid_dim.x = 32; + ker_attn_softmax<__half, 256, 2><<>>( + inp, attn_mask, from_len, to_len, mask_future); + } else { + throw std::runtime_error( + "Sequence length greater than 512 is currently not supported"); + } +} + +/** +@brief: ker_attn_softmax_bw +Softmax backward in self attention. + +@thread +gridDim.x = batch_size * nhead * seq_len / warps_per_block +blockDim.x = WARP_SIZE +blockDim.y = warps_per_block + +@param +grad: [batch_size, nhead, seq_len, seq_len], output grad. +output: [batch_size, nhead, seq_len, seq_len], output of softmax forward. +*/ +template +__global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + inp += offset; + + T grad_reg[ITERATIONS]; + T inp_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + inp_reg[i] = inp[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)inp_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (T)((float)inp_reg[i] * ((float)grad_reg[i] - sum)); + } +} + +template +void launch_attn_softmax_bw(T *out_grad, const T *soft_inp, int rows, + int softmax_len, cudaStream_t stream) { + const int warps_per_block = 4; + // rows = batch_size * nhead * from_len + dim3 grid_dim(rows / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (softmax_len <= 32) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 64) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 128) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 256) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 384) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 512) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 768) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 1024) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else if (softmax_len <= 2048) + ker_attn_softmax_bw + <<>>(out_grad, soft_inp, softmax_len); + else + throw std::runtime_error( + std::string( + "Special sequence length found in softmax backward, seq_len: ") + + std::to_string(softmax_len)); +} + +template void launch_attn_softmax_bw<__half>(__half *out_grad, + const __half *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); +template void launch_attn_softmax_bw(float *out_grad, + const float *soft_inp, int rows, + int softmax_len, + cudaStream_t stream); diff --git a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu index d03084b22e12..04de3c092ee0 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu +++ b/colossalai/kernel/cuda_native/csrc/kernels/transform_kernels.cu @@ -1,312 +1,314 @@ -#include -#include -#include - -#include "kernels.h" - -using namespace cub; - -/** -@brief: transform_0213 -Split the attention heads and reshape input -during backward progress of encoder self-attention - -@thread -gridDim.x = batch_size -gridDim.y = seq_len -blockDim.x = min(hidden_dim, MAX_THREADS) - -@param -input: [batch_size, seq_len, hidden_dim] -output: [batch_size, nhead, seq_len, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -*/ - -template -__global__ void transform_0213(T *output, const T *input, int hidden_dim, - int head_dim); - -template <> -__global__ void transform_0213(float *output, const float *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -template <> -__global__ void transform_0213<__half>(__half *output, const __half *input, - int hidden_dim, int head_dim) { - int batch_id = blockIdx.x; - int token_id = blockIdx.y; - int seq_len = gridDim.y; - int nhead = hidden_dim / head_dim; - - // [b, s, h] - int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); - // [b, nh, s, ad] - int trg_offset = - flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - float4 vinput4; - - for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { - vinput4 = input4[src_offset + i]; - - int head_id = i / head_dim; - int dim_id = i % head_dim; - int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); - res4[trg_offset + cur_trg_offset] = vinput4; - } -} - -// [b, s, h] -> [b, nh, s, ad] -template <> -void launch_transform_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213 - <<>>(output, input, hidden_dim, head_dim); -} - -template <> -void launch_transform_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - - dim3 grid_dim(batch_size, seq_len); - dim3 block_dim(min(hidden_dim, MAX_THREADS)); - - transform_0213<__half> - <<>>(output, input, hidden_dim, head_dim); -} - -/** -@brief: bias_add_transform_20314 -Add bias to input, transform from -[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] - -@thread -gridDim.x = dim_0 -gridDim.y = dim_1 -gridDim.z = dim_2 -blockDim.x = min(dim_3 * dim_4, MAX_THREADS) - -@param -input: [dim_0, dim_1, dim_2, dim_3, dim_4] -bias: [dim_2, dim_3, dim_4] -output: [dim_2, dim_0, dim_3, dim_1, dim_4] -*/ -template -__global__ void bias_add_transform_20314(T *output, const T *input, - const T *bias, int dim_3, int dim_4); - -template <> -__global__ void -bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_3, int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - vres4.x = vqkv4.x + vbias4.x; - vres4.y = vqkv4.y + vbias4.y; - vres4.z = vqkv4.z + vbias4.z; - vres4.w = vqkv4.w + vbias4.w; - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -template <> -__global__ void -bias_add_transform_20314<__half>(__half *output, const __half *input, - const __half *bias, int dim_3, int dim_4) { - int id0 = blockIdx.x; - int id1 = blockIdx.y; - int id2 = blockIdx.z; - int dim_0 = gridDim.x; - int dim_1 = gridDim.y; - int dim_2 = gridDim.z; - int dim_34 = dim_3 * dim_4; - - int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); - int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); - int bias_offset = flat_2dim(id2, 0, dim_34); - - const float4 *qkv4 = reinterpret_cast(input); - const float4 *bias4 = reinterpret_cast(bias); - float4 *res4 = reinterpret_cast(output); - float4 vqkv4; - float4 vbias4; - float4 vres4; - __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); - __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); - __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); - - for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { - vqkv4 = qkv4[src_offset + i]; - vbias4 = bias4[bias_offset + i]; - h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); - h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); - h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); - h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); - - int id3 = i / dim_4; - int id4 = i % dim_4; - int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); - res4[trg_offset + cur_trg_offset] = vres4; - } -} - -// [b, s, 3, h] -> [3, b, nh, s, ad] -template <> -void launch_bias_add_transform_20314(float *output, const float *input, - const float *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 2; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314 - <<>>(output, input, bias, dim_3, dim_4); -} - -template <> -void launch_bias_add_transform_20314<__half>(__half *output, - const __half *input, - const __half *bias, int dim_0, - int dim_1, int dim_2, int dim_3, - int dim_4, cudaStream_t stream) { - dim_4 >>= 3; - - dim3 grid_dim(dim_0, dim_1, dim_2); - dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); - - bias_add_transform_20314<__half> - <<>>(output, input, bias, dim_3, dim_4); -} - -/** -@brief: transform4d_0213 -Reshape the input matrix to merge the heads - -@thread -gridDim.x = (num_all + max_block_thread - 1) / max_block_thread -blockDim.x = max_block_thread - -@param -input: [trans_count, batch_size, nhead, seq_len, head_dim] -output: [batch_size, seq_len, trans_count, nhead, head_dim] -batch_size: the size of the current batch -seq_len: the sequence length of the current batch -hidden_dim: dim of the hidden tensor -nhead: number of attention heads -trans_count: 1 or 3, the count of matrice need to be transformed -*/ -template -__global__ void transform4d_0213(T *output, const T *input, int batch_size, - int seq_len, int trans_count, int nhead, - int head_dim, int num_all) { - int offset = blockIdx.x * blockDim.x + threadIdx.x; - if (offset >= num_all) { - return; - } - int trans_id, batch_id, head_id, token_id, dim_id; - decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, - &batch_id, &head_id, &token_id, &dim_id); - // [b, s, tc, nh, ad] - int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, - seq_len, trans_count, nhead, head_dim); - - const float4 *input4 = reinterpret_cast(input); - float4 *res4 = reinterpret_cast(output); - res4[trg_offset] = input4[offset]; -} - -// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] -template <> -void launch_transform4d_0213(float *output, const float *input, - int batch_size, int seq_len, int hidden_dim, - int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 2; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} - -template <> -void launch_transform4d_0213<__half>(__half *output, const __half *input, - int batch_size, int seq_len, - int hidden_dim, int nhead, int trans_count, - cudaStream_t stream) { - hidden_dim >>= 3; - int head_dim = hidden_dim / nhead; - int num_all = batch_size * seq_len * trans_count * hidden_dim; - int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; - - transform4d_0213<__half><<>>( - output, input, batch_size, seq_len, trans_count, nhead, head_dim, - num_all); -} +#include +#include +#include + +#include "kernels.h" + +using namespace cub; + +/** +@brief: transform_0213 +Split the attention heads and reshape input +during backward progress of encoder self-attention + +@thread +gridDim.x = batch_size +gridDim.y = seq_len +blockDim.x = min(hidden_dim, MAX_THREADS) + +@param +input: [batch_size, seq_len, hidden_dim] +output: [batch_size, nhead, seq_len, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +*/ + +template +__global__ void transform_0213(T *output, const T *input, int hidden_dim, + int head_dim); + +template <> +__global__ void transform_0213(float *output, const float *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +template <> +__global__ void transform_0213<__half>(__half *output, const __half *input, + int hidden_dim, int head_dim) { + int batch_id = blockIdx.x; + int token_id = blockIdx.y; + int seq_len = gridDim.y; + int nhead = hidden_dim / head_dim; + + // [b, s, h] + int src_offset = flat_3dim(batch_id, token_id, 0, seq_len, hidden_dim); + // [b, nh, s, ad] + int trg_offset = + flat_4dim(batch_id, 0, token_id, 0, nhead, seq_len, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + float4 vinput4; + + for (std::size_t i = threadIdx.x; i < hidden_dim; i += blockDim.x) { + vinput4 = input4[src_offset + i]; + + int head_id = i / head_dim; + int dim_id = i % head_dim; + int cur_trg_offset = flat_3dim(head_id, 0, dim_id, seq_len, head_dim); + res4[trg_offset + cur_trg_offset] = vinput4; + } +} + +// [b, s, h] -> [b, nh, s, ad] +template <> +void launch_transform_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213 + <<>>(output, input, hidden_dim, head_dim); +} + +template <> +void launch_transform_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + + dim3 grid_dim(batch_size, seq_len); + dim3 block_dim(min(hidden_dim, MAX_THREADS)); + + transform_0213<__half> + <<>>(output, input, hidden_dim, head_dim); +} + +/** +@brief: bias_add_transform_20314 +Add bias to input, transform from +[0, 1, 2, 3, 4] to [2, 0, 3, 1, 4] + +@thread +gridDim.x = dim_0 +gridDim.y = dim_1 +gridDim.z = dim_2 +blockDim.x = min(dim_3 * dim_4, MAX_THREADS) + +@param +input: [dim_0, dim_1, dim_2, dim_3, dim_4] +bias: [dim_2, dim_3, dim_4] +output: [dim_2, dim_0, dim_3, dim_1, dim_4] +*/ +template +__global__ void bias_add_transform_20314(T *output, const T *input, + const T *bias, int dim_3, int dim_4); + +template <> +__global__ void bias_add_transform_20314(float *output, + const float *input, + const float *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + vres4.x = vqkv4.x + vbias4.x; + vres4.y = vqkv4.y + vbias4.y; + vres4.z = vqkv4.z + vbias4.z; + vres4.w = vqkv4.w + vbias4.w; + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +template <> +__global__ void bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_3, + int dim_4) { + int id0 = blockIdx.x; + int id1 = blockIdx.y; + int id2 = blockIdx.z; + int dim_0 = gridDim.x; + int dim_1 = gridDim.y; + int dim_2 = gridDim.z; + int dim_34 = dim_3 * dim_4; + + int src_offset = flat_4dim(id0, id1, id2, 0, dim_1, dim_2, dim_34); + int trg_offset = flat_5dim(id2, id0, 0, id1, 0, dim_0, dim_3, dim_1, dim_4); + int bias_offset = flat_2dim(id2, 0, dim_34); + + const float4 *qkv4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *res4 = reinterpret_cast(output); + float4 vqkv4; + float4 vbias4; + float4 vres4; + __half2 *h2_qkv = reinterpret_cast<__half2 *>(&vqkv4); + __half2 *h2_bias = reinterpret_cast<__half2 *>(&vbias4); + __half2 *h2_res = reinterpret_cast<__half2 *>(&vres4); + + for (std::size_t i = threadIdx.x; i < dim_34; i += blockDim.x) { + vqkv4 = qkv4[src_offset + i]; + vbias4 = bias4[bias_offset + i]; + h2_res[0] = __hadd2(h2_qkv[0], h2_bias[0]); + h2_res[1] = __hadd2(h2_qkv[1], h2_bias[1]); + h2_res[2] = __hadd2(h2_qkv[2], h2_bias[2]); + h2_res[3] = __hadd2(h2_qkv[3], h2_bias[3]); + + int id3 = i / dim_4; + int id4 = i % dim_4; + int cur_trg_offset = flat_3dim(id3, 0, id4, dim_1, dim_4); + res4[trg_offset + cur_trg_offset] = vres4; + } +} + +// [b, s, 3, h] -> [3, b, nh, s, ad] +template <> +void launch_bias_add_transform_20314(float *output, const float *input, + const float *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 2; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314 + <<>>(output, input, bias, dim_3, dim_4); +} + +template <> +void launch_bias_add_transform_20314<__half>(__half *output, + const __half *input, + const __half *bias, int dim_0, + int dim_1, int dim_2, int dim_3, + int dim_4, cudaStream_t stream) { + dim_4 >>= 3; + + dim3 grid_dim(dim_0, dim_1, dim_2); + dim3 block_dim(min(dim_3 * dim_4, MAX_THREADS)); + + bias_add_transform_20314<__half> + <<>>(output, input, bias, dim_3, dim_4); +} + +/** +@brief: transform4d_0213 +Reshape the input matrix to merge the heads + +@thread +gridDim.x = (num_all + max_block_thread - 1) / max_block_thread +blockDim.x = max_block_thread + +@param +input: [trans_count, batch_size, nhead, seq_len, head_dim] +output: [batch_size, seq_len, trans_count, nhead, head_dim] +batch_size: the size of the current batch +seq_len: the sequence length of the current batch +hidden_dim: dim of the hidden tensor +nhead: number of attention heads +trans_count: 1 or 3, the count of matrice need to be transformed +*/ +template +__global__ void transform4d_0213(T *output, const T *input, int batch_size, + int seq_len, int trans_count, int nhead, + int head_dim, int num_all) { + int offset = blockIdx.x * blockDim.x + threadIdx.x; + if (offset >= num_all) { + return; + } + int trans_id, batch_id, head_id, token_id, dim_id; + decompose_5dim(offset, batch_size, nhead, seq_len, head_dim, &trans_id, + &batch_id, &head_id, &token_id, &dim_id); + // [b, s, tc, nh, ad] + int trg_offset = flat_5dim(batch_id, token_id, trans_id, head_id, dim_id, + seq_len, trans_count, nhead, head_dim); + + const float4 *input4 = reinterpret_cast(input); + float4 *res4 = reinterpret_cast(output); + res4[trg_offset] = input4[offset]; +} + +// [tc, b, nh, s, ad] -> [b, s, tc, nh, ad] +template <> +void launch_transform4d_0213(float *output, const float *input, + int batch_size, int seq_len, int hidden_dim, + int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 2; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} + +template <> +void launch_transform4d_0213<__half>(__half *output, const __half *input, + int batch_size, int seq_len, + int hidden_dim, int nhead, int trans_count, + cudaStream_t stream) { + hidden_dim >>= 3; + int head_dim = hidden_dim / nhead; + int num_all = batch_size * seq_len * trans_count * hidden_dim; + int nblock = (num_all + MAX_THREADS - 1) / MAX_THREADS; + + transform4d_0213<__half><<>>( + output, input, batch_size, seq_len, trans_count, nhead, head_dim, + num_all); +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp index 4690277e63db..15a07bb0c7ac 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda.cpp @@ -138,4 +138,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu index ad7066bbd9df..72b84d6ca40f 100644 --- a/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu @@ -680,4 +680,4 @@ void cuda_layer_norm_gradient(at::Tensor* dout, at::Tensor* mean, grad_input->DATA_PTR(), gamma != NULL ? grad_gamma->DATA_PTR() : NULL, gamma != NULL ? grad_beta->DATA_PTR() : NULL);) -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp index 61c8a725052f..8c0b89eb06d1 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda.cpp @@ -1,97 +1,97 @@ -#include - -torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx); - -torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); - -#define CHECK_CUDA(x) \ - TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -torch::Tensor moe_dispatch_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, torch::Tensor dest_idx) { - CHECK_INPUT(batch_tokens); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); -} - -torch::Tensor moe_dispatch_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(expert_grad); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); -} - -torch::Tensor moe_combine_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(expert_tokens); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, - dest_idx); -} - -std::vector moe_combine_backward(int s, int e, int c, int h, - torch::Tensor tokens_grad, - torch::Tensor expert_tokens, - torch::Tensor logits, - torch::Tensor mask, - torch::Tensor dest_idx) { - CHECK_INPUT(tokens_grad); - CHECK_INPUT(logits); - CHECK_CUDA(mask); - CHECK_CUDA(dest_idx); - - return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, - logits, mask, dest_idx); -} - -torch::Tensor moe_cumsum(torch::Tensor mask) { - CHECK_INPUT(mask); - return cumsum_sub_one_in_dim0(mask); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); - m.def("dispatch_forward", &moe_dispatch_forward, - "Forward operation in MoE dispatch function"); - m.def("dispatch_backward", &moe_dispatch_backward, - "Backward operation in MoE dispatch function"); - m.def("combine_forward", &moe_combine_forward, - "Combine operation in MoE combine function"); - m.def("combine_backward", &moe_combine_backward, - "Combine operation in MoE combine function"); -} +#include + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx); + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); + +#define CHECK_CUDA(x) \ + TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +torch::Tensor moe_dispatch_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, torch::Tensor dest_idx) { + CHECK_INPUT(batch_tokens); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_forward(s, ec, h, batch_tokens, mask, dest_idx); +} + +torch::Tensor moe_dispatch_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_grad); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_dispatch_cuda_backward(s, ec, h, expert_grad, mask, dest_idx); +} + +torch::Tensor moe_combine_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(expert_tokens); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_forward(s, e, c, h, expert_tokens, logits, mask, + dest_idx); +} + +std::vector moe_combine_backward(int s, int e, int c, int h, + torch::Tensor tokens_grad, + torch::Tensor expert_tokens, + torch::Tensor logits, + torch::Tensor mask, + torch::Tensor dest_idx) { + CHECK_INPUT(tokens_grad); + CHECK_INPUT(logits); + CHECK_CUDA(mask); + CHECK_CUDA(dest_idx); + + return moe_combine_cuda_backward(s, e, c, h, tokens_grad, expert_tokens, + logits, mask, dest_idx); +} + +torch::Tensor moe_cumsum(torch::Tensor mask) { + CHECK_INPUT(mask); + return cumsum_sub_one_in_dim0(mask); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cumsum_sub_one", &moe_cumsum, "Fast cumsum operation in dim0"); + m.def("dispatch_forward", &moe_dispatch_forward, + "Forward operation in MoE dispatch function"); + m.def("dispatch_backward", &moe_dispatch_backward, + "Backward operation in MoE dispatch function"); + m.def("combine_forward", &moe_combine_forward, + "Combine operation in MoE combine function"); + m.def("combine_backward", &moe_combine_backward, + "Combine operation in MoE combine function"); +} diff --git a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu index 0454377a2fad..66c1e6bd260e 100644 --- a/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/moe_cuda_kernel.cu @@ -1,659 +1,659 @@ -#include -#include -#include - -#include - -#include "block_reduce.h" - -template -__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - BlockStore(ts_store).Store(dst_row + idx, pack); - } -} - -template -__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, pack); - BlockStore(ts_store).Store(src_row + idx, pack); - } -} - -template -__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - BlockStore(ts_store).Store(dst_row1 + idx, pack); - BlockStore(ts_store).Store(dst_row2 + idx, pack); - } -} - -template -__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack1[pack_size], pack2[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row1 + idx, pack1); - BlockLoad(ts_load).Load(dst_row2 + idx, pack2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack1[i] += pack2[i]; - } - - BlockStore(ts_store).Store(src_row + idx, pack1); - } -} - -template -__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row + idx, pack); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack[i] *= weight; - } - - BlockStore(ts_store).Store(dst_row + idx, pack); - } -} - -template -__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, - T *weight_grad, const T weight, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T grad[pack_size], tokens[pack_size]; - float thread_sum = 0; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, grad); - BlockLoad(ts_load).Load(tks_row + idx, tokens); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum += grad[i] * tokens[i]; - grad[i] *= weight; - } - - BlockStore(ts_store).Store(src_row + idx, grad); - } - - blockReduce(&thread_sum); - - if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); -} - -template -__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, - const T weight1, const T weight2, - const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T pack1[pack_size], pack2[pack_size]; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(src_row1 + idx, pack1); - BlockLoad(ts_load).Load(src_row2 + idx, pack2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; - } - - BlockStore(ts_store).Store(dst_row + idx, pack1); - } -} - -template -__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, - T *tks_row1, T *tks_row2, T *weight_grad1, - T *weight_grad2, const T weight1, - const T weight2, const int cols) { - assert(cols % pack_size == 0); - const int bpack_size = block_size * pack_size; - - typedef cub::BlockLoad - BlockLoad; - __shared__ typename BlockLoad::TempStorage ts_load; - - typedef cub::BlockStore - BlockStore; - __shared__ typename BlockStore::TempStorage ts_store; - - int tps = threadIdx.x * pack_size; - T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], - sgrad2[pack_size]; - float thread_sum[2] = {0, 0}; - for (int idx = 0; idx + tps < cols; idx += bpack_size) { - BlockLoad(ts_load).Load(dst_row + idx, grad); - BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); - BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); - -#pragma unroll - for (int i = 0; i < pack_size; ++i) { - thread_sum[0] += grad[i] * tokens1[i]; - thread_sum[1] += grad[i] * tokens2[i]; - sgrad1[i] = weight1 * grad[i]; - sgrad2[i] = weight2 * grad[i]; - } - - BlockStore(ts_store).Store(src_row1 + idx, sgrad1); - BlockStore(ts_store).Store(src_row2 + idx, sgrad2); - } - - blockReduce(thread_sum); - - if (threadIdx.x == 0) - *weight_grad1 = static_cast(thread_sum[0]); - else if (threadIdx.x == 1) - *weight_grad2 = static_cast(thread_sum[1]); -} - -// DISPATCH KERNELS -------------------------------- - -template -__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, - const int cols, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_dpch_two_fwd(src_row, dst_row1, dst_row2, - cols); - else if (indicator1 != 0) - moe_dpch_one_fwd(src_row, dst_row1, cols); - else if (indicator2 != 0) - moe_dpch_one_fwd(src_row, dst_row2, cols); - else - return; -} - -template -__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, - const int cols, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_dpch_two_bwd(src_row, dst_row1, dst_row2, - cols); - else if (indicator1 != 0) - moe_dpch_one_bwd(src_row, dst_row1, cols); - else if (indicator2 != 0) - moe_dpch_one_bwd(src_row, dst_row2, cols); - else - return; -} - -template -__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, - int *mask1, int *mask2, int *dest1, - int *dest2, const int h) { - int row = blockIdx.x; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - moe_dpch_fwd_selector( - batch_tokens + (row * h), expert_input + (dest1[row] * h), - expert_input + (dest2[row] * h), h, mask1[row], indicator2); -} - -template -__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, - int *mask2, int *dest1, int *dest2, - const int h) { - int row = blockIdx.x; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - moe_dpch_bwd_selector( - tokens_grad + (row * h), expert_grad + (dest1[row] * h), - expert_grad + (dest2[row] * h), h, mask1[row], indicator2); -} - -// COMBINE KERNELS -------------------------------- - -template -__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, - const int cols, const T weight1, - const T weight2, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_cb_two_fwd(src_row1, src_row2, dst_row, - weight1, weight2, cols); - else if (indicator1 != 0) - moe_cb_one_fwd(src_row1, dst_row, weight1, cols); - else if (indicator2 != 0) - moe_cb_one_fwd(src_row2, dst_row, weight2, cols); - else - return; -} - -template -__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, - const int cols, T *tks_row1, T *tks_row2, - T *wt_grad1, T *wt_grad2, const T weight1, - const T weight2, const int indicator1, - const int indicator2) { - if (indicator1 != 0 && indicator2 != 0) - moe_cb_two_bwd(src_row1, src_row2, dst_row, - tks_row1, tks_row2, wt_grad1, - wt_grad2, weight1, weight2, cols); - else if (indicator1 != 0) - moe_cb_one_bwd(src_row1, dst_row, tks_row1, - wt_grad1, weight1, cols); - else if (indicator2 != 0) - moe_cb_one_bwd(src_row2, dst_row, tks_row2, - wt_grad2, weight2, cols); - else - return; -} - -template -__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, - T *logits, int *mask1, int *mask2, int *dest1, - int *dest2, const int e, const int c, - const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - T *row_log = logits + (row * e); - moe_cb_fwd_selector( - expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), - combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], - indicator2); -} - -template -__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, - T *logits, T *logits_grad, int *mask1, - int *mask2, int *dest1, int *dest2, - const int e, const int c, const int h) { - int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; - int indicator2 = mask2 == nullptr ? 0 : mask2[row]; - T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); - moe_cb_bwd_selector( - expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), - tokens_grad + (row * h), h, tks + (dest1[row] * h), - tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], - row_log[eid2], mask1[row], indicator2); -} - -// CUMSUM KERNEL -------------------------------- - -template -__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, - const int e) { - assert(s % pack_size == 0); - constexpr int bpack_size = block_size * pack_size; - int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; - __shared__ int temp[block_size + 1]; - int pack[pack_size]; - - for (int idx = 0; idx < s; idx += bpack_size) { - int offset = 1; - - if (idx + tps < s) { - temp[tid] = inputs[tps * e + bid]; -#pragma unroll - for (int i = 1; i < pack_size; ++i) { - pack[i] = inputs[(tps + i) * e + bid]; - } -#pragma unroll - for (int i = 1; i < pack_size; ++i) { - temp[tid] += pack[i]; - } - } - - for (int i = block_size >> 1; i > 0; i >>= 1) { - __syncthreads(); - if (tid < i) { - int j = offset * (2 * tid + 1) - 1; - temp[j + offset] += temp[j]; - } - offset <<= 1; - } - - if (tid == 0) { - temp[block_size] = temp[block_size - 1]; - temp[block_size - 1] = 0; - } - - for (int i = 1; i < block_size; i <<= 1) { - offset >>= 1; - __syncthreads(); - if (tid < i) { - int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; - temp[j] = temp[k]; - temp[k] += ts; - } - } - __syncthreads(); - - if (tid == 0) temp[0] = temp[block_size]; - __syncthreads(); - - if (idx + tps < s) { - temp[tid + 1] += last_sum; -#pragma unroll - for (int i = pack_size - 1; i > 0; --i) { - outputs[(tps + i) * e + bid] = temp[tid + 1]; - temp[tid + 1] -= pack[i]; - } - outputs[tps * e + bid] = temp[tid + 1]; - } - __syncthreads(); - - last_sum += temp[0]; - inputs += bpack_size * e; - outputs += bpack_size * e; - } -} - -// LAUNCH FUNCTIONS -------------------------------- - -template -void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, - int *mask2, int *dest1, int *dest2, const int s, - const int h) { - if (h < 256) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 512) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 1024) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else if (h < 2048) - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); - else - moe_dpch_fwd_kernel - <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); -} - -template -void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, - int *dest1, int *dest2, const int s, const int h) { - if (h < 256) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 512) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 1024) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else if (h < 2048) - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); - else - moe_dpch_bwd_kernel - <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); -} - -template -void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, - int *mask1, int *mask2, int *dest1, int *dest2, - const int s, const int e, const int c, const int h) { - if (h < 256) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 512) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 1024) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else if (h < 2048) - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, dest2, - e, c, h); - else - moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, - logits, mask1, mask2, dest1, - dest2, e, c, h); -} - -template -void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, - T *logits_grad, int *mask1, int *mask2, int *dest1, - int *dest2, const int s, const int e, const int c, - const int h) { - if (h < 256) - moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, - logits, logits_grad, mask1, mask2, - dest1, dest2, e, c, h); - else // if (h < 512) - moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, - logits, logits_grad, mask1, mask2, - dest1, dest2, e, c, h); - // else if (h < 1024) - // moe_cb_bwd_kernel<<>> - // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, - // dest1, dest2, e, c, h); - // else - // moe_cb_bwd_kernel<<>> - // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, - // dest1, dest2, e, c, h); -} - -void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { - if (s <= 256) - cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); - else if (s <= 512) - cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); - else if (s <= 1024) - cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); - else if (s <= 2048) - cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); - else - cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); -} - -// API FUNCTIONS -------------------------------- - -#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ - switch (TYPE) { \ - case at::ScalarType::Float: { \ - using scalar_t = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented yet for specific data type."); \ - } - -torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, - torch::Tensor batch_tokens, - torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - auto res = torch::zeros( - {ec, h}, - torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - batch_tokens.scalar_type(), "moe dispatch forward", - moe_dpch_fwd_launch( - batch_tokens.data(), res.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); - - return res; -} - -torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, - torch::Tensor expert_grad, - torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - auto res = torch::zeros( - {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - expert_grad.scalar_type(), "moe dispatch backward", - moe_dpch_bwd_launch( - res.data(), expert_grad.data(), - mask[0].data(), k == 1 ? nullptr : mask[1].data(), - dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); - - return res; -} - -torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, - torch::Tensor expert_tokens, - torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - assert(expert_tokens.dtype() == logits.dtype()); - - auto res = torch::zeros( - {s, h}, - torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - expert_tokens.scalar_type(), "moe combine forward", - moe_cb_fwd_launch( - expert_tokens.data(), res.data(), - logits.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, - h)); - - return res; -} - -std::vector moe_combine_cuda_backward( - int s, int e, int c, int h, torch::Tensor tokens_grad, - torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, - torch::Tensor dest_idx) { - assert(h % 16 == 0); - assert(tokens_grad.dtype() == expert_tokens.dtype()); - assert(expert_tokens.dtype() == logits.dtype()); - - auto egrad = torch::zeros( - {e * c, h}, - torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), - wgrad = torch::zeros( - {s, e}, torch::dtype(logits.dtype()).device(logits.device())); - auto k = mask.size(0); - - DISPATCH_FLOAT_AND_HALF( - tokens_grad.scalar_type(), "moe combine backward", - moe_cb_bwd_launch( - tokens_grad.data(), egrad.data(), - expert_tokens.data(), logits.data(), - wgrad.data(), mask[0].data(), - k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), - k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, - h)); - - return {egrad, wgrad}; -} - -torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { - assert(mask.dim() == 2); - assert(mask.dtype() == torch::kInt32); - - const int s = mask.size(0), e = mask.size(1); - auto res = - torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); - cumsum_launch(mask.data(), res.data(), s, e); - - return res; -} +#include +#include +#include + +#include + +#include "block_reduce.h" + +template +__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, pack); + BlockStore(ts_store).Store(src_row + idx, pack); + } +} + +template +__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + BlockStore(ts_store).Store(dst_row1 + idx, pack); + BlockStore(ts_store).Store(dst_row2 + idx, pack); + } +} + +template +__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row1 + idx, pack1); + BlockLoad(ts_load).Load(dst_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] += pack2[i]; + } + + BlockStore(ts_store).Store(src_row + idx, pack1); + } +} + +template +__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row + idx, pack); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack[i] *= weight; + } + + BlockStore(ts_store).Store(dst_row + idx, pack); + } +} + +template +__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, + T *weight_grad, const T weight, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens[pack_size]; + float thread_sum = 0; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row + idx, tokens); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum += grad[i] * tokens[i]; + grad[i] *= weight; + } + + BlockStore(ts_store).Store(src_row + idx, grad); + } + + blockReduce(&thread_sum); + + if (threadIdx.x == 0) *weight_grad = static_cast(thread_sum); +} + +template +__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, + const T weight1, const T weight2, + const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T pack1[pack_size], pack2[pack_size]; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(src_row1 + idx, pack1); + BlockLoad(ts_load).Load(src_row2 + idx, pack2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; + } + + BlockStore(ts_store).Store(dst_row + idx, pack1); + } +} + +template +__device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, + T *tks_row1, T *tks_row2, T *weight_grad1, + T *weight_grad2, const T weight1, + const T weight2, const int cols) { + assert(cols % pack_size == 0); + const int bpack_size = block_size * pack_size; + + typedef cub::BlockLoad + BlockLoad; + __shared__ typename BlockLoad::TempStorage ts_load; + + typedef cub::BlockStore + BlockStore; + __shared__ typename BlockStore::TempStorage ts_store; + + int tps = threadIdx.x * pack_size; + T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size], + sgrad2[pack_size]; + float thread_sum[2] = {0, 0}; + for (int idx = 0; idx + tps < cols; idx += bpack_size) { + BlockLoad(ts_load).Load(dst_row + idx, grad); + BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); + BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); + +#pragma unroll + for (int i = 0; i < pack_size; ++i) { + thread_sum[0] += grad[i] * tokens1[i]; + thread_sum[1] += grad[i] * tokens2[i]; + sgrad1[i] = weight1 * grad[i]; + sgrad2[i] = weight2 * grad[i]; + } + + BlockStore(ts_store).Store(src_row1 + idx, sgrad1); + BlockStore(ts_store).Store(src_row2 + idx, sgrad2); + } + + blockReduce(thread_sum); + + if (threadIdx.x == 0) + *weight_grad1 = static_cast(thread_sum[0]); + else if (threadIdx.x == 1) + *weight_grad2 = static_cast(thread_sum[1]); +} + +// DISPATCH KERNELS -------------------------------- + +template +__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_fwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_fwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_fwd(src_row, dst_row2, cols); + else + return; +} + +template +__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, + const int cols, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_dpch_two_bwd(src_row, dst_row1, dst_row2, + cols); + else if (indicator1 != 0) + moe_dpch_one_bwd(src_row, dst_row1, cols); + else if (indicator2 != 0) + moe_dpch_one_bwd(src_row, dst_row2, cols); + else + return; +} + +template +__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, + int *mask1, int *mask2, int *dest1, + int *dest2, const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_fwd_selector( + batch_tokens + (row * h), expert_input + (dest1[row] * h), + expert_input + (dest2[row] * h), h, mask1[row], indicator2); +} + +template +__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int h) { + int row = blockIdx.x; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + moe_dpch_bwd_selector( + tokens_grad + (row * h), expert_grad + (dest1[row] * h), + expert_grad + (dest2[row] * h), h, mask1[row], indicator2); +} + +// COMBINE KERNELS -------------------------------- + +template +__device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_fwd(src_row1, src_row2, dst_row, + weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_fwd(src_row1, dst_row, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_fwd(src_row2, dst_row, weight2, cols); + else + return; +} + +template +__device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, + const int cols, T *tks_row1, T *tks_row2, + T *wt_grad1, T *wt_grad2, const T weight1, + const T weight2, const int indicator1, + const int indicator2) { + if (indicator1 != 0 && indicator2 != 0) + moe_cb_two_bwd(src_row1, src_row2, dst_row, + tks_row1, tks_row2, wt_grad1, + wt_grad2, weight1, weight2, cols); + else if (indicator1 != 0) + moe_cb_one_bwd(src_row1, dst_row, tks_row1, + wt_grad1, weight1, cols); + else if (indicator2 != 0) + moe_cb_one_bwd(src_row2, dst_row, tks_row2, + wt_grad2, weight2, cols); + else + return; +} + +template +__global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, + T *logits, int *mask1, int *mask2, int *dest1, + int *dest2, const int e, const int c, + const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e); + moe_cb_fwd_selector( + expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), + combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row], + indicator2); +} + +template +__global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, + T *logits, T *logits_grad, int *mask1, + int *mask2, int *dest1, int *dest2, + const int e, const int c, const int h) { + int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; + int indicator2 = mask2 == nullptr ? 0 : mask2[row]; + T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); + moe_cb_bwd_selector( + expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), + tokens_grad + (row * h), h, tks + (dest1[row] * h), + tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1], + row_log[eid2], mask1[row], indicator2); +} + +// CUMSUM KERNEL -------------------------------- + +template +__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, + const int e) { + assert(s % pack_size == 0); + constexpr int bpack_size = block_size * pack_size; + int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; + __shared__ int temp[block_size + 1]; + int pack[pack_size]; + + for (int idx = 0; idx < s; idx += bpack_size) { + int offset = 1; + + if (idx + tps < s) { + temp[tid] = inputs[tps * e + bid]; +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + pack[i] = inputs[(tps + i) * e + bid]; + } +#pragma unroll + for (int i = 1; i < pack_size; ++i) { + temp[tid] += pack[i]; + } + } + + for (int i = block_size >> 1; i > 0; i >>= 1) { + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1; + temp[j + offset] += temp[j]; + } + offset <<= 1; + } + + if (tid == 0) { + temp[block_size] = temp[block_size - 1]; + temp[block_size - 1] = 0; + } + + for (int i = 1; i < block_size; i <<= 1) { + offset >>= 1; + __syncthreads(); + if (tid < i) { + int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j]; + temp[j] = temp[k]; + temp[k] += ts; + } + } + __syncthreads(); + + if (tid == 0) temp[0] = temp[block_size]; + __syncthreads(); + + if (idx + tps < s) { + temp[tid + 1] += last_sum; +#pragma unroll + for (int i = pack_size - 1; i > 0; --i) { + outputs[(tps + i) * e + bid] = temp[tid + 1]; + temp[tid + 1] -= pack[i]; + } + outputs[tps * e + bid] = temp[tid + 1]; + } + __syncthreads(); + + last_sum += temp[0]; + inputs += bpack_size * e; + outputs += bpack_size * e; + } +} + +// LAUNCH FUNCTIONS -------------------------------- + +template +void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, + int *mask2, int *dest1, int *dest2, const int s, + const int h) { + if (h < 256) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); + else + moe_dpch_fwd_kernel + <<>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); +} + +template +void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, + int *dest1, int *dest2, const int s, const int h) { + if (h < 256) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 512) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 1024) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else if (h < 2048) + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); + else + moe_dpch_bwd_kernel + <<>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); +} + +template +void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, + int *mask1, int *mask2, int *dest1, int *dest2, + const int s, const int e, const int c, const int h) { + if (h < 256) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 512) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 1024) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else if (h < 2048) + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, dest2, + e, c, h); + else + moe_cb_fwd_kernel<<>>(expert_tokens, combine_tokens, + logits, mask1, mask2, dest1, + dest2, e, c, h); +} + +template +void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, + T *logits_grad, int *mask1, int *mask2, int *dest1, + int *dest2, const int s, const int e, const int c, + const int h) { + if (h < 256) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + else // if (h < 512) + moe_cb_bwd_kernel<<>>(tokens_grad, expert_grad, tks, + logits, logits_grad, mask1, mask2, + dest1, dest2, e, c, h); + // else if (h < 1024) + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); + // else + // moe_cb_bwd_kernel<<>> + // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, + // dest1, dest2, e, c, h); +} + +void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { + if (s <= 256) + cumsum_kernel<256, 1><<>>(inputs, outputs, s, e); + else if (s <= 512) + cumsum_kernel<512, 1><<>>(inputs, outputs, s, e); + else if (s <= 1024) + cumsum_kernel<1024, 1><<>>(inputs, outputs, s, e); + else if (s <= 2048) + cumsum_kernel<1024, 2><<>>(inputs, outputs, s, e); + else + cumsum_kernel<1024, 4><<>>(inputs, outputs, s, e); +} + +// API FUNCTIONS -------------------------------- + +#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float: { \ + using scalar_t = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented yet for specific data type."); \ + } + +torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, + torch::Tensor batch_tokens, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {ec, h}, + torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + batch_tokens.scalar_type(), "moe dispatch forward", + moe_dpch_fwd_launch( + batch_tokens.data(), res.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, + torch::Tensor expert_grad, + torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + auto res = torch::zeros( + {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_grad.scalar_type(), "moe dispatch backward", + moe_dpch_bwd_launch( + res.data(), expert_grad.data(), + mask[0].data(), k == 1 ? nullptr : mask[1].data(), + dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, h)); + + return res; +} + +torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, + torch::Tensor expert_tokens, + torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(expert_tokens.dtype() == logits.dtype()); + + auto res = torch::zeros( + {s, h}, + torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + expert_tokens.scalar_type(), "moe combine forward", + moe_cb_fwd_launch( + expert_tokens.data(), res.data(), + logits.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return res; +} + +std::vector moe_combine_cuda_backward( + int s, int e, int c, int h, torch::Tensor tokens_grad, + torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, + torch::Tensor dest_idx) { + assert(h % 16 == 0); + assert(tokens_grad.dtype() == expert_tokens.dtype()); + assert(expert_tokens.dtype() == logits.dtype()); + + auto egrad = torch::zeros( + {e * c, h}, + torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), + wgrad = torch::zeros( + {s, e}, torch::dtype(logits.dtype()).device(logits.device())); + auto k = mask.size(0); + + DISPATCH_FLOAT_AND_HALF( + tokens_grad.scalar_type(), "moe combine backward", + moe_cb_bwd_launch( + tokens_grad.data(), egrad.data(), + expert_tokens.data(), logits.data(), + wgrad.data(), mask[0].data(), + k == 1 ? nullptr : mask[1].data(), dest_idx[0].data(), + k == 1 ? dest_idx[0].data() : dest_idx[1].data(), s, e, c, + h)); + + return {egrad, wgrad}; +} + +torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { + assert(mask.dim() == 2); + assert(mask.dtype() == torch::kInt32); + + const int s = mask.size(0), e = mask.size(1); + auto res = + torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); + cumsum_launch(mask.data(), res.data(), s, e); + + return res; +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu index 49ab83e8fc81..85f935152f8a 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_l2norm_kernel.cu @@ -379,4 +379,4 @@ void multi_tensor_norm_out_cuda( norm_type, alpha, beta); return; -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu index 54c4220190d8..63771cf40bcb 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_lamb.cu @@ -351,4 +351,4 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, lr, weight_decay, use_nvlamb);) AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu index 360485dcd02f..2f58a0f16dce 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu @@ -122,4 +122,4 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, AT_CUDA_CHECK(cudaGetLastError()); // AT_CUDA_CHECK(cudaDeviceSynchronize()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu index 35f2c9b4ed15..7f48dbd5d497 100644 --- a/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu +++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_sgd_kernel.cu @@ -164,4 +164,4 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, } AT_CUDA_CHECK(cudaGetLastError()); -} \ No newline at end of file +} diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp index 4ae3c853ca5e..8c2982b0cff9 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.cpp @@ -3,82 +3,68 @@ #include #include + #include namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, + int attn_heads); + +torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, + float scale_factor) { AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - +torch::Tensor bwd(torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, float scale_factor) { AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); return bwd_cuda(output_grads, softmax_results, scale_factor); } -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, + attn_heads); } -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); + m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); + &multihead_attn::fused_softmax::scaled_masked_softmax:: + get_batch_per_block, + "Return Batch per block size."); } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h index 1583030b8235..d3e6f04e6093 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_masked_softmax.h @@ -4,12 +4,12 @@ #pragma once #include +#include #include +#include + #include #include -#include -#include -#include namespace { @@ -17,37 +17,53 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -55,438 +71,468 @@ struct Max { }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Explicit masking + */ +template __global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } + output_t *dst, const input_t *src, const uint8_t *mask, const acc_t scale, + int micro_batch_size, int element_count, int pad_batches) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = + (blockDim.y * + (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) + + threadIdx.y) * + WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = + (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * + WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i * element_count + it * WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; } + copy_vector( + dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } } + } } -template +template __global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = + first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; } + } } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count + it * WARP_SIZE, out); + } } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, + int attn_heads) { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; } -} // end of anonymous namespace -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ +template +void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads, + int pad_batches) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { int log2_elements = log2_ceil(key_seq_len); const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + // use 128 threads per block to maximimize gpu utilization constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } + TORCH_INTERNAL_ASSERT(query_seq_len % batches_per_block == 0); + dim3 blocks(query_seq_len / batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>( + dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; } + } } -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } +template +void dispatch_scaled_masked_softmax_backward(output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, int key_seq_len, + int batches, int attn_heads) { + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, key_seq_len); + break; + default: + break; } + } } diff --git a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h index 3af487f9de0f..54c8e9133a1b 100644 --- a/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h +++ b/colossalai/kernel/cuda_native/csrc/scaled_upper_triang_masked_softmax.h @@ -4,11 +4,12 @@ #pragma once #include +#include #include +#include + #include #include -#include -#include namespace { @@ -16,53 +17,78 @@ template __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - +__device__ __inline__ void copy_vector( + c10::BFloat16 *dst, const c10::BFloat16 *src) { + *((float2 *)dst) = *((float2 *)src); +} + template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } +__device__ __inline__ void copy_vector(c10::Half *dst, + const c10::Half *src) { + *((float2 *)dst) = *((float2 *)src); +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *dst = *src; +} template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } +__device__ __inline__ void copy_vector(uint8_t *dst, + const uint8_t *src) { + *((half2 *)dst) = *((half2 *)src); +} template __device__ __inline__ void copy_zero_vector(Datatype *dst); template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } +__device__ __inline__ void copy_zero_vector( + c10::BFloat16 *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *dst = 0.0; +} template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { + *((float2 *)dst) = make_float2(0.0f, 0.0f); +} int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; } -template +template struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } + __device__ __forceinline__ T operator()(T a, T b) const { return a + b; } }; -template +template struct Max { __device__ __forceinline__ T operator()(T a, T b) const { return a < b ? b : a; @@ -70,431 +96,505 @@ struct Max { }; template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ +__device__ __forceinline__ T +WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, + unsigned int mask = 0xffffffff) { #if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); + return __shfl_xor_sync(mask, value, laneMask, width); #else - return __shfl_xor(value, laneMask, width); + return __shfl_xor(value, laneMask, width); #endif } -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t *sum) { + ReduceOp r; +#pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); } + } } /* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) + * Extended softmax (from native aten pytorch) with following additional + * features 1) input scaling 2) Implicit time (diagonal masking) */ -template +template __global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } + output_t *dst, const input_t *src, const acc_t scale, int micro_batch_size, + int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = + (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1) / WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector( + temp_data, src + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } } + } else { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } } + } - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } + // compute max_value + acc_t max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } } + copy_vector( + dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector( + dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } } + } } -template +template __global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } + output_t *gradInput, input_t *grad, const input_t *output, acc_t scale, + int micro_batch_size, int stride, int element_count) { + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; + + int first_batch = + (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the + // batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]{0.0f}; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector( + temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector( + temp_output, output + i * element_count * stride + it * WARP_SIZE); + +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = + (acc_t)temp_grad[element] * output_reg[i][it + element]; + } } + } } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } + } + + acc_t sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; +#pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = + (output_t)(scale * (grad_reg[i][it + element] - + output_reg[i][it + element] * sum[i])); } + copy_vector( + gradInput + i * element_count * stride + it * WARP_SIZE, out); + } } + } } -} // end of anonymous namespace +} // end of anonymous namespace -template +template void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } + output_t *dst, const input_t *src, const input_t scale, + int softmax_elements, int softmax_elements_stride, int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_forward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>( + dst, src, scale, batch_count, softmax_elements_stride, + softmax_elements); + break; + default: + break; } + } } -template +template void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } + output_t *grad_input, input_t *grad, const input_t *output, + const acc_t scale, int softmax_elements, int softmax_elements_stride, + int attn_batches) { + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside + // softmax_warp_backward. + int warp_size = + (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside + // softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>( + grad_input, grad, output, scale, batch_count, + softmax_elements_stride, softmax_elements); + break; + default: + break; } + } } diff --git a/colossalai/kernel/jit/__init__.py b/colossalai/kernel/jit/__init__.py index 57b8fb7b2e99..072ac927a918 100644 --- a/colossalai/kernel/jit/__init__.py +++ b/colossalai/kernel/jit/__init__.py @@ -1,8 +1,7 @@ -from .option import set_jit_fusion_options -from .bias_dropout_add import bias_dropout_add_fused_train, bias_dropout_add_fused_inference +from .bias_dropout_add import bias_dropout_add_fused_inference, bias_dropout_add_fused_train from .bias_gelu import bias_gelu_impl +from .option import set_jit_fusion_options __all__ = [ - "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", - "set_jit_fusion_options" + "bias_dropout_add_fused_train", "bias_dropout_add_fused_inference", "bias_gelu_impl", "set_jit_fusion_options" ] diff --git a/colossalai/kernel/jit/bias_dropout_add.py b/colossalai/kernel/jit/bias_dropout_add.py index 3687dde79a08..44ff443cc41d 100644 --- a/colossalai/kernel/jit/bias_dropout_add.py +++ b/colossalai/kernel/jit/bias_dropout_add.py @@ -9,16 +9,12 @@ def bias_dropout_add(x, bias, residual, prob, training): @torch.jit.script -def bias_dropout_add_fused_train(x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, +def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script -def bias_dropout_add_fused_inference(x: torch.Tensor, - bias: torch.Tensor, - residual: torch.Tensor, +def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) diff --git a/colossalai/logging/__init__.py b/colossalai/logging/__init__.py index 97fe4f89ded3..ef9b3384868e 100644 --- a/colossalai/logging/__init__.py +++ b/colossalai/logging/__init__.py @@ -12,7 +12,7 @@ def get_dist_logger(name: str = 'colossalai') -> DistributedLogger: Args: name (str): name of the logger, name must be unique - + Returns: :class:`colossalai.logging.DistributedLogger`: A distributed logger singleton instance. """ diff --git a/colossalai/nn/_ops/_utils.py b/colossalai/nn/_ops/_utils.py index 56bb5f465184..24877bbb552f 100644 --- a/colossalai/nn/_ops/_utils.py +++ b/colossalai/nn/_ops/_utils.py @@ -1,12 +1,11 @@ -import torch -from typing import Union, Optional, List -from colossalai.tensor import ColoTensor +from typing import List, Optional, Union + import torch import torch.distributed as dist -from colossalai.global_variables import tensor_parallel_env as env +from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn.layer.utils import divide -from colossalai.tensor import ProcessGroup, ColoTensorSpec +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup GeneralTensor = Union[ColoTensor, torch.Tensor] Number = Union[int, float] @@ -135,7 +134,7 @@ def backward(ctx, grad_output): class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. - + Args: input_: input matrix. process_group: parallel mode. diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index fe2eb0c999a1..660b48a71d57 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -1,9 +1,9 @@ import torch + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor -from colossalai.tensor import distspec, ColoTensorSpec, ShardSpec, ReplicaSpec -from ._utils import GeneralTensor, Number, convert_to_colo_tensor -from ._utils import reduce_input, reduce_grad + +from ._utils import GeneralTensor, Number, convert_to_colo_tensor, reduce_grad, reduce_input def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, @@ -69,9 +69,13 @@ def colo_addmm(input_tensor: GeneralTensor, if not mat2.has_compute_spec(): # No Model Parallel Applied assert mat2.is_replicate(), 'Invalid mat2 spec for native addmm op' assert input_tensor.is_replicate(), 'Invalid input spec for native addmm op' - ret_tensor = ColoTensor.from_torch_tensor( - tensor=torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha, **kargs), - spec=ColoTensorSpec(mat2.get_process_group())) + ret_tensor = ColoTensor.from_torch_tensor(tensor=torch.addmm(input_tensor, + mat1, + mat2, + beta=beta, + alpha=alpha, + **kargs), + spec=ColoTensorSpec(mat2.get_process_group())) elif mat2.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if mat2.is_shard_1drow() and input_tensor.is_replicate(): mode = 'row' diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index a045f305b5dc..b145d1763380 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -1,8 +1,10 @@ -import torch.nn.functional as F from typing import Optional + +import torch.nn.functional as F + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ComputePattern, ColoTensorSpec, ComputePattern, ComputeSpec, ColoTensor, ShardSpec, \ - ReplicaSpec + from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py index 0e8aa8fecb01..0c909381b1ff 100644 --- a/colossalai/nn/_ops/embedding_bag.py +++ b/colossalai/nn/_ops/embedding_bag.py @@ -1,9 +1,11 @@ -import torch.nn.functional as F from typing import Optional + +import torch.nn.functional as F from torch import Tensor + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ReplicaSpec, ShardSpec, distspec from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ComputePattern, ComputePattern, ComputeSpec, ColoTensor, distspec, ColoTensorSpec, \ - ShardSpec, ReplicaSpec + from ._utils import GeneralTensor, convert_to_colo_tensor diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/nn/_ops/layernorm.py index 2b761b84e3ee..9960c5d48096 100644 --- a/colossalai/nn/_ops/layernorm.py +++ b/colossalai/nn/_ops/layernorm.py @@ -1,7 +1,10 @@ from typing import List, Optional + import torch.nn.functional as F + +from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec, distspec from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ColoTensor, distspec, ColoTensorSpec, ReplicaSpec + from ._utils import GeneralTensor, convert_to_colo_tensor diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py index 1e54f662859c..45d4b97e8bea 100644 --- a/colossalai/nn/_ops/loss.py +++ b/colossalai/nn/_ops/loss.py @@ -1,9 +1,12 @@ +from typing import Optional + import torch import torch.nn.functional as F -from typing import Optional -from colossalai.tensor.op_wrapper import colo_op_impl -from colossalai.tensor import ColoTensor, ColoTensorSpec + from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.tensor.op_wrapper import colo_op_impl + from ._utils import GeneralTensor, convert_to_colo_tensor diff --git a/colossalai/nn/init.py b/colossalai/nn/init.py index 559b7038fc35..3b414a93f91c 100644 --- a/colossalai/nn/init.py +++ b/colossalai/nn/init.py @@ -1,8 +1,8 @@ import math import warnings -from torch import Tensor import torch.nn as nn +from torch import Tensor def zeros_(): diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py index b705632f8040..09c6615ea2ad 100644 --- a/colossalai/nn/layer/__init__.py +++ b/colossalai/nn/layer/__init__.py @@ -1,10 +1,10 @@ from .colossalai_layer import * +from .moe import * from .parallel_1d import * from .parallel_2d import * from .parallel_2p5d import * from .parallel_3d import * from .parallel_sequence import * -from .moe import * from .utils import * from .vanilla import * from .wrapper import * diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py index c85f53cc44c3..5234b6b1a1b5 100644 --- a/colossalai/nn/layer/base_layer.py +++ b/colossalai/nn/layer/base_layer.py @@ -1,11 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from contextlib import contextmanager + import torch.nn as nn from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from contextlib import contextmanager class ParallelLayer(nn.Module): diff --git a/colossalai/nn/layer/colossalai_layer/__init__.py b/colossalai/nn/layer/colossalai_layer/__init__.py index 2ae1b07a75b2..ed743820ddbc 100644 --- a/colossalai/nn/layer/colossalai_layer/__init__.py +++ b/colossalai/nn/layer/colossalai_layer/__init__.py @@ -1,7 +1,7 @@ -from ._utils import partition_batch -from .dropout import Dropout -from .embedding import Embedding, PatchEmbedding -from .linear import Classifier, Linear -from .normalization import LayerNorm - -__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch'] +from ._utils import partition_batch +from .dropout import Dropout +from .embedding import Embedding, PatchEmbedding +from .linear import Classifier, Linear +from .normalization import LayerNorm + +__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch'] diff --git a/colossalai/nn/layer/colossalai_layer/embedding.py b/colossalai/nn/layer/colossalai_layer/embedding.py index e5c9c46e0ff1..a0ad0848c292 100644 --- a/colossalai/nn/layer/colossalai_layer/embedding.py +++ b/colossalai/nn/layer/colossalai_layer/embedding.py @@ -1,151 +1,152 @@ -import math -from typing import Callable - -from colossalai.utils import get_current_device -from torch import dtype, nn - -from ... import init as init -from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D -from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D -from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D -from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D -from ..utils import get_tensor_parallel_mode -from ..vanilla import VanillaPatchEmbedding -from ._utils import ColossalaiModule - -_parallel_embedding = { - '1d': Embedding1D, - '2d': Embedding2D, - '2.5d': Embedding2p5D, - '3d': Embedding3D, -} - -_vocab_parallel_embedding = { - '1d': VocabParallelEmbedding1D, - '2d': VocabParallelEmbedding2D, - '2.5d': VocabParallelEmbedding2p5D, - '3d': VocabParallelEmbedding3D -} - -_parallel_patchembedding = { - None: VanillaPatchEmbedding, - '1d': PatchEmbedding1D, - '2d': PatchEmbedding2D, - '2.5d': PatchEmbedding2p5D, - '3d': PatchEmbedding3D -} - - -class Embedding(ColossalaiModule): - r"""Embedding for colossalai. - - Args: - num_embeddings (int): number of embeddings. - embedding_dim (int): dimension of embedding. - padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; - therefore, the embedding vector at padding_idx is not updated during training, - i.e. it remains as a fixed “pad”, defaults to None. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - weight_initializer (:class:`typing.Callable`, optional): - he initializer of weight, defaults to normal initializer. - - The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: - :: - - max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is - renormalized to have norm max_norm. Note: this will modify weight in-place. - norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. - scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse - of frequency of the words in the mini-batch. Default False. - sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. - - More details about ``args`` and ``kwargs`` could be found in - `Embedding `_. - - More details about ``initializer`` please refer to - `init `_ - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - padding_idx: int = None, - dtype: dtype = None, - weight_initializer: Callable = init.normal_(), - vocab_parallel_limit: int = 2048, - *args, - **kwargs) -> None: - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel is None: - embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, - **kwargs).to(dtype).to(get_current_device()) - weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) - elif num_embeddings <= vocab_parallel_limit: - embed = _parallel_embedding[tensor_parallel]( - num_embeddings, - embedding_dim, - padding_idx=padding_idx, - dtype=dtype, - weight_initializer=weight_initializer, - *args, - **kwargs, - ) - else: - embed = _vocab_parallel_embedding[tensor_parallel]( - num_embeddings, - embedding_dim, - padding_idx=padding_idx, - dtype=dtype, - weight_initializer=weight_initializer, - *args, - **kwargs, - ) - super().__init__(embed) - - -class PatchEmbedding(ColossalaiModule): - """2D Image to Patch Embedding. - - Args: - img_size (int): image size. - patch_size (int): patch size. - in_chans (int): number of channels of input image. - embed_size (int): size of embedding. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - flatten (bool, optional): whether to flatten output tensor, defaults to True. - weight_initializer (:class:`typing.Callable`, optional): - The initializer of weight, defaults to kaiming uniform initializer. - bias_initializer (:class:`typing.Callable`, optional): - The initializer of bias, defaults to xavier uniform initializer. - position_embed_initializer (:class:`typing.Callable`, optional): - The initializer of position embedding, defaults to zeros initializer. - - More details about ``initializer`` please refer to - `init `_. - """ - - def __init__( - self, - img_size: int, - patch_size: int, - in_chans: int, - embed_size: int, - dtype: dtype = None, - flatten: bool = True, - weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), - position_embed_initializer: Callable = init.zeros_() - ) -> None: - tensor_parallel = get_tensor_parallel_mode() - embed = _parallel_patchembedding[tensor_parallel]( - img_size, - patch_size, - in_chans, - embed_size, - dtype=dtype, - flatten=flatten, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - position_embed_initializer=position_embed_initializer, - ) - super().__init__(embed) +import math +from typing import Callable + +from torch import dtype, nn + +from colossalai.utils import get_current_device + +from ... import init as init +from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D +from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D +from ..parallel_2p5d import Embedding2p5D, PatchEmbedding2p5D, VocabParallelEmbedding2p5D +from ..parallel_3d import Embedding3D, PatchEmbedding3D, VocabParallelEmbedding3D +from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaPatchEmbedding +from ._utils import ColossalaiModule + +_parallel_embedding = { + '1d': Embedding1D, + '2d': Embedding2D, + '2.5d': Embedding2p5D, + '3d': Embedding3D, +} + +_vocab_parallel_embedding = { + '1d': VocabParallelEmbedding1D, + '2d': VocabParallelEmbedding2D, + '2.5d': VocabParallelEmbedding2p5D, + '3d': VocabParallelEmbedding3D +} + +_parallel_patchembedding = { + None: VanillaPatchEmbedding, + '1d': PatchEmbedding1D, + '2d': PatchEmbedding2D, + '2.5d': PatchEmbedding2p5D, + '3d': PatchEmbedding3D +} + + +class Embedding(ColossalaiModule): + r"""Embedding for colossalai. + + Args: + num_embeddings (int): number of embeddings. + embedding_dim (int): dimension of embedding. + padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient; + therefore, the embedding vector at padding_idx is not updated during training, + i.e. it remains as a fixed “pad”, defaults to None. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + weight_initializer (:class:`typing.Callable`, optional): + he initializer of weight, defaults to normal initializer. + + The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain: + :: + + max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is + renormalized to have norm max_norm. Note: this will modify weight in-place. + norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2. + scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse + of frequency of the words in the mini-batch. Default False. + sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False. + + More details about ``args`` and ``kwargs`` could be found in + `Embedding `_. + + More details about ``initializer`` please refer to + `init `_ + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + dtype: dtype = None, + weight_initializer: Callable = init.normal_(), + vocab_parallel_limit: int = 2048, + *args, + **kwargs) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is None: + embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, + **kwargs).to(dtype).to(get_current_device()) + weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) + elif num_embeddings <= vocab_parallel_limit: + embed = _parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + else: + embed = _vocab_parallel_embedding[tensor_parallel]( + num_embeddings, + embedding_dim, + padding_idx=padding_idx, + dtype=dtype, + weight_initializer=weight_initializer, + *args, + **kwargs, + ) + super().__init__(embed) + + +class PatchEmbedding(ColossalaiModule): + """2D Image to Patch Embedding. + + Args: + img_size (int): image size. + patch_size (int): patch size. + in_chans (int): number of channels of input image. + embed_size (int): size of embedding. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + flatten (bool, optional): whether to flatten output tensor, defaults to True. + weight_initializer (:class:`typing.Callable`, optional): + The initializer of weight, defaults to kaiming uniform initializer. + bias_initializer (:class:`typing.Callable`, optional): + The initializer of bias, defaults to xavier uniform initializer. + position_embed_initializer (:class:`typing.Callable`, optional): + The initializer of position embedding, defaults to zeros initializer. + + More details about ``initializer`` please refer to + `init `_. + """ + + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + embed_size: int, + dtype: dtype = None, + flatten: bool = True, + weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + position_embed_initializer: Callable = init.zeros_() + ) -> None: + tensor_parallel = get_tensor_parallel_mode() + embed = _parallel_patchembedding[tensor_parallel]( + img_size, + patch_size, + in_chans, + embed_size, + dtype=dtype, + flatten=flatten, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + position_embed_initializer=position_embed_initializer, + ) + super().__init__(embed) diff --git a/colossalai/nn/layer/colossalai_layer/normalization.py b/colossalai/nn/layer/colossalai_layer/normalization.py index 86861d30214a..f8e317e723f1 100644 --- a/colossalai/nn/layer/colossalai_layer/normalization.py +++ b/colossalai/nn/layer/colossalai_layer/normalization.py @@ -1,41 +1,42 @@ -from colossalai.utils import get_current_device -from torch import nn - -from ..parallel_1d import LayerNorm1D -from ..parallel_2d import LayerNorm2D -from ..parallel_2p5d import LayerNorm2p5D -from ..parallel_3d import LayerNorm3D -from ..utils import get_tensor_parallel_mode -from ..vanilla import VanillaLayerNorm -from ._utils import ColossalaiModule - -_parallel_layernorm = { - None: VanillaLayerNorm, - "1d": LayerNorm1D, - "2d": LayerNorm2D, - "2.5d": LayerNorm2p5D, - "3d": LayerNorm3D, -} - - -class LayerNorm(ColossalaiModule): - r"""Layer Normalization for colossalai. - - Args: - normalized_shape (int): input shape from an expected input of size. - :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] - \times \ldots \times \text{normalized_shape}[-1]]` - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. - bias (bool, optional): Whether to add a bias, defaults to ``True``. - dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. - """ - - def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel is None: - norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) - else: - norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) - super().__init__(norm) +from torch import nn + +from colossalai.utils import get_current_device + +from ..parallel_1d import LayerNorm1D +from ..parallel_2d import LayerNorm2D +from ..parallel_2p5d import LayerNorm2p5D +from ..parallel_3d import LayerNorm3D +from ..utils import get_tensor_parallel_mode +from ..vanilla import VanillaLayerNorm +from ._utils import ColossalaiModule + +_parallel_layernorm = { + None: VanillaLayerNorm, + "1d": LayerNorm1D, + "2d": LayerNorm2D, + "2.5d": LayerNorm2p5D, + "3d": LayerNorm3D, +} + + +class LayerNorm(ColossalaiModule): + r"""Layer Normalization for colossalai. + + Args: + normalized_shape (int): input shape from an expected input of size. + :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] + \times \ldots \times \text{normalized_shape}[-1]]` + If a single integer is used, it is treated as a singleton list, and this module will + normalize over the last dimension which is expected to be of that specific size. + eps (float): a value added to the denominator for numerical stability, defaults to 1e-05. + bias (bool, optional): Whether to add a bias, defaults to ``True``. + dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None. + """ + + def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel is None: + norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) + else: + norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) + super().__init__(norm) diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py index 2a51344c31a4..c4f9dcd0b313 100644 --- a/colossalai/nn/layer/moe/__init__.py +++ b/colossalai/nn/layer/moe/__init__.py @@ -1,9 +1,9 @@ -from .experts import Experts, FFNExperts, TPExperts -from .layers import MoeLayer, MoeModule -from .routers import MoeRouter, Top1Router, Top2Router -from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts - -__all__ = [ - 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', - 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter' -] +from .experts import Experts, FFNExperts, TPExperts +from .layers import MoeLayer, MoeModule +from .routers import MoeRouter, Top1Router, Top2Router +from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts + +__all__ = [ + 'Experts', 'FFNExperts', 'TPExperts', 'Top1Router', 'Top2Router', 'MoeLayer', 'NormalNoiseGenerator', + 'UniformNoiseGenerator', 'build_ffn_experts', 'MoeModule', 'MoeRouter' +] diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 055afded9a20..d391a2fb4eaf 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -1,172 +1,173 @@ -import math - -import torch -import torch.nn as nn -from colossalai.context import ParallelMode, seed -from colossalai.utils import get_current_device -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.zero.init_ctx import no_shard_zero_decrator -from typing import Type - - -class MoeExperts(nn.Module): - """Basic class for experts in MoE. It stores what kind of communication expersts use - to exchange tokens, how many experts in a single GPU and parallel information such as - expert parallel size, data parallel size and their distributed communication groups. - """ - - def __init__(self, comm_name: str, num_experts: int): - super().__init__() - assert comm_name in {"all_to_all", "all_gather"}, \ - "This kind of communication has not been implemented yet.\n Please use Experts build function." - self.comm_name = comm_name - # Get the configuration of experts' deployment and parallel information from moe contex - self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) - - -@no_shard_zero_decrator(is_replicated=False) -class Experts(MoeExperts): - """A wrapper class to create experts. It will create E experts across the - moe model parallel group, where E is the number of experts. Every expert - is a instence of the class, 'expert' in initialization parameters. - - Args: - expert_cls (:class:`torch.nn.Module`): The class of all experts - num_experts (int): The number of experts - expert_args: Args used to initialize experts, the args could be found in corresponding expert class - """ - - def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): - super().__init__("all_to_all", num_experts) - - # Use seed to make every expert different from others - with seed(ParallelMode.TENSOR): - self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) - - # Attach parallel information for all parameters in Experts - for exp in self.experts: - for param in exp.parameters(): - param.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs: torch.Tensor): - # Split inputs for each expert - expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) - expert_output = [] - - # Get outputs from each expert - for i in range(self.num_local_experts): - expert_output.append(self.experts[i](expert_input[i])) - - # Concatenate all outputs together - output = torch.cat(expert_output, dim=1).contiguous() - return output - - -class FFNExperts(MoeExperts): - """Use torch.bmm to speed up for multiple experts. - """ - - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_to_all", num_experts) - - self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) - - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) - - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - nn.init.trunc_normal_(self.b2, std=s2) - - self.act = nn.GELU() if activation is None else activation - self.drop = nn.Dropout(p=drop_rate) - - for param in self.parameters(): - param.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, el, c, h] - - el = inputs.size(1) - h = inputs.size(-1) - - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(el, -1, h) - - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) - - out_model = torch.baddbmm(self.b2, out_inter, self.w2) - with seed(ParallelMode.TENSOR): - outputs = self.drop(out_model) # outputs [el, gc, h] - - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs - - -class TPExperts(MoeExperts): - """Use tensor parallelism to split each expert evenly, which can deploy experts in - case that the number of experts can't be divied by maximum expert parallel size or - maximum expert parallel size can't be divied by the number of experts. - """ - - def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - super().__init__("all_gather", MOE_CONTEXT.max_ep_size) - - assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ - "d_ff should be divied by maximum expert parallel size" - - p_ff = d_ff // MOE_CONTEXT.max_ep_size - - self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) - self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) - - self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) - self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) - - s1 = math.sqrt(0.1 / d_model) - s2 = math.sqrt(0.1 / d_ff) - - with seed(ParallelMode.TENSOR): - nn.init.trunc_normal_(self.w1, std=s1) - nn.init.trunc_normal_(self.b1, std=s1) - nn.init.trunc_normal_(self.w2, std=s2) - - nn.init.trunc_normal_(self.b2, std=s2) - - self.act = nn.GELU() if activation is None else activation - self.drop = nn.Dropout(p=drop_rate) - - self.w1.__setattr__('moe_info', self.dist_info) - self.w2.__setattr__('moe_info', self.dist_info) - self.b1.__setattr__('moe_info', self.dist_info) - - def forward(self, inputs): # inputs [g, e, c, h] - - e = inputs.size(1) - h = inputs.size(-1) - - inputs = inputs.transpose(0, 1) - inshape = inputs.shape - inputs = inputs.reshape(e, -1, h) - - out_ff = torch.baddbmm(self.b1, inputs, self.w1) - out_act = self.act(out_ff) - with seed(ParallelMode.TENSOR): - out_inter = self.drop(out_act) - - out_model = torch.baddbmm(self.b2, out_inter, self.w2) - outputs = self.drop(out_model) # outputs [e, gc, h] - - outputs = outputs.reshape(inshape) - outputs = outputs.transpose(0, 1).contiguous() - return outputs # outputs [g, e, c, h] +import math +from typing import Type + +import torch +import torch.nn as nn + +from colossalai.context import ParallelMode, seed +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device +from colossalai.zero.init_ctx import no_shard_zero_decrator + + +class MoeExperts(nn.Module): + """Basic class for experts in MoE. It stores what kind of communication expersts use + to exchange tokens, how many experts in a single GPU and parallel information such as + expert parallel size, data parallel size and their distributed communication groups. + """ + + def __init__(self, comm_name: str, num_experts: int): + super().__init__() + assert comm_name in {"all_to_all", "all_gather"}, \ + "This kind of communication has not been implemented yet.\n Please use Experts build function." + self.comm_name = comm_name + # Get the configuration of experts' deployment and parallel information from moe contex + self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts) + + +@no_shard_zero_decrator(is_replicated=False) +class Experts(MoeExperts): + """A wrapper class to create experts. It will create E experts across the + moe model parallel group, where E is the number of experts. Every expert + is a instence of the class, 'expert' in initialization parameters. + + Args: + expert_cls (:class:`torch.nn.Module`): The class of all experts + num_experts (int): The number of experts + expert_args: Args used to initialize experts, the args could be found in corresponding expert class + """ + + def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args): + super().__init__("all_to_all", num_experts) + + # Use seed to make every expert different from others + with seed(ParallelMode.TENSOR): + self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)]) + + # Attach parallel information for all parameters in Experts + for exp in self.experts: + for param in exp.parameters(): + param.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs: torch.Tensor): + # Split inputs for each expert + expert_input = torch.chunk(inputs, self.num_local_experts, dim=1) + expert_output = [] + + # Get outputs from each expert + for i in range(self.num_local_experts): + expert_output.append(self.experts[i](expert_input[i])) + + # Concatenate all outputs together + output = torch.cat(expert_output, dim=1).contiguous() + return output + + +class FFNExperts(MoeExperts): + """Use torch.bmm to speed up for multiple experts. + """ + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_to_all", num_experts) + + self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + for param in self.parameters(): + param.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs): # inputs [g, el, c, h] + + el = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(el, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + out_inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, out_inter, self.w2) + with seed(ParallelMode.TENSOR): + outputs = self.drop(out_model) # outputs [el, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs + + +class TPExperts(MoeExperts): + """Use tensor parallelism to split each expert evenly, which can deploy experts in + case that the number of experts can't be divied by maximum expert parallel size or + maximum expert parallel size can't be divied by the number of experts. + """ + + def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + super().__init__("all_gather", MOE_CONTEXT.max_ep_size) + + assert d_ff % MOE_CONTEXT.max_ep_size == 0, \ + "d_ff should be divied by maximum expert parallel size" + + p_ff = d_ff // MOE_CONTEXT.max_ep_size + + self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device())) + self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device())) + + self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device())) + self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device())) + + s1 = math.sqrt(0.1 / d_model) + s2 = math.sqrt(0.1 / d_ff) + + with seed(ParallelMode.TENSOR): + nn.init.trunc_normal_(self.w1, std=s1) + nn.init.trunc_normal_(self.b1, std=s1) + nn.init.trunc_normal_(self.w2, std=s2) + + nn.init.trunc_normal_(self.b2, std=s2) + + self.act = nn.GELU() if activation is None else activation + self.drop = nn.Dropout(p=drop_rate) + + self.w1.__setattr__('moe_info', self.dist_info) + self.w2.__setattr__('moe_info', self.dist_info) + self.b1.__setattr__('moe_info', self.dist_info) + + def forward(self, inputs): # inputs [g, e, c, h] + + e = inputs.size(1) + h = inputs.size(-1) + + inputs = inputs.transpose(0, 1) + inshape = inputs.shape + inputs = inputs.reshape(e, -1, h) + + out_ff = torch.baddbmm(self.b1, inputs, self.w1) + out_act = self.act(out_ff) + with seed(ParallelMode.TENSOR): + out_inter = self.drop(out_act) + + out_model = torch.baddbmm(self.b2, out_inter, self.w2) + outputs = self.drop(out_model) # outputs [e, gc, h] + + outputs = outputs.reshape(inshape) + outputs = outputs.transpose(0, 1).contiguous() + return outputs # outputs [g, e, c, h] diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 259f53f1adf5..d7ad8cf04615 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -1,203 +1,210 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.utils import get_current_device -from colossalai.nn.layer.moe._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, \ - ReduceScatter, MoeDispatch, MoeCombine -from colossalai.nn.layer.moe.experts import MoeExperts, Experts -from colossalai.nn.layer.moe.utils import UniformNoiseGenerator, NormalNoiseGenerator -from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router -from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator -from typing import Optional, Type, Tuple - - -@no_shard_zero_decrator(is_replicated=True) -class MoeLayer(nn.Module): - """A MoE layer, that puts its input tensor to its gate and uses the output logits - to router all tokens, is mainly used to exchange all tokens for every expert across - the moe tensor group by all to all comunication. Then it will get the output of all - experts and exchange the output. At last returns the output of the moe system. - - Args: - dim_model (int): Dimension of model. - num_experts (int): The number of experts. - router (MoeRouter): Instance of router used in routing. - experts (MoeExperts): Instance of experts generated by Expert. - """ - - def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): - super().__init__() - self.d_model = dim_model - self.num_experts = num_experts - self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) - self.router: MoeRouter = router - self.experts: MoeExperts = experts - self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False - self.ep_group = experts.dist_info.ep_group - self.ep_size = experts.dist_info.ep_size - self.num_local_experts = experts.num_local_experts - - nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) - - def a2a_process(self, dispatch_data: torch.Tensor): - expert_input = AllToAll.apply(dispatch_data, self.ep_group) - input_shape = expert_input.shape - expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) - expert_output = self.experts(expert_input) - expert_output = expert_output.reshape(input_shape) - expert_output = AllToAll.apply(expert_output, self.ep_group) - return expert_output - - def tp_process(self, dispatch_data: torch.Tensor): - expert_in = AllGather.apply(dispatch_data, self.ep_group) - expert_out = self.experts(expert_in) - expert_out = ReduceScatter.apply(expert_out, self.ep_group) - return expert_out - - def forward(self, inputs: torch.Tensor) -> Tuple: - # reshape the input tokens - tokens = inputs.reshape(-1, self.d_model) - - # the data type of the inputs in the gating should be fp32 - fp32_input = tokens.to(torch.float) - fp32_weight = self.gate_weight.to(torch.float) - gate_output = F.linear(fp32_input, fp32_weight) - - # the result from the router - route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) - - if self.use_kernel: - dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) - dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) - else: - sec_mask_f = route_result_list[1].type_as(inputs) - dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) - - # dispatch_data [e, c, h] - if self.experts.comm_name == "all_to_all": - expert_output = self.a2a_process(dispatch_data) - elif self.experts.comm_name == "all_gather": - expert_output = self.tp_process(dispatch_data) - else: - raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " - "build function.") - # expert_output [e, c, h] - if self.use_kernel: - expert_output = expert_output.reshape(-1, self.d_model) - ans = MoeCombine.apply(expert_output, *route_result_list) - else: - combine_weights = route_result_list[0].type_as(inputs) - combine_weights = combine_weights.view(combine_weights.shape[0], -1) - expert_output = expert_output.view(-1, expert_output.shape[-1]) - ans = torch.matmul(combine_weights, expert_output) - - ans = ans.reshape(inputs.shape) - l_aux = self.router.pop_routing_loss() - return ans, l_aux - - -class MoeModule(nn.Module): - """A class for users to create MoE modules in their models. - - Args: - dim_model (int): Hidden dimension of training model - num_experts (int): The number experts - top_k (int, optional): The number of experts for dispatchment of each token - capacity_factor_train (float, optional): Capacity factor in routing during training - capacity_factor_eval (float, optional): Capacity factor in routing during evaluation - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. - 'Jitter' can be found in `Switch Transformer paper`_. - 'Gaussian' can be found in `ViT-MoE paper`_. - drop_tks (bool, optional): Whether drops tokens in evaluation - use_residual (bool, optional): Makes this MoE layer a Residual MoE. - More information can be found in `Microsoft paper`_. - residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE - expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer - expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given - expert_args (optional): The args of expert when no instance is given - - .. _Switch Transformer paper: - https://arxiv.org/abs/2101.03961 - .. _ViT-MoE paper: - https://arxiv.org/abs/2106.05974 - .. _Microsoft paper: - https://arxiv.org/abs/2201.05596 - """ - - def __init__(self, - dim_model: int, - num_experts: int, - top_k: int = 1, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_policy: Optional[str] = None, - drop_tks: bool = True, - use_residual: bool = False, - residual_instance: Optional[nn.Module] = None, - expert_instance: Optional[MoeExperts] = None, - expert_cls: Optional[Type[nn.Module]] = None, - **expert_args): - super().__init__() - - noisy_func = None - if noisy_policy is not None: - if noisy_policy == 'Jitter': - noisy_func = UniformNoiseGenerator() - elif noisy_policy == 'Gaussian': - noisy_func = NormalNoiseGenerator(num_experts) - else: - raise NotImplementedError("Unsupported input noisy policy") - - if top_k == 1: - moe_router_cls = Top1Router - elif top_k == 2: - moe_router_cls = Top2Router - else: - raise NotImplementedError("top_k > 2 is not supported yet") - - self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - self.use_residual = use_residual - if use_residual: - if residual_instance is not None: - self.residual_module = residual_instance - else: - assert expert_cls is not None, \ - "Expert class can't be None when residual instance is not given" - self.residual_module = expert_cls(**expert_args) - - with no_shard_zero_context(): - self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) - - if expert_instance is not None: - self.experts = expert_instance - else: - assert expert_cls is not None, \ - "Expert class can't be None when experts instance is not given" - self.experts = Experts(expert_cls, num_experts, **expert_args) - - self.moe_layer = MoeLayer(dim_model=dim_model, - num_experts=num_experts, - router=self.moe_router, - experts=self.experts) - - def forward(self, inputs: torch.Tensor): - moe_output, l_aux = self.moe_layer(inputs) - - if self.use_residual: - residual_output = self.residual_module(inputs) - combine_coef = self.residual_combine(inputs) - combine_coef = F.softmax(combine_coef, dim=-1) - output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] - else: - output = moe_output - - return output, l_aux +import math +from typing import Optional, Tuple, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import ( + COL_MOE_KERNEL_FLAG, + AllGather, + AllToAll, + MoeCombine, + MoeDispatch, + ReduceScatter, +) +from colossalai.nn.layer.moe.experts import Experts, MoeExperts +from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router +from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator +from colossalai.utils import get_current_device +from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator + + +@no_shard_zero_decrator(is_replicated=True) +class MoeLayer(nn.Module): + """A MoE layer, that puts its input tensor to its gate and uses the output logits + to router all tokens, is mainly used to exchange all tokens for every expert across + the moe tensor group by all to all comunication. Then it will get the output of all + experts and exchange the output. At last returns the output of the moe system. + + Args: + dim_model (int): Dimension of model. + num_experts (int): The number of experts. + router (MoeRouter): Instance of router used in routing. + experts (MoeExperts): Instance of experts generated by Expert. + """ + + def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts): + super().__init__() + self.d_model = dim_model + self.num_experts = num_experts + self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) + self.router: MoeRouter = router + self.experts: MoeExperts = experts + self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False + self.ep_group = experts.dist_info.ep_group + self.ep_size = experts.dist_info.ep_size + self.num_local_experts = experts.num_local_experts + + nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) + + def a2a_process(self, dispatch_data: torch.Tensor): + expert_input = AllToAll.apply(dispatch_data, self.ep_group) + input_shape = expert_input.shape + expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) + expert_output = self.experts(expert_input) + expert_output = expert_output.reshape(input_shape) + expert_output = AllToAll.apply(expert_output, self.ep_group) + return expert_output + + def tp_process(self, dispatch_data: torch.Tensor): + expert_in = AllGather.apply(dispatch_data, self.ep_group) + expert_out = self.experts(expert_in) + expert_out = ReduceScatter.apply(expert_out, self.ep_group) + return expert_out + + def forward(self, inputs: torch.Tensor) -> Tuple: + # reshape the input tokens + tokens = inputs.reshape(-1, self.d_model) + + # the data type of the inputs in the gating should be fp32 + fp32_input = tokens.to(torch.float) + fp32_weight = self.gate_weight.to(torch.float) + gate_output = F.linear(fp32_input, fp32_weight) + + # the result from the router + route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) + + if self.use_kernel: + dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) + dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) + else: + sec_mask_f = route_result_list[1].type_as(inputs) + dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) + + # dispatch_data [e, c, h] + if self.experts.comm_name == "all_to_all": + expert_output = self.a2a_process(dispatch_data) + elif self.experts.comm_name == "all_gather": + expert_output = self.tp_process(dispatch_data) + else: + raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " + "build function.") + # expert_output [e, c, h] + if self.use_kernel: + expert_output = expert_output.reshape(-1, self.d_model) + ans = MoeCombine.apply(expert_output, *route_result_list) + else: + combine_weights = route_result_list[0].type_as(inputs) + combine_weights = combine_weights.view(combine_weights.shape[0], -1) + expert_output = expert_output.view(-1, expert_output.shape[-1]) + ans = torch.matmul(combine_weights, expert_output) + + ans = ans.reshape(inputs.shape) + l_aux = self.router.pop_routing_loss() + return ans, l_aux + + +class MoeModule(nn.Module): + """A class for users to create MoE modules in their models. + + Args: + dim_model (int): Hidden dimension of training model + num_experts (int): The number experts + top_k (int, optional): The number of experts for dispatchment of each token + capacity_factor_train (float, optional): Capacity factor in routing during training + capacity_factor_eval (float, optional): Capacity factor in routing during evaluation + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'. + 'Jitter' can be found in `Switch Transformer paper`_. + 'Gaussian' can be found in `ViT-MoE paper`_. + drop_tks (bool, optional): Whether drops tokens in evaluation + use_residual (bool, optional): Makes this MoE layer a Residual MoE. + More information can be found in `Microsoft paper`_. + residual_instance (nn.Module, optional): The instance of residual module in Resiual MoE + expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer + expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given + expert_args (optional): The args of expert when no instance is given + + .. _Switch Transformer paper: + https://arxiv.org/abs/2101.03961 + .. _ViT-MoE paper: + https://arxiv.org/abs/2106.05974 + .. _Microsoft paper: + https://arxiv.org/abs/2201.05596 + """ + + def __init__(self, + dim_model: int, + num_experts: int, + top_k: int = 1, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_policy: Optional[str] = None, + drop_tks: bool = True, + use_residual: bool = False, + residual_instance: Optional[nn.Module] = None, + expert_instance: Optional[MoeExperts] = None, + expert_cls: Optional[Type[nn.Module]] = None, + **expert_args): + super().__init__() + + noisy_func = None + if noisy_policy is not None: + if noisy_policy == 'Jitter': + noisy_func = UniformNoiseGenerator() + elif noisy_policy == 'Gaussian': + noisy_func = NormalNoiseGenerator(num_experts) + else: + raise NotImplementedError("Unsupported input noisy policy") + + if top_k == 1: + moe_router_cls = Top1Router + elif top_k == 2: + moe_router_cls = Top2Router + else: + raise NotImplementedError("top_k > 2 is not supported yet") + + self.moe_router = moe_router_cls(capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.use_residual = use_residual + if use_residual: + if residual_instance is not None: + self.residual_module = residual_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when residual instance is not given" + self.residual_module = expert_cls(**expert_args) + + with no_shard_zero_context(): + self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device()) + + if expert_instance is not None: + self.experts = expert_instance + else: + assert expert_cls is not None, \ + "Expert class can't be None when experts instance is not given" + self.experts = Experts(expert_cls, num_experts, **expert_args) + + self.moe_layer = MoeLayer(dim_model=dim_model, + num_experts=num_experts, + router=self.moe_router, + experts=self.experts) + + def forward(self, inputs: torch.Tensor): + moe_output, l_aux = self.moe_layer(inputs) + + if self.use_residual: + residual_output = self.residual_module(inputs) + combine_coef = self.residual_combine(inputs) + combine_coef = F.softmax(combine_coef, dim=-1) + output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:] + else: + output = moe_output + + return output, l_aux diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py index c522c655a511..2bc6a5142256 100644 --- a/colossalai/nn/layer/moe/routers.py +++ b/colossalai/nn/layer/moe/routers.py @@ -1,226 +1,227 @@ -import math -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -from colossalai.utils import get_current_device -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer.moe._operation import moe_cumsum -from typing import Callable, Optional -from torch.distributed import ProcessGroup - - -class MoeRouter(nn.Module, ABC): - """Base class for all MoE routers. - Args: - k_value (int): The value of top_k. - capacity_factor_train (float): Capacity factor in routing of training. - capacity_factor_eval (float): Capacity factor in routing of evaluation. - min_capacity (int): The minimum number of the capacity of each expert. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - k_value: int, - capacity_factor_train: float, - capacity_factor_eval: float, - min_capacity: int, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__() - self.k_value = k_value - self.capacity_factor_train = capacity_factor_train - self.capacity_factor_eval = capacity_factor_eval - self.min_capacity = min_capacity - self.noisy_func = noisy_func - self.drop_tks = drop_tks - self._routing_loss = None - - def get_capacity(self, logits_shape): - capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval - capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) - capacity += capacity % 2 - capacity = max(capacity, self.min_capacity) - assert capacity > 0 - return capacity - - def set_routing_loss(self, aux_loss: torch.Tensor) -> None: - assert self._routing_loss is None - self._routing_loss = aux_loss - - def pop_routing_loss(self) -> torch.Tensor: - assert self._routing_loss is not None - reservation = self._routing_loss - self._routing_loss = None - return reservation - - -class Top1Router(MoeRouter): - """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More deailted function can be found in the paper about Switch Transformer - of Google. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert. - select_policy (str, optional): The policy about tokens selection. - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - select_policy: str = "first", - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=1, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - self.select_policy = select_policy - assert select_policy in {"first", "random"} - if select_policy == "random": - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), - high=torch.tensor(1.0, - device=get_current_device())).rsample - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(inputs, dim=-1) - mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(mask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(mask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - if self.select_policy == "random": - rand_mask = mask * self.uniform(mask.shape) - _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) - mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) - ranks = moe_cumsum(mask) - elif self.select_policy == "first": - ranks = moe_cumsum(mask) - mask = mask * torch.lt(ranks, capacity) - else: - raise NotImplementedError("Not support such select policy yet.") - - ranks = torch.sum(mask * ranks, dim=-1) - - if use_kernel: - mask = torch.sum(mask, dim=-1) - mask = torch.stack([mask], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) - return logits, mask, dest_idx, num_experts * capacity - else: - ranks = F.one_hot(ranks, num_classes=capacity) - weight = mask * logits.type_as(inputs) - combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) - sec_mask = combine_weights.bool() - return combine_weights, sec_mask - - -class Top2Router(MoeRouter): - """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] - for routing usage. More deailted function can be found in the paper about ViT-MoE. - Args: - capacity_factor_train (float, optional): Capacity factor in routing of training. - capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. - min_capacity (int, optional): The minimum number of the capacity of each expert - noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. - drop_tks (bool, optional): Whether drops tokens in evaluation. - """ - - def __init__(self, - capacity_factor_train: float = 1.25, - capacity_factor_eval: float = 2.0, - min_capacity: int = 4, - noisy_func: Callable = None, - drop_tks: bool = True): - super().__init__(k_value=2, - capacity_factor_train=capacity_factor_train, - capacity_factor_eval=capacity_factor_eval, - min_capacity=min_capacity, - noisy_func=noisy_func, - drop_tks=drop_tks) - - def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): - # inputs: [s, h] - if self.noisy_func is not None and self.training: - inputs = self.noisy_func(inputs) - - assert inputs.dtype == torch.float - logits = F.softmax(inputs, dim=-1) # logits: [s, e] - num_experts = logits.size(-1) - capacity = self.get_capacity(logits.shape) - - top1_idx = torch.argmax(logits, dim=-1) - mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) - logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) - top2_idx = torch.argmax(logits_except1, dim=-1) - mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) - - cmask = (mask1 + mask2) # loss: [s, e] - - # caculate the auxiliary loss - me = torch.mean(logits, dim=0) - ce = torch.mean(cmask.float(), dim=0) - l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 - self.set_routing_loss(l_aux) - - if not self.training and not self.drop_tks: - max_num = torch.max(torch.sum(cmask, dim=0)) - dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) - capacity = max_num.item() - - rank1 = moe_cumsum(mask1) # rank1: [s, e] - rank2 = moe_cumsum(mask2) - rank2 += torch.sum(mask1, dim=-2, keepdim=True) - - mask1 *= torch.lt(rank1, capacity) - mask2 *= torch.lt(rank2, capacity) - - rank1 = torch.sum(mask1 * rank1, dim=-1) - rank2 = torch.sum(mask2 * rank2, dim=-1) - - if use_kernel: - mask1 = torch.sum(mask1, dim=-1) - mask2 = torch.sum(mask2, dim=-1) - - mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) - dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) - - return logits, mask, dest_idx, num_experts * capacity - else: - weight1 = mask1 * logits.type_as(inputs) - weight2 = mask2 * logits.type_as(inputs) - rank1_sc = F.one_hot(rank1, num_classes=capacity) - rank2_sc = F.one_hot(rank2, num_classes=capacity) - - cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) - cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) - cb_weight = cb_weight1 + cb_weight2 - sec_mask = cb_weight.bool() - - return cb_weight, sec_mask +import math +from abc import ABC +from typing import Callable, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import ProcessGroup + +from colossalai.context import MOE_CONTEXT +from colossalai.nn.layer.moe._operation import moe_cumsum +from colossalai.utils import get_current_device + + +class MoeRouter(nn.Module, ABC): + """Base class for all MoE routers. + Args: + k_value (int): The value of top_k. + capacity_factor_train (float): Capacity factor in routing of training. + capacity_factor_eval (float): Capacity factor in routing of evaluation. + min_capacity (int): The minimum number of the capacity of each expert. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + k_value: int, + capacity_factor_train: float, + capacity_factor_eval: float, + min_capacity: int, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__() + self.k_value = k_value + self.capacity_factor_train = capacity_factor_train + self.capacity_factor_eval = capacity_factor_eval + self.min_capacity = min_capacity + self.noisy_func = noisy_func + self.drop_tks = drop_tks + self._routing_loss = None + + def get_capacity(self, logits_shape): + capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval + capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1]) + capacity += capacity % 2 + capacity = max(capacity, self.min_capacity) + assert capacity > 0 + return capacity + + def set_routing_loss(self, aux_loss: torch.Tensor) -> None: + assert self._routing_loss is None + self._routing_loss = aux_loss + + def pop_routing_loss(self) -> torch.Tensor: + assert self._routing_loss is not None + reservation = self._routing_loss + self._routing_loss = None + return reservation + + +class Top1Router(MoeRouter): + """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More deailted function can be found in the paper about Switch Transformer + of Google. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert. + select_policy (str, optional): The policy about tokens selection. + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + select_policy: str = "first", + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=1, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + self.select_policy = select_policy + assert select_policy in {"first", "random"} + if select_policy == "random": + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=get_current_device()), + high=torch.tensor(1.0, + device=get_current_device())).rsample + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(inputs, dim=-1) + mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(mask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(mask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + if self.select_policy == "random": + rand_mask = mask * self.uniform(mask.shape) + _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0) + mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1) + ranks = moe_cumsum(mask) + elif self.select_policy == "first": + ranks = moe_cumsum(mask) + mask = mask * torch.lt(ranks, capacity) + else: + raise NotImplementedError("Not support such select policy yet.") + + ranks = torch.sum(mask * ranks, dim=-1) + + if use_kernel: + mask = torch.sum(mask, dim=-1) + mask = torch.stack([mask], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32) + return logits, mask, dest_idx, num_experts * capacity + else: + ranks = F.one_hot(ranks, num_classes=capacity) + weight = mask * logits.type_as(inputs) + combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1) + sec_mask = combine_weights.bool() + return combine_weights, sec_mask + + +class Top2Router(MoeRouter): + """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c] + for routing usage. More deailted function can be found in the paper about ViT-MoE. + Args: + capacity_factor_train (float, optional): Capacity factor in routing of training. + capacity_factor_eval (float, optional): Capacity factor in routing of evaluation. + min_capacity (int, optional): The minimum number of the capacity of each expert + noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits. + drop_tks (bool, optional): Whether drops tokens in evaluation. + """ + + def __init__(self, + capacity_factor_train: float = 1.25, + capacity_factor_eval: float = 2.0, + min_capacity: int = 4, + noisy_func: Callable = None, + drop_tks: bool = True): + super().__init__(k_value=2, + capacity_factor_train=capacity_factor_train, + capacity_factor_eval=capacity_factor_eval, + min_capacity=min_capacity, + noisy_func=noisy_func, + drop_tks=drop_tks) + + def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None): + # inputs: [s, h] + if self.noisy_func is not None and self.training: + inputs = self.noisy_func(inputs) + + assert inputs.dtype == torch.float + logits = F.softmax(inputs, dim=-1) # logits: [s, e] + num_experts = logits.size(-1) + capacity = self.get_capacity(logits.shape) + + top1_idx = torch.argmax(logits, dim=-1) + mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32) + logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) + top2_idx = torch.argmax(logits_except1, dim=-1) + mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) + + cmask = (mask1 + mask2) # loss: [s, e] + + # caculate the auxiliary loss + me = torch.mean(logits, dim=0) + ce = torch.mean(cmask.float(), dim=0) + l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1 + self.set_routing_loss(l_aux) + + if not self.training and not self.drop_tks: + max_num = torch.max(torch.sum(cmask, dim=0)) + dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) + capacity = max_num.item() + + rank1 = moe_cumsum(mask1) # rank1: [s, e] + rank2 = moe_cumsum(mask2) + rank2 += torch.sum(mask1, dim=-2, keepdim=True) + + mask1 *= torch.lt(rank1, capacity) + mask2 *= torch.lt(rank2, capacity) + + rank1 = torch.sum(mask1 * rank1, dim=-1) + rank2 = torch.sum(mask2 * rank2, dim=-1) + + if use_kernel: + mask1 = torch.sum(mask1, dim=-1) + mask2 = torch.sum(mask2, dim=-1) + + mask = torch.stack([mask1, mask2], dim=0).to(torch.int32) + dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32) + + return logits, mask, dest_idx, num_experts * capacity + else: + weight1 = mask1 * logits.type_as(inputs) + weight2 = mask2 * logits.type_as(inputs) + rank1_sc = F.one_hot(rank1, num_classes=capacity) + rank2_sc = F.one_hot(rank2, num_classes=capacity) + + cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1) + cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1) + cb_weight = cb_weight1 + cb_weight2 + sec_mask = cb_weight.bool() + + return cb_weight, sec_mask diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index 9362347414e0..9fa6d0b94e8e 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,68 +1,70 @@ -import torch -import torch.nn.functional as F -from colossalai.utils import get_current_device -from colossalai.context.moe_context import MOE_CONTEXT -from .experts import FFNExperts, TPExperts - - -class ForceFP32Parameter(torch.nn.Parameter): - - def half(self, memory_format=None): - return self.data.clone() - - -class NormalNoiseGenerator: - """Generates a random noisy mask for logtis tensor. - - All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where - `E = the number of experts`. - - Args: - num_experts (int): The number of experts. - """ - - def __init__(self, num_experts: int): - self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), - scale=torch.tensor(1.0 / num_experts**2, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.normal(inputs.shape) - return inputs + noisy - - -class UniformNoiseGenerator: - """Generates a random noisy mask for logtis tensor. - copied from mesh tensorflow: - Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. - Makes models more resilient to rounding errors introduced by bfloat16. - This seems particularly important for logits. - - Args: - eps (float, optional): Epsilon in generator, defaults 1e-2. - """ - - def __init__(self, eps: float = 1e-2): - self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), - high=torch.tensor(1.0 + eps, - device=get_current_device())).rsample - - def __call__(self, inputs: torch.Tensor): - noisy = self.uniform(inputs.shape) - return inputs * noisy - - -def autocast_softmax(logit: torch.Tensor, dim: int): - if logit.dtype != torch.float32: - logit = logit.float() - return F.softmax(logit, dim=dim) - - -def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): - mep_size = MOE_CONTEXT.max_ep_size - if num_experts % mep_size == 0 or mep_size % num_experts == 0: - return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) - elif d_ff % mep_size == 0: - return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) - else: - raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") +import torch +import torch.nn.functional as F + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.utils import get_current_device + +from .experts import FFNExperts, TPExperts + + +class ForceFP32Parameter(torch.nn.Parameter): + + def half(self, memory_format=None): + return self.data.clone() + + +class NormalNoiseGenerator: + """Generates a random noisy mask for logtis tensor. + + All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where + `E = the number of experts`. + + Args: + num_experts (int): The number of experts. + """ + + def __init__(self, num_experts: int): + self.normal = torch.distributions.normal.Normal(loc=torch.tensor(0.0, device=get_current_device()), + scale=torch.tensor(1.0 / num_experts**2, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.normal(inputs.shape) + return inputs + noisy + + +class UniformNoiseGenerator: + """Generates a random noisy mask for logtis tensor. + copied from mesh tensorflow: + Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`. + Makes models more resilient to rounding errors introduced by bfloat16. + This seems particularly important for logits. + + Args: + eps (float, optional): Epsilon in generator, defaults 1e-2. + """ + + def __init__(self, eps: float = 1e-2): + self.uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - eps, device=get_current_device()), + high=torch.tensor(1.0 + eps, + device=get_current_device())).rsample + + def __call__(self, inputs: torch.Tensor): + noisy = self.uniform(inputs.shape) + return inputs * noisy + + +def autocast_softmax(logit: torch.Tensor, dim: int): + if logit.dtype != torch.float32: + logit = logit.float() + return F.softmax(logit, dim=dim) + + +def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0): + mep_size = MOE_CONTEXT.max_ep_size + if num_experts % mep_size == 0 or mep_size % num_experts == 0: + return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate) + elif d_ff % mep_size == 0: + return TPExperts(num_experts, d_model, d_ff, activation, drop_rate) + else: + raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.") diff --git a/colossalai/nn/layer/parallel_1d/__init__.py b/colossalai/nn/layer/parallel_1d/__init__.py index 2353851df665..9cffd4d339f5 100644 --- a/colossalai/nn/layer/parallel_1d/__init__.py +++ b/colossalai/nn/layer/parallel_1d/__init__.py @@ -1,5 +1,15 @@ -from .layers import (Classifier1D, Dropout1D, Embedding1D, LayerNorm1D, Linear1D, Linear1D_Col, Linear1D_Row, - PatchEmbedding1D, VocabParallelClassifier1D, VocabParallelEmbedding1D) +from .layers import ( + Classifier1D, + Dropout1D, + Embedding1D, + LayerNorm1D, + Linear1D, + Linear1D_Col, + Linear1D_Row, + PatchEmbedding1D, + VocabParallelClassifier1D, + VocabParallelEmbedding1D, +) __all__ = [ 'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D', diff --git a/colossalai/nn/layer/parallel_1d/_operation.py b/colossalai/nn/layer/parallel_1d/_operation.py index 394334558275..300baf9c12ba 100644 --- a/colossalai/nn/layer/parallel_1d/_operation.py +++ b/colossalai/nn/layer/parallel_1d/_operation.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist + from colossalai.core import global_context as gpc try: diff --git a/colossalai/nn/layer/parallel_1d/_utils.py b/colossalai/nn/layer/parallel_1d/_utils.py index 1212d595635d..fddf4e73db51 100644 --- a/colossalai/nn/layer/parallel_1d/_utils.py +++ b/colossalai/nn/layer/parallel_1d/_utils.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist + from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env @@ -124,7 +125,7 @@ def backward(ctx, grad_output): class _SplitForwardGatherBackward(torch.autograd.Function): """ Split the input and keep only the corresponding chuck to the rank. - + Args: input_: input matrix. parallel_mode: parallel mode. diff --git a/colossalai/nn/layer/parallel_2d/__init__.py b/colossalai/nn/layer/parallel_2d/__init__.py index 5562d1a70036..9c65f3608710 100644 --- a/colossalai/nn/layer/parallel_2d/__init__.py +++ b/colossalai/nn/layer/parallel_2d/__init__.py @@ -1,6 +1,13 @@ from ._operation import reduce_by_batch_2d, split_batch_2d -from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D, - VocabParallelEmbedding2D) +from .layers import ( + Classifier2D, + Embedding2D, + LayerNorm2D, + Linear2D, + PatchEmbedding2D, + VocabParallelClassifier2D, + VocabParallelEmbedding2D, +) __all__ = [ 'split_batch_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index 306577dbd933..0b856c18ace0 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -2,13 +2,14 @@ import torch import torch.distributed as dist -from colossalai.communication.collective import (all_gather, all_reduce, reduce, reduce_scatter) -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import get_current_device from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd + +from colossalai.communication.collective import all_gather, all_reduce, reduce, reduce_scatter +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env +from colossalai.utils import get_current_device def matmul_2d( diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index f3a4d2bbbc32..d49d67dbf4f3 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -5,6 +5,9 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc @@ -13,13 +16,19 @@ from colossalai.registry import LAYERS from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch.nn import Parameter from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, - reduce_scatter_tensor_2d, split_batch_2d) +from ._operation import ( + Matmul_AB_2D, + Matmul_ABT_2D, + add_bias_2d, + all_gather_tensor_2d, + classifier_2d, + layernorm_2d, + reduce_scatter_tensor_2d, + split_batch_2d, +) from ._utils import assert_summa_initialization, get_summa_dim_from_env diff --git a/colossalai/nn/layer/parallel_2p5d/__init__.py b/colossalai/nn/layer/parallel_2p5d/__init__.py index bec3b1c4b0b8..23e47e6ed06b 100644 --- a/colossalai/nn/layer/parallel_2p5d/__init__.py +++ b/colossalai/nn/layer/parallel_2p5d/__init__.py @@ -1,6 +1,13 @@ from ._operation import reduce_by_batch_2p5d, split_batch_2p5d -from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D, - VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D) +from .layers import ( + Classifier2p5D, + Embedding2p5D, + LayerNorm2p5D, + Linear2p5D, + PatchEmbedding2p5D, + VocabParallelClassifier2p5D, + VocabParallelEmbedding2p5D, +) __all__ = [ 'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', diff --git a/colossalai/nn/layer/parallel_2p5d/_operation.py b/colossalai/nn/layer/parallel_2p5d/_operation.py index 5a0f537cd6d9..c0c0a4ba20b8 100644 --- a/colossalai/nn/layer/parallel_2p5d/_operation.py +++ b/colossalai/nn/layer/parallel_2p5d/_operation.py @@ -2,12 +2,13 @@ import torch import torch.distributed as dist -from colossalai.communication.collective import (all_gather, all_reduce, reduce_scatter) +from torch import Tensor +from torch.cuda.amp import custom_bwd, custom_fwd + +from colossalai.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import get_current_device -from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd def get_parallel_group(parallel_mode: ParallelMode): diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index f849cbbe7b0d..ffbaedb14d4c 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -5,22 +5,34 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor +from torch.nn import Parameter + from colossalai.communication import broadcast from colossalai.context import ParallelMode, seed from colossalai.core import global_context as gpc from colossalai.global_variables import tensor_parallel_env as env from colossalai.nn import init as init from colossalai.registry import LAYERS -from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, - partition_tensor_parallel_state_dict) +from colossalai.utils.checkpointing import ( + broadcast_state_dict, + gather_tensor_parallel_state_dict, + partition_tensor_parallel_state_dict, +) from colossalai.utils.cuda import get_current_device -from torch import Tensor -from torch.nn import Parameter from ..base_layer import ParallelLayer from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple -from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d, - layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d) +from ._operation import ( + Matmul_AB_2p5D, + Matmul_ABT_2p5D, + add_bias_2p5d, + all_gather_tensor_2p5d, + classifier_2p5d, + layernorm_2p5d, + reduce_scatter_tensor_2p5d, + split_batch_2p5d, +) from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env diff --git a/colossalai/nn/layer/parallel_3d/__init__.py b/colossalai/nn/layer/parallel_3d/__init__.py index 9ae255b449ee..17fe8403c585 100644 --- a/colossalai/nn/layer/parallel_3d/__init__.py +++ b/colossalai/nn/layer/parallel_3d/__init__.py @@ -1,6 +1,13 @@ from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d -from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D, - VocabParallelEmbedding3D) +from .layers import ( + Classifier3D, + Embedding3D, + LayerNorm3D, + Linear3D, + PatchEmbedding3D, + VocabParallelClassifier3D, + VocabParallelEmbedding3D, +) __all__ = [ 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', diff --git a/colossalai/nn/layer/parallel_sequence/__init__.py b/colossalai/nn/layer/parallel_sequence/__init__.py index 4fa9eed6f34b..d92d66d40a8e 100644 --- a/colossalai/nn/layer/parallel_sequence/__init__.py +++ b/colossalai/nn/layer/parallel_sequence/__init__.py @@ -1,4 +1,4 @@ -from ._operation import RingQK, RingAV +from ._operation import RingAV, RingQK from .layers import TransformerSelfAttentionRing __all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK'] diff --git a/colossalai/nn/layer/parallel_sequence/_operation.py b/colossalai/nn/layer/parallel_sequence/_operation.py index fc80494224c6..5b905bae03e9 100644 --- a/colossalai/nn/layer/parallel_sequence/_operation.py +++ b/colossalai/nn/layer/parallel_sequence/_operation.py @@ -3,13 +3,13 @@ import torch from torch import distributed as dist +from torch.cuda.amp import custom_bwd, custom_fwd from colossalai.communication import ring_forward from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range +from colossalai.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range from colossalai.utils import get_current_device -from torch.cuda.amp import custom_bwd, custom_fwd class RingQK(torch.autograd.Function): diff --git a/colossalai/nn/layer/parallel_sequence/layers.py b/colossalai/nn/layer/parallel_sequence/layers.py index d9486217bbc9..e2d42306797a 100644 --- a/colossalai/nn/layer/parallel_sequence/layers.py +++ b/colossalai/nn/layer/parallel_sequence/layers.py @@ -2,20 +2,20 @@ # -*- encoding: utf-8 -*- import math -import colossalai import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import Parameter +import colossalai +from colossalai.context import seed from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV -from colossalai.registry import LAYERS -from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType from colossalai.kernel import FusedScaleMaskSoftmax -from colossalai.context import seed +from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType +from colossalai.nn.layer.parallel_sequence._operation import RingAV, RingQK +from colossalai.registry import LAYERS @LAYERS.register_module diff --git a/colossalai/nn/layer/utils/__init__.py b/colossalai/nn/layer/utils/__init__.py index 7e999ee82149..56e969bfd0bd 100644 --- a/colossalai/nn/layer/utils/__init__.py +++ b/colossalai/nn/layer/utils/__init__.py @@ -1,7 +1,15 @@ -from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode, - set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple) - -__all__ = [ - 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size', - 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple' -] +from .common import ( + ACT2FN, + CheckpointModule, + _ntuple, + divide, + get_tensor_parallel_mode, + set_tensor_parallel_attribute_by_partition, + set_tensor_parallel_attribute_by_size, + to_2tuple, +) + +__all__ = [ + 'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size', + 'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple' +] diff --git a/colossalai/nn/layer/utils/common.py b/colossalai/nn/layer/utils/common.py index f2297304fdc9..d8f3ad2a7eca 100644 --- a/colossalai/nn/layer/utils/common.py +++ b/colossalai/nn/layer/utils/common.py @@ -6,10 +6,11 @@ import numpy as np import torch +from torch import Tensor, nn + from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS from colossalai.global_variables import tensor_parallel_env as env from colossalai.utils import checkpoint -from torch import Tensor, nn class CheckpointModule(nn.Module): diff --git a/colossalai/nn/layer/wrapper/pipeline_wrapper.py b/colossalai/nn/layer/wrapper/pipeline_wrapper.py index ef1d794cc68f..68fea8622c5c 100644 --- a/colossalai/nn/layer/wrapper/pipeline_wrapper.py +++ b/colossalai/nn/layer/wrapper/pipeline_wrapper.py @@ -1,6 +1,8 @@ -import torch.nn as nn -import torch.distributed as dist from typing import List, Tuple, Union + +import torch.distributed as dist +import torch.nn as nn + from colossalai.context import ParallelMode from colossalai.core import global_context as gpc diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py index 373e4ec9468b..8722b1eb0fc4 100644 --- a/colossalai/nn/loss/__init__.py +++ b/colossalai/nn/loss/__init__.py @@ -1,9 +1,10 @@ -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn.layer.utils import get_tensor_parallel_mode from torch import nn from torch.nn.modules.loss import * from torch.nn.modules.loss import _Loss +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn.layer.utils import get_tensor_parallel_mode + from .loss_1d import VocabParallelCrossEntropyLoss1D from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D diff --git a/colossalai/nn/loss/loss_1d.py b/colossalai/nn/loss/loss_1d.py index 2fabd954f8fb..3d0aefb52be4 100644 --- a/colossalai/nn/loss/loss_1d.py +++ b/colossalai/nn/loss/loss_1d.py @@ -1,105 +1,106 @@ -import torch -import torch.distributed as dist -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.registry import LOSSES -from torch.cuda.amp import custom_bwd, custom_fwd -from torch.nn.modules.loss import _Loss - - -class _VocabParallelCrossEntropy1D(torch.autograd.Function): - - @staticmethod - @custom_fwd(cast_inputs=torch.float32) - def forward(ctx, vocab_parallel_logits, targets, process_group): - if process_group is None: - process_group = gpc.get_group(ParallelMode.PARALLEL_1D) - - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) - # Subtract the maximum value. - vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) - - # Get the partition's vocab indecies - partition_vocab_size = vocab_parallel_logits.size()[-1] - rank = dist.get_rank(process_group) - vocab_start_index = partition_vocab_size * rank - vocab_end_index = vocab_start_index + partition_vocab_size - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) - masked_target = targets.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(targets) - predicted_logits[target_mask] = 0.0 - # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) - - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = torch.exp(vocab_parallel_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) - - # Loss = log(sum(exp(logits))) - predicted-logit. - loss = torch.log(sum_exp_logits) - predicted_logits - # Store softmax, target-mask and masked-target for backward pass. - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - return loss - - @staticmethod - @custom_bwd - def backward(ctx, grad_output): - - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d = ctx.saved_tensors - - # All the inputs have softmax as thier gradient. - grad_input = softmax - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) - - # Finally elementwise multiplication with the output gradients. - grad_input.mul_(grad_output.unsqueeze(dim=-1)) - - return grad_input, None, None - - -@LOSSES.register_module -class VocabParallelCrossEntropyLoss1D(_Loss): - """Vocab parallel cross entropy loss for 1D parallelism. - - Args: - reduction (bool, optional): whether to average the loss, defaults to True. - """ - - def __init__(self, reduction=True): - super().__init__() - self.reduction_mean = reduction - - def forward(self, logits, targets, process_group=None): - """Calculate loss between logits and targets. - - Args: - logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. - """ - loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) - if self.reduction_mean: - loss = loss.mean() - return loss +import torch +import torch.distributed as dist +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.modules.loss import _Loss + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.registry import LOSSES + + +class _VocabParallelCrossEntropy1D(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, vocab_parallel_logits, targets, process_group): + if process_group is None: + process_group = gpc.get_group(ParallelMode.PARALLEL_1D) + + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) + # Subtract the maximum value. + vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) + + # Get the partition's vocab indecies + partition_vocab_size = vocab_parallel_logits.size()[-1] + rank = dist.get_rank(process_group) + vocab_start_index = partition_vocab_size * rank + vocab_end_index = vocab_start_index + partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) + masked_target = targets.clone() - vocab_start_index + masked_target[target_mask] = 0 + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(targets) + predicted_logits[target_mask] = 0.0 + # All reduce is needed to get the chunks from other GPUs. + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Sum of exponential of logits along vocab dimension across all GPUs. + exp_logits = torch.exp(vocab_parallel_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) + + # Loss = log(sum(exp(logits))) - predicted-logit. + loss = torch.log(sum_exp_logits) - predicted_logits + # Store softmax, target-mask and masked-target for backward pass. + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) + return loss + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) + + # Finally elementwise multiplication with the output gradients. + grad_input.mul_(grad_output.unsqueeze(dim=-1)) + + return grad_input, None, None + + +@LOSSES.register_module +class VocabParallelCrossEntropyLoss1D(_Loss): + """Vocab parallel cross entropy loss for 1D parallelism. + + Args: + reduction (bool, optional): whether to average the loss, defaults to True. + """ + + def __init__(self, reduction=True): + super().__init__() + self.reduction_mean = reduction + + def forward(self, logits, targets, process_group=None): + """Calculate loss between logits and targets. + + Args: + logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. + """ + loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) + if self.reduction_mean: + loss = loss.mean() + return loss diff --git a/colossalai/nn/loss/loss_2d.py b/colossalai/nn/loss/loss_2d.py index cb12e723c323..3c67cbbb595e 100644 --- a/colossalai/nn/loss/loss_2d.py +++ b/colossalai/nn/loss/loss_2d.py @@ -1,14 +1,15 @@ import torch import torch.distributed as dist +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.registry import LOSSES from colossalai.utils import get_current_device -from torch.cuda.amp import custom_bwd, custom_fwd -from torch.nn.functional import cross_entropy -from torch.nn.modules.loss import _Loss @LOSSES.register_module diff --git a/colossalai/nn/loss/loss_2p5d.py b/colossalai/nn/loss/loss_2p5d.py index f8e3324fc5ff..c5fd363bacb5 100644 --- a/colossalai/nn/loss/loss_2p5d.py +++ b/colossalai/nn/loss/loss_2p5d.py @@ -1,14 +1,15 @@ import torch import torch.distributed as dist +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.registry import LOSSES from colossalai.utils import get_current_device -from torch.cuda.amp import custom_bwd, custom_fwd -from torch.nn.functional import cross_entropy -from torch.nn.modules.loss import _Loss @LOSSES.register_module diff --git a/colossalai/nn/loss/loss_3d.py b/colossalai/nn/loss/loss_3d.py index e76439191fdb..1cf8200296ac 100644 --- a/colossalai/nn/loss/loss_3d.py +++ b/colossalai/nn/loss/loss_3d.py @@ -1,14 +1,15 @@ import torch import torch.distributed as dist -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D +from torch.cuda.amp import custom_bwd, custom_fwd +from torch.nn.functional import cross_entropy +from torch.nn.modules.loss import _Loss + +from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.core import global_context as gpc from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.registry import LOSSES from colossalai.utils import get_current_device -from torch.cuda.amp import custom_bwd, custom_fwd -from torch.nn.functional import cross_entropy -from torch.nn.modules.loss import _Loss @LOSSES.register_module diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py index a8b18a3e37ee..d96b1965bb76 100644 --- a/colossalai/nn/loss/loss_moe.py +++ b/colossalai/nn/loss/loss_moe.py @@ -1,80 +1,81 @@ -import torch.nn as nn -from colossalai.registry import LOSSES -from torch.nn.modules.loss import _Loss -from colossalai.context.moe_context import MOE_CONTEXT - - -@LOSSES.register_module -class MoeCrossEntropyLoss(_Loss): - r"""torch.nn.CrossEntropyLoss added with auxiliary loss. - - Args: - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. - - The ``args`` and ``kwargs`` should include parameters below: - :: - - weight (Tensor, optional) - size_average (bool, optional) - ignore_index (int, optional) - reduce (bool, optional) - reduction (str, optional) - label_smoothing (float, optional) - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - - def __init__(self, aux_weight: float = 0.01, *args, **kwargs): - super().__init__() - self.loss = nn.CrossEntropyLoss(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args): - """ - The ``args`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in - `Cross_entropy `_. - """ - main_loss = self.loss(*args) - aux_loss = MOE_CONTEXT.get_loss() - return main_loss + self.aux_weight * aux_loss - - -@LOSSES.register_module -class MoeLoss(_Loss): - """A wrapper class for any loss module to add with auxiliary loss. - - Args: - aux_weight (float): Weight of auxiliary loss in total loss. - loss_fn (``Callable``): Loss function. - args (list): Args in loss function. - kwargs (dict): Kwargs in loss function - """ - - def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): - super().__init__() - self.loss_fn = loss_fn(*args, **kwargs) - self.aux_weight = aux_weight - - def forward(self, *args, **kwargs): - """ - The ``args`` and ``kwargs`` should at least include parameters below: - :: - - input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). - target (:class:`torch.tensor`): Ground truth class indices or class probabilities. - - Note: - The ``args`` and ``kwargs`` may include different parameters varying with different loss function. - """ - main_loss = self.loss_fn(*args, **kwargs) - aux_loss = MOE_CONTEXT.get_loss() - return main_loss + self.aux_weight * aux_loss +import torch.nn as nn +from torch.nn.modules.loss import _Loss + +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.registry import LOSSES + + +@LOSSES.register_module +class MoeCrossEntropyLoss(_Loss): + r"""torch.nn.CrossEntropyLoss added with auxiliary loss. + + Args: + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. + + The ``args`` and ``kwargs`` should include parameters below: + :: + + weight (Tensor, optional) + size_average (bool, optional) + ignore_index (int, optional) + reduce (bool, optional) + reduction (str, optional) + label_smoothing (float, optional) + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + + def __init__(self, aux_weight: float = 0.01, *args, **kwargs): + super().__init__() + self.loss = nn.CrossEntropyLoss(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args): + """ + The ``args`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in + `Cross_entropy `_. + """ + main_loss = self.loss(*args) + aux_loss = MOE_CONTEXT.get_loss() + return main_loss + self.aux_weight * aux_loss + + +@LOSSES.register_module +class MoeLoss(_Loss): + """A wrapper class for any loss module to add with auxiliary loss. + + Args: + aux_weight (float): Weight of auxiliary loss in total loss. + loss_fn (``Callable``): Loss function. + args (list): Args in loss function. + kwargs (dict): Kwargs in loss function + """ + + def __init__(self, aux_weight: float, loss_fn, *args, **kwargs): + super().__init__() + self.loss_fn = loss_fn(*args, **kwargs) + self.aux_weight = aux_weight + + def forward(self, *args, **kwargs): + """ + The ``args`` and ``kwargs`` should at least include parameters below: + :: + + input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). + target (:class:`torch.tensor`): Ground truth class indices or class probabilities. + + Note: + The ``args`` and ``kwargs`` may include different parameters varying with different loss function. + """ + main_loss = self.loss_fn(*args, **kwargs) + aux_loss = MOE_CONTEXT.get_loss() + return main_loss + self.aux_weight * aux_loss diff --git a/colossalai/nn/lr_scheduler/__init__.py b/colossalai/nn/lr_scheduler/__init__.py index 34731ee901a0..cab06eba2034 100644 --- a/colossalai/nn/lr_scheduler/__init__.py +++ b/colossalai/nn/lr_scheduler/__init__.py @@ -3,7 +3,7 @@ from .multistep import MultiStepLR, MultiStepWarmupLR from .onecycle import OneCycleLR from .poly import PolynomialLR, PolynomialWarmupLR -from .torch import LambdaLR, MultiplicativeLR, StepLR, ExponentialLR +from .torch import ExponentialLR, LambdaLR, MultiplicativeLR, StepLR __all__ = [ 'CosineAnnealingLR', 'CosineAnnealingWarmupLR', 'FlatAnnealingLR', 'FlatAnnealingWarmupLR', 'LinearWarmupLR', diff --git a/colossalai/nn/lr_scheduler/cosine.py b/colossalai/nn/lr_scheduler/cosine.py index aab523bef8b3..01a265de9cd5 100644 --- a/colossalai/nn/lr_scheduler/cosine.py +++ b/colossalai/nn/lr_scheduler/cosine.py @@ -1,6 +1,7 @@ from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR from colossalai.registry import LR_SCHEDULERS + from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler diff --git a/colossalai/nn/lr_scheduler/multistep.py b/colossalai/nn/lr_scheduler/multistep.py index 29531a9e3855..a5950ee3f083 100644 --- a/colossalai/nn/lr_scheduler/multistep.py +++ b/colossalai/nn/lr_scheduler/multistep.py @@ -3,6 +3,7 @@ from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR from colossalai.registry import LR_SCHEDULERS + from .delayed import WarmupScheduler diff --git a/colossalai/nn/lr_scheduler/poly.py b/colossalai/nn/lr_scheduler/poly.py index 16352bc5175f..7ef567d2fb4a 100644 --- a/colossalai/nn/lr_scheduler/poly.py +++ b/colossalai/nn/lr_scheduler/poly.py @@ -1,6 +1,7 @@ from torch.optim.lr_scheduler import _LRScheduler from colossalai.registry import LR_SCHEDULERS + from .delayed import WarmupScheduler diff --git a/colossalai/nn/lr_scheduler/torch.py b/colossalai/nn/lr_scheduler/torch.py index 05d2a49c1ea5..7e0172f032b0 100644 --- a/colossalai/nn/lr_scheduler/torch.py +++ b/colossalai/nn/lr_scheduler/torch.py @@ -1,7 +1,7 @@ +from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR from torch.optim.lr_scheduler import LambdaLR as _LambdaLR from torch.optim.lr_scheduler import MultiplicativeLR as _MultiplicativeLR from torch.optim.lr_scheduler import StepLR as _StepLR -from torch.optim.lr_scheduler import ExponentialLR as _ExponentialLR from colossalai.registry import LR_SCHEDULERS diff --git a/colossalai/nn/metric/__init__.py b/colossalai/nn/metric/__init__.py index 00833b6119c1..e441409a8a27 100644 --- a/colossalai/nn/metric/__init__.py +++ b/colossalai/nn/metric/__init__.py @@ -1,26 +1,28 @@ -from torch import nn - -from ._utils import calc_acc -from .accuracy_2d import Accuracy2D -from .accuracy_2p5d import Accuracy2p5D -from .accuracy_3d import Accuracy3D -from colossalai.nn.layer.utils import get_tensor_parallel_mode - -_parallel_accuracy = { - '2d': Accuracy2D, - '2.5d': Accuracy2p5D, - '3d': Accuracy3D, -} - - -class Accuracy(nn.Module): - def __init__(self): - super().__init__() - tensor_parallel = get_tensor_parallel_mode() - if tensor_parallel not in _parallel_accuracy: - self.acc = calc_acc - else: - self.acc = _parallel_accuracy[tensor_parallel]() - - def forward(self, *args): - return self.acc(*args) +from torch import nn + +from colossalai.nn.layer.utils import get_tensor_parallel_mode + +from ._utils import calc_acc +from .accuracy_2d import Accuracy2D +from .accuracy_2p5d import Accuracy2p5D +from .accuracy_3d import Accuracy3D + +_parallel_accuracy = { + '2d': Accuracy2D, + '2.5d': Accuracy2p5D, + '3d': Accuracy3D, +} + + +class Accuracy(nn.Module): + + def __init__(self): + super().__init__() + tensor_parallel = get_tensor_parallel_mode() + if tensor_parallel not in _parallel_accuracy: + self.acc = calc_acc + else: + self.acc = _parallel_accuracy[tensor_parallel]() + + def forward(self, *args): + return self.acc(*args) diff --git a/colossalai/nn/metric/_utils.py b/colossalai/nn/metric/_utils.py index eac591b64c65..8706ffc101b0 100644 --- a/colossalai/nn/metric/_utils.py +++ b/colossalai/nn/metric/_utils.py @@ -1,7 +1,7 @@ -import torch - - -def calc_acc(logits, targets): - preds = torch.argmax(logits, dim=-1) - correct = torch.sum(targets == preds) - return correct +import torch + + +def calc_acc(logits, targets): + preds = torch.argmax(logits, dim=-1) + correct = torch.sum(targets == preds) + return correct diff --git a/colossalai/nn/metric/accuracy_2d.py b/colossalai/nn/metric/accuracy_2d.py index a86832973cfd..22c95724acb1 100644 --- a/colossalai/nn/metric/accuracy_2d.py +++ b/colossalai/nn/metric/accuracy_2d.py @@ -1,7 +1,8 @@ import torch -from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from torch import nn +from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d + from ._utils import calc_acc diff --git a/colossalai/nn/metric/accuracy_2p5d.py b/colossalai/nn/metric/accuracy_2p5d.py index 3044da065de1..0c166b65973c 100644 --- a/colossalai/nn/metric/accuracy_2p5d.py +++ b/colossalai/nn/metric/accuracy_2p5d.py @@ -1,7 +1,8 @@ import torch -from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from torch import nn +from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d + from ._utils import calc_acc diff --git a/colossalai/nn/metric/accuracy_3d.py b/colossalai/nn/metric/accuracy_3d.py index 5506fc1d2ffc..1c1e7ec26c55 100644 --- a/colossalai/nn/metric/accuracy_3d.py +++ b/colossalai/nn/metric/accuracy_3d.py @@ -1,33 +1,35 @@ -import torch -from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D -from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d -from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env -from torch import nn - -from ._utils import calc_acc - - -class Accuracy3D(nn.Module): - """Accuracy for 3D parallelism - """ - def __init__(self): - super().__init__() - self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) - self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) - - def forward(self, logits, targets): - """Calculate the accuracy of predicted labels. - - Args: - logits (:class:`torch.tensor`): Predicted labels. - targets (:class:`torch.tensor`): True labels from data. - - Returns: - float: the accuracy of prediction. - """ - with torch.no_grad(): - targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) - targets = split_tensor_3d(targets, 0, self.input_parallel_mode) - correct = calc_acc(logits, targets) - correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode) - return correct +import torch +from torch import nn + +from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D +from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d +from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env + +from ._utils import calc_acc + + +class Accuracy3D(nn.Module): + """Accuracy for 3D parallelism + """ + + def __init__(self): + super().__init__() + self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) + self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) + + def forward(self, logits, targets): + """Calculate the accuracy of predicted labels. + + Args: + logits (:class:`torch.tensor`): Predicted labels. + targets (:class:`torch.tensor`): True labels from data. + + Returns: + float: the accuracy of prediction. + """ + with torch.no_grad(): + targets = split_tensor_3d(targets, 0, self.weight_parallel_mode) + targets = split_tensor_3d(targets, 0, self.input_parallel_mode) + correct = calc_acc(logits, targets) + correct = reduce_by_batch_3d(correct, self.input_parallel_mode, self.weight_parallel_mode) + return correct diff --git a/colossalai/nn/optimizer/__init__.py b/colossalai/nn/optimizer/__init__.py index 06072648beba..b92be5a3b9d1 100644 --- a/colossalai/nn/optimizer/__init__.py +++ b/colossalai/nn/optimizer/__init__.py @@ -1,10 +1,10 @@ from .colossalai_optimizer import ColossalaiOptimizer +from .cpu_adam import CPUAdam from .fused_adam import FusedAdam from .fused_lamb import FusedLAMB from .fused_sgd import FusedSGD +from .hybrid_adam import HybridAdam from .lamb import Lamb from .lars import Lars -from .cpu_adam import CPUAdam -from .hybrid_adam import HybridAdam __all__ = ['ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam'] diff --git a/colossalai/nn/optimizer/colossalai_optimizer.py b/colossalai/nn/optimizer/colossalai_optimizer.py index 34f5a9541975..8e0dd9e03cf4 100644 --- a/colossalai/nn/optimizer/colossalai_optimizer.py +++ b/colossalai/nn/optimizer/colossalai_optimizer.py @@ -2,6 +2,7 @@ import torch.nn as nn from torch import Tensor from torch.optim import Optimizer + from colossalai.utils import clip_grad_norm_fp32 diff --git a/colossalai/nn/optimizer/lars.py b/colossalai/nn/optimizer/lars.py index 212f66671a0d..9f86a1431a2a 100644 --- a/colossalai/nn/optimizer/lars.py +++ b/colossalai/nn/optimizer/lars.py @@ -22,28 +22,24 @@ class Lars(Optimizer): weight_decay (float, optional): weight decay (L2 penalty) (default: 0) """ - def __init__( - self, - params: Iterable[torch.nn.Parameter], - lr=1e-3, - momentum=0, - eeta=1e-3, - weight_decay=0, - epsilon=0.0 - ) -> None: + def __init__(self, + params: Iterable[torch.nn.Parameter], + lr=1e-3, + momentum=0, + eeta=1e-3, + weight_decay=0, + epsilon=0.0) -> None: if not isinstance(lr, float) or lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: - raise ValueError( - "Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if eeta <= 0 or eeta > 1: raise ValueError("Invalid eeta value: {}".format(eeta)) if epsilon < 0: raise ValueError("Invalid epsilon value: {}".format(epsilon)) - defaults = dict(lr=lr, momentum=momentum, - weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True) + defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay, eeta=eeta, epsilon=epsilon, lars=True) super().__init__(params, defaults) @@ -76,11 +72,9 @@ def step(self, closure=None): if lars: w_norm = torch.norm(p) g_norm = torch.norm(p.grad) - trust_ratio = torch.where( - w_norm > 0 and g_norm > 0, - eeta * w_norm / (g_norm + weight_decay * w_norm + eps), - torch.ones_like(w_norm) - ) + trust_ratio = torch.where(w_norm > 0 and g_norm > 0, + eeta * w_norm / (g_norm + weight_decay * w_norm + eps), + torch.ones_like(w_norm)) trust_ratio.clamp_(0.0, 50) scaled_lr *= trust_ratio.item() if weight_decay != 0: @@ -90,8 +84,7 @@ def step(self, closure=None): if momentum != 0: param_state = self.state[p] if 'momentum_buffer' not in param_state: - buf = param_state['momentum_buffer'] = torch.clone( - decayed_grad).detach() + buf = param_state['momentum_buffer'] = torch.clone(decayed_grad).detach() else: buf = param_state['momentum_buffer'] buf.mul_(momentum).add_(decayed_grad) diff --git a/colossalai/nn/optimizer/nvme_optimizer.py b/colossalai/nn/optimizer/nvme_optimizer.py index cbb435a90f61..786f1ecd7b5e 100644 --- a/colossalai/nn/optimizer/nvme_optimizer.py +++ b/colossalai/nn/optimizer/nvme_optimizer.py @@ -1,9 +1,10 @@ -import torch +import math import os import tempfile -import math +from typing import Callable, Dict, List, Optional + +import torch from torch.nn.parameter import Parameter -from typing import Optional, List, Dict, Callable class NVMeOptimizer(torch.optim.Optimizer): diff --git a/colossalai/nn/parallel/layers/__init__.py b/colossalai/nn/parallel/layers/__init__.py index 29b8353e63c5..f38124efedf7 100644 --- a/colossalai/nn/parallel/layers/__init__.py +++ b/colossalai/nn/parallel/layers/__init__.py @@ -1,10 +1,17 @@ +from .cache_embedding import ( + CachedEmbeddingBag, + CachedParamMgr, + EvictionStrategy, + LimitBuffIndexCopyer, + ParallelCachedEmbeddingBag, + ParallelCachedEmbeddingBagTablewise, + ParallelCachedEmbeddingBagTablewiseSpiltCache, + TablewiseEmbeddingBagConfig, +) from .colo_module import ColoModule -from .linear import ColoLinear from .embedding import ColoEmbedding -from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module - -from .cache_embedding import CachedEmbeddingBag, ParallelCachedEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer, EvictionStrategy, \ - ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig, ParallelCachedEmbeddingBagTablewiseSpiltCache +from .linear import ColoLinear +from .module_utils import check_colo_module, get_colo_module, init_colo_module, is_colo_module, register_colo_module __all__ = [ 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', diff --git a/colossalai/nn/parallel/layers/cache_embedding/__init__.py b/colossalai/nn/parallel/layers/cache_embedding/__init__.py index 5bbc931a79dc..d87930c1c6b3 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/__init__.py +++ b/colossalai/nn/parallel/layers/cache_embedding/__init__.py @@ -1,8 +1,8 @@ from .cache_mgr import CachedParamMgr, EvictionStrategy -from .copyer import LimitBuffIndexCopyer from .cached_embedding import CachedEmbeddingBag -from .parallel_cached_embedding import ParallelCachedEmbeddingBag +from .copyer import LimitBuffIndexCopyer from .embedding_config import TablewiseEmbeddingBagConfig +from .parallel_cached_embedding import ParallelCachedEmbeddingBag from .parallel_cached_embedding_tablewise import ParallelCachedEmbeddingBagTablewise from .parallel_cached_embedding_tablewise_split_cache import ParallelCachedEmbeddingBagTablewiseSpiltCache diff --git a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py index 705835a0ed22..9558c541e703 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/base_embedding.py @@ -1,4 +1,5 @@ import abc + import torch.nn as nn diff --git a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py index da043df368ae..e58d2847efb2 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cache_mgr.py @@ -1,12 +1,14 @@ +import sys +from contextlib import contextmanager +from enum import Enum +from typing import List, Optional + import numpy as np import torch -from torch.profiler import record_function -from typing import List, Optional from contexttimer import Timer +from torch.profiler import record_function + from .copyer import LimitBuffIndexCopyer -from enum import Enum -import sys -from contextlib import contextmanager class EvictionStrategy(Enum): @@ -35,7 +37,7 @@ def _wait_for_data(t, stream: Optional[torch.cuda.streams.Stream]) -> None: class CachedParamMgr(torch.nn.Module): """ Manage Embedding Weights on CPU and CUDA memory uses a software cache. - CPU maintains the entire original weight. + CPU maintains the entire original weight. CUDA maintains a fraction of the weights used in the upcoming computation. The row number in CUDA is controlled by `cuda_row_num`. During training, GPU needs to transmit embedding rows between CPU and GPU. Args: @@ -115,7 +117,7 @@ def timer(self, name): self._elapsed_dict[name] += t.elapsed def _find_evict_gpu_idxs(self, evict_num: int) -> torch.Tensor: - """_find_evict_gpu_idxs + """_find_evict_gpu_idxs Find the gpu idxs to be evicted, according to their freq. Args: evict_num (int): how many rows has to be evicted @@ -202,7 +204,7 @@ def reorder(self, ids_freq_mapping: Optional[List[int]] = None, warmup_ratio=0.7 """reorder reorder the weight according to ids' frequency in dataset before training. Execute only once before training, also known as warmup phase. - + Note: If you would like to use the DATASET as the eviction strategy, you must call this function. Note: @@ -516,7 +518,7 @@ def _evict(self) -> int: """ deprecated evict one row from cuda to cpu. - Returns: + Returns: (int) : the slot id be evicted. """ mask = torch.logical_or(torch.isin(self.cached_idx_map, self.evict_backlist), self.cached_idx_map == -1) diff --git a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py index a0c45d8e80c0..64bc69148692 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/cached_embedding.py @@ -1,10 +1,11 @@ +from typing import Iterator, List, Optional, Tuple, Union + import torch import torch.nn.functional as F -from typing import List, Optional, Iterator, Tuple, Union +from torch.nn.parameter import Parameter from .base_embedding import BaseEmbeddingBag from .cache_mgr import CachedParamMgr, EvictionStrategy -from torch.nn.parameter import Parameter class CachedEmbeddingBag(BaseEmbeddingBag): @@ -27,7 +28,7 @@ class CachedEmbeddingBag(BaseEmbeddingBag): include_last_offset (bool, optional): if True, offsets has one additional element, where the last element is equivalent to the size of indices. This matches the CSR format.. Defaults to False. dtype (torch.dtype, optional): data type of the cpu weight initialization. Defaults to None meaning float32. device (torch.device, optional): device type to the cpu weight. Defaults to None meaning cpu. - cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row + cache_ratio (float, float): cache ratio of the #cuda_weight_row / #cpu_weight_row ids_freq_mapping (Union[List, torch.Tensor], optional): the frequency of each embedding vector occures in dataset. Defaults to None. warmup_ratio (float, optional): the ratio of cuda cache is warmuped with. Defaults to 0.7. buffer_size (int, optional): the max number of vectors in transmitter buffer. If set to 0, the buffer is not used. Defaults to 0. @@ -85,10 +86,10 @@ def _preprocess(self, buffer_size=50_000, pin_weight=False): """ - Called after initialized. + Called after initialized. Reorder the weight rows according to the ids_freq_mapping. Then, let the weights of the Module be managed by a CachedParamMgr. - + Args: cuda_row_num (int): number of rows can be hosted in CUDA memory ids_freq_mapping (List[int]): a list, idx is id number, value is freq diff --git a/colossalai/nn/parallel/layers/cache_embedding/copyer.py b/colossalai/nn/parallel/layers/cache_embedding/copyer.py index b586be1dc6d9..d1be7248e729 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/copyer.py +++ b/colossalai/nn/parallel/layers/cache_embedding/copyer.py @@ -3,7 +3,7 @@ class LimitBuffIndexCopyer(object): - """LimitBuffIndexCopyer + """LimitBuffIndexCopyer Index Copy using limited temp buffer on CUDA. Args: @@ -15,7 +15,7 @@ def __init__(self, size: int) -> None: @torch.no_grad() def index_copy(self, dim: int, src_index: LongTensor, tgt_index: LongTensor, src: torch.Tensor, tgt: torch.Tensor): - """copy + """copy src tensor[src_index] -(index_select)-> tmp -(index_copy_)-> tgt tensor [tgt_index] The valid rows in the src tensor are continous, while rows in tgt tensor is scattered. diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py index d7f77e195f4b..c2b8a33210f8 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding.py @@ -1,12 +1,13 @@ +from typing import Iterator, List, Optional, Tuple + import torch import torch.nn.functional as F -from typing import List, Optional, Iterator, Tuple -from .cached_embedding import CachedEmbeddingBag from colossalai.nn._ops._utils import dual_all_to_all +from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ComputePattern, ProcessGroup, ShardSpec -from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec, ColoTensor from .cache_mgr import CachedParamMgr, EvictionStrategy +from .cached_embedding import CachedEmbeddingBag def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py index 949f85ad4baf..702d546edb09 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise.py @@ -1,15 +1,16 @@ +import time +from typing import List + import torch import torch.distributed as dist import torch.nn.functional as F -from .cached_embedding import CachedEmbeddingBag -from .cache_mgr import EvictionStrategy -from .embedding_config import TablewiseEmbeddingBagConfig -from colossalai.tensor import ProcessGroup from colossalai.nn._ops._utils import dual_all_to_all_tablewise +from colossalai.tensor import ProcessGroup -from typing import List -import time +from .cache_mgr import EvictionStrategy +from .cached_embedding import CachedEmbeddingBag +from .embedding_config import TablewiseEmbeddingBagConfig class ParallelCachedEmbeddingBagTablewise(CachedEmbeddingBag): diff --git a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py index cb4647028d47..d23eb3e2a040 100644 --- a/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py +++ b/colossalai/nn/parallel/layers/cache_embedding/parallel_cached_embedding_tablewise_split_cache.py @@ -1,17 +1,17 @@ +import abc +from typing import List + import torch import torch.distributed as dist import torch.nn as nn from torch.profiler import record_function -from .cached_embedding import CachedEmbeddingBag - -from colossalai.tensor import ProcessGroup from colossalai.nn._ops._utils import dual_all_to_all_tablewise -from .embedding_config import TablewiseEmbeddingBagConfig -from .cache_mgr import EvictionStrategy +from colossalai.tensor import ProcessGroup -from typing import List -import abc +from .cache_mgr import EvictionStrategy +from .cached_embedding import CachedEmbeddingBag +from .embedding_config import TablewiseEmbeddingBagConfig class ParallelCachedEmbeddingBagTablewiseSpiltCache(abc.ABC, nn.Module): diff --git a/colossalai/nn/parallel/layers/colo_module.py b/colossalai/nn/parallel/layers/colo_module.py index 8f0f5d5f520a..a0a3eb40cf08 100644 --- a/colossalai/nn/parallel/layers/colo_module.py +++ b/colossalai/nn/parallel/layers/colo_module.py @@ -1,6 +1,7 @@ -from colossalai.tensor.distspec import _DistSpec +from typing import Dict, List + from colossalai.tensor import ComputePattern -from typing import List, Dict +from colossalai.tensor.distspec import _DistSpec class ColoModule(object): diff --git a/colossalai/nn/parallel/layers/embedding.py b/colossalai/nn/parallel/layers/embedding.py index ccacc1ead297..3e4e7ffd8de7 100644 --- a/colossalai/nn/parallel/layers/embedding.py +++ b/colossalai/nn/parallel/layers/embedding.py @@ -1,5 +1,6 @@ +from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec + from .colo_module import ColoModule -from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec class ColoEmbedding(ColoModule): diff --git a/colossalai/nn/parallel/layers/linear.py b/colossalai/nn/parallel/layers/linear.py index 84a8c042587d..e391cf808933 100644 --- a/colossalai/nn/parallel/layers/linear.py +++ b/colossalai/nn/parallel/layers/linear.py @@ -1,5 +1,6 @@ +from colossalai.tensor import ComputePattern, ProcessGroup, ShardSpec, distspec + from .colo_module import ColoModule -from colossalai.tensor import ComputePattern, distspec, ProcessGroup, ShardSpec class ColoLinear(ColoModule): diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/nn/parallel/layers/module_utils.py index 38d128cc705e..191266fa70fd 100644 --- a/colossalai/nn/parallel/layers/module_utils.py +++ b/colossalai/nn/parallel/layers/module_utils.py @@ -1,9 +1,11 @@ from typing import Dict -from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup -from colossalai.tensor import distspec -from . import ColoModule + import torch +from colossalai.tensor import ColoParameter, ComputeSpec, ProcessGroup, distspec + +from . import ColoModule + _COLOSSAL_MODULES: Dict[type, ColoModule] = {} diff --git a/colossalai/pipeline/__init__.py b/colossalai/pipeline/__init__.py index 0fcde9707646..f36f54ac9307 100644 --- a/colossalai/pipeline/__init__.py +++ b/colossalai/pipeline/__init__.py @@ -1,4 +1,4 @@ -from .pipelinable import PipelinableContext, PipelinableModel from .layer_spec import LayerSpec +from .pipelinable import PipelinableContext, PipelinableModel -__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] \ No newline at end of file +__all__ = ['PipelinableModel', 'PipelinableContext', 'LayerSpec'] diff --git a/colossalai/pipeline/layer_spec.py b/colossalai/pipeline/layer_spec.py index 7e9169efff78..3960debd7f72 100644 --- a/colossalai/pipeline/layer_spec.py +++ b/colossalai/pipeline/layer_spec.py @@ -1,9 +1,11 @@ import torch + from colossalai.utils.model.utils import call_to_str + class LayerSpec: """ - + """ def __init__(self, typename, *module_args, **module_kwargs): @@ -52,4 +54,4 @@ def count_params(self): return self._param_count def reset_param_count(self): - self._param_count = 0 \ No newline at end of file + self._param_count = 0 diff --git a/colossalai/pipeline/middleware/__init__.py b/colossalai/pipeline/middleware/__init__.py index 79e19f9eaf77..481741bfee31 100644 --- a/colossalai/pipeline/middleware/__init__.py +++ b/colossalai/pipeline/middleware/__init__.py @@ -1,3 +1,3 @@ -from .topo import Topo, Partition, PartitionOutputVal, PartitionInputVal +from .topo import Partition, PartitionInputVal, PartitionOutputVal, Topo -__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] \ No newline at end of file +__all__ = ['Topo', 'Partition', 'PartitionOutputVal', 'PartitionInputVal'] diff --git a/colossalai/pipeline/middleware/adaptor/__init__.py b/colossalai/pipeline/middleware/adaptor/__init__.py index 949700a2c49d..0b0d36d2ffe5 100644 --- a/colossalai/pipeline/middleware/adaptor/__init__.py +++ b/colossalai/pipeline/middleware/adaptor/__init__.py @@ -1,3 +1,3 @@ from .fx import get_topology as get_fx_topology -__all__ = ['get_fx_topology'] \ No newline at end of file +__all__ = ['get_fx_topology'] diff --git a/colossalai/pipeline/middleware/adaptor/fx.py b/colossalai/pipeline/middleware/adaptor/fx.py index 8437c5194762..80d7c9c9b1a1 100644 --- a/colossalai/pipeline/middleware/adaptor/fx.py +++ b/colossalai/pipeline/middleware/adaptor/fx.py @@ -1,6 +1,8 @@ +import torch from torch.fx.graph_module import GraphModule + from colossalai.pipeline.middleware.topo import Partition, PartitionInputVal, PartitionOutputVal, Topo -import torch + def partition_name_to_id(partition_name, is_input=False, is_output=False): if is_input: @@ -12,6 +14,7 @@ def partition_name_to_id(partition_name, is_input=False, is_output=False): partition_id = int(partition_name.split(prefix)[-1]) + 2 return partition_id + # There are two kinds of def in fx.graph # 1. non direct_use & non direct_def, which means the output is used by next partition with a temporary mid value. # e.g. submod1 = call_module(...) @@ -45,9 +48,10 @@ def find_input_in_partition(node, partitions, input_partitions=None): partition_id = partition_name_to_id(partition.name) p_input_val = PartitionInputVal(partition_id=partition_id, offset=offset) return p_input_val - + return p_input_val - + + def find_output_in_partition(node, partitions, output_partitions=None): p_output_val = PartitionOutputVal() for user in node.users: @@ -70,7 +74,7 @@ def find_output_in_partition(node, partitions, output_partitions=None): if arg == user: p_output_val.add(partition_id=partition_id, offset=i) break - + # user is output if output_partitions is not None: output_node = output_partitions[0] @@ -84,10 +88,11 @@ def find_output_in_partition(node, partitions, output_partitions=None): break return p_output_val + def get_topology(gm: GraphModule): topo = Topo() topo_output_partition = Partition() - + input_partitions = [] partitions = [] output_partitions = [] @@ -109,7 +114,7 @@ def get_topology(gm: GraphModule): topo_input_partition.add_output_val(p_output_val) topo.set_partitions(partition_id=0, partition=topo_input_partition) topo.set_input_partition_id(partition_id=0) - + for i, partition in enumerate(partitions): topo_mid_partition = Partition() # set input for submodule @@ -131,15 +136,16 @@ def get_topology(gm: GraphModule): for user in partition.users: cur_node = user p_output_val = find_output_in_partition(cur_node, partitions, output_partitions) - topo_mid_partition.add_output_val(p_output_val) - topo.set_partitions(partition_id=i+2, partition=topo_mid_partition) - + topo_mid_partition.add_output_val(p_output_val) + topo.set_partitions(partition_id=i + 2, partition=topo_mid_partition) + # set input for output_partition for partition in output_partitions: topo_output_partition = Partition() - torch.fx.graph.map_arg(partition.args[0], lambda n: topo_output_partition.add_input_val( - find_input_in_partition(n, partitions, input_partitions))) + torch.fx.graph.map_arg( + partition.args[0], + lambda n: topo_output_partition.add_input_val(find_input_in_partition(n, partitions, input_partitions))) topo.set_partitions(partition_id=1, partition=topo_output_partition) topo.set_output_partition_id(partition_id=1) - return topo \ No newline at end of file + return topo diff --git a/colossalai/pipeline/middleware/topo.py b/colossalai/pipeline/middleware/topo.py index e798e2ed9cab..3c21cce6dc0e 100644 --- a/colossalai/pipeline/middleware/topo.py +++ b/colossalai/pipeline/middleware/topo.py @@ -1,49 +1,54 @@ -from typing import Dict, List from dataclasses import dataclass +from typing import Dict, List # This file includes data structure used by Pipeline Middleware. + @dataclass class ValPosition: partition_id: int offset: int - + def __str__(self) -> str: res = f'[partition_id:{self.partition_id},offset:{self.offset}]' return res - + def __repr__(self) -> str: return self.__str__() + class PartitionInputVal(object): + def __init__(self, partition_id, offset) -> None: # every input from which partition_id and which offset val_pos = ValPosition(partition_id, offset) self._from_partition_and_offset: ValPosition = val_pos - + def get(self): return self._from_partition_and_offset - + def __str__(self) -> str: res = '' res += f'<-({self._from_partition_and_offset})' return res - + def __repr__(self) -> str: return self.__str__() - + + class PartitionOutputVal(object): + def __init__(self) -> None: # every output to which partition_id and which offset self._to_partition_and_offset: List[ValPosition] = [] - + def add(self, partition_id, offset): val_pos = ValPosition(partition_id, offset) self._to_partition_and_offset.append(val_pos) - + def get(self): return self._to_partition_and_offset - + def __str__(self) -> str: res = '' res += '->(' @@ -51,27 +56,29 @@ def __str__(self) -> str: res += f'{val_pos},' res += ')' return res - + def __repr__(self) -> str: return self.__str__() + class Partition(object): + def __init__(self) -> None: self._input_vals: List[PartitionInputVal] = [] self._output_vals: List[PartitionOutputVal] = [] - + def add_input_val(self, input_val: PartitionInputVal): self._input_vals.append(input_val) - + def add_output_val(self, output_val: PartitionOutputVal): self._output_vals.append(output_val) - + def get_input_vals(self): return self._input_vals - + def get_output_vals(self): return self._output_vals - + # get the output offsets sent to dst_partition_id def get_output_offsets(self, dst_partition_id): res = [] @@ -80,9 +87,9 @@ def get_output_offsets(self, dst_partition_id): for val_pos in outputs: if val_pos.partition_id == dst_partition_id: res.append(offset) - + return res - + # get all input dst partition_ids def get_input_partition_ids(self): res = [] @@ -91,7 +98,7 @@ def get_input_partition_ids(self): if val_pos.partition_id not in res: res.append(val_pos.partition_id) return res - + # get all output dst partition_ids def get_output_partition_ids(self): res = [] @@ -101,24 +108,25 @@ def get_output_partition_ids(self): if val_pos.partition_id not in res: res.append(val_pos.partition_id) return res - + def __str__(self) -> str: res = '' res += f' input:\n' res += f' length:{len(self._input_vals)}\n' for i, input_val in enumerate(self._input_vals): res += f' offset={i}:{input_val}\n' - + res += f' output:\n' res += f' length:{len(self._output_vals)}\n' for i, output_val in enumerate(self._output_vals): res += f' offset={i}:{output_val}\n' - + return res - + def __repr__(self) -> str: return self.__str__() + # This class is a middleware between partition splitter # and Pipeline Scheduler. It records the graph info about # partition input/output and provides it to scheduler. @@ -132,42 +140,43 @@ def __repr__(self) -> str: # _input_partition_id: the key represents input_partition # _output_partition_id: the key represents output_partition class Topo(object): + def __init__(self, input_partition_id=None, output_partition_id=None) -> None: self._partitions: Dict[int, Partition] = {} self._input_partition_id = input_partition_id self._output_partition_id = output_partition_id - + def set_input_partition_id(self, partition_id: int): self._input_partition_id = partition_id - + def set_output_partition_id(self, partition_id: int): self._output_partition_id = partition_id - + def get_input_partition_id(self): return self._input_partition_id - + def get_output_partition_id(self): return self._output_partition_id - + def set_partitions(self, partition_id: int, partition: Partition): self._partitions[partition_id] = partition - + def get_mid_partitions(self): - res = {} #{partition_id: Partition} + res = {} #{partition_id: Partition} for partition_id, partition in self._partitions.items(): if self._input_partition_id == partition_id or self._output_partition_id == partition_id: continue res[partition_id] = partition return res - + def get_mid_partition_ids(self): return list(self.get_mid_partitions().keys()) - + def get_input_partition(self): if self._input_partition_id is not None: return self._partitions[self._input_partition_id] return None - + def get_output_partition(self): if self._output_partition_id is not None: return self._partitions[self._output_partition_id] @@ -175,7 +184,7 @@ def get_output_partition(self): def get_partition_by_id(self, partition_id): return self._partitions[partition_id] - + def __str__(self) -> str: res = '' if len(self._partitions) == 0: @@ -186,21 +195,20 @@ def __str__(self) -> str: res += '{\n' res += f'InputPartition:\n partition_id={self._input_partition_id}\n{input_part}' res += '}\n' - + mid_parts = self.get_mid_partitions() for i, (partition_id, part) in enumerate(mid_parts.items()): res += '{\n' res += f'SubPartition_{i}:\n partition_id={partition_id}\n {part}' res += '}\n' - + output_part = self.get_output_partition() if output_part is not None: res += '{\n' res += f'OutputPartition:\n partition_id={self._output_partition_id}\n{output_part}' res += '}\n' - + return res - + def __repr__(self) -> str: return self.__str__() - \ No newline at end of file diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py index 9731530a6e15..433f2352fb92 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/pipeline/pipelinable.py @@ -1,15 +1,24 @@ -import torch import inspect -from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, \ - build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, \ - call_module, customized_partition +import torch + +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.nn.layer.utils import CheckpointModule from colossalai.tensor import ColoParameter -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses + from .layer_spec import LayerSpec +from .utils import ( + build_kwargs_for_function, + build_kwargs_for_module, + call_module, + customized_partition, + exec_func_with_kwargs, + exec_funcs_with_kwargs, + partition_balanced, + partition_uniform, +) class PipelinableContext(InsertPostInitMethodToModuleSubClasses): diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/pipeline/pipeline_process_group.py index c61d97ebabfa..c0ee0286787f 100644 --- a/colossalai/pipeline/pipeline_process_group.py +++ b/colossalai/pipeline/pipeline_process_group.py @@ -1,9 +1,9 @@ -from typing import List, Dict, Tuple import os import threading +from typing import Dict, List, Tuple -from torch.distributed import rpc import torch.distributed as dist +from torch.distributed import rpc from colossalai.tensor import ProcessGroup diff --git a/colossalai/pipeline/rpc/__init__.py b/colossalai/pipeline/rpc/__init__.py index 9d9e9d44f46c..15b65a4138a8 100644 --- a/colossalai/pipeline/rpc/__init__.py +++ b/colossalai/pipeline/rpc/__init__.py @@ -1,4 +1,4 @@ -from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine +from ._pipeline_schedule import ChimeraPipelineEngine, FillDrainPipelineEngine, OneFOneBPipelineEngine from .utils import pytree_map -__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] \ No newline at end of file +__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py index df7226644a7a..a1ea560d61a3 100644 --- a/colossalai/pipeline/utils.py +++ b/colossalai/pipeline/utils.py @@ -1,12 +1,13 @@ import heapq import inspect +from collections import OrderedDict +from typing import List + import torch from colossalai.logging import get_dist_logger from colossalai.nn.layer.utils import CheckpointModule -from typing import List -from collections import OrderedDict def _binary_partition(weights: List, start: int, end: int): """Returns the binary partition position of `weights`, given the start @@ -162,7 +163,7 @@ def build_kwargs_for_module(function, input_tensor, kw_dict): kwargs_offset = 1 elif isinstance(input_tensor, (tuple, OrderedDict)): #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' - # Huggingface will take their own structures based on OrderedDict as the output + # Huggingface will take their own structures based on OrderedDict as the output # between layers so we've to close this check. kwargs_offset = len(input_tensor) args_name_list = list(sig.parameters.keys()) @@ -256,7 +257,7 @@ def call_module(module, args=None, kwargs=None): def customized_partition(exec_seq): ''' - This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an + This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an annotation to note the partition point. ''' customized_parts = {} diff --git a/colossalai/registry/registry.py b/colossalai/registry/registry.py index 8a4173f7ab99..50d6b74c5617 100644 --- a/colossalai/registry/registry.py +++ b/colossalai/registry/registry.py @@ -6,7 +6,7 @@ class Registry: - """This is a registry class used to register classes and modules so that a universal + """This is a registry class used to register classes and modules so that a universal object builder can be enabled. Args: @@ -42,7 +42,7 @@ def register_module(self, module_class): return module_class def get_module(self, module_name: str): - """Retrieves a module with name `module_name` and returns the module if it has + """Retrieves a module with name `module_name` and returns the module if it has already been registered before. Args: diff --git a/colossalai/tensor/op_wrapper.py b/colossalai/tensor/op_wrapper.py index 1c00066f7465..63ebaa264279 100644 --- a/colossalai/tensor/op_wrapper.py +++ b/colossalai/tensor/op_wrapper.py @@ -1,8 +1,5 @@ -from typing import ( - Callable, - Dict, -) import functools +from typing import Callable, Dict # Custom sharded ops _COLOSSAL_OPS: Dict[str, Callable] = {} diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index e3dd500dea8e..1caf3b94d777 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -1,5 +1,5 @@ -from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group -from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use, skip_if_not_enough_gpus +from .comparison import assert_close, assert_close_loose, assert_equal, assert_equal_in_group, assert_not_equal +from .utils import parameterize, rerun_if_address_is_in_use, rerun_on_exception, skip_if_not_enough_gpus __all__ = [ 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', diff --git a/colossalai/testing/pytest_wrapper.py b/colossalai/testing/pytest_wrapper.py index a472eb3723ec..9732f89b3aab 100644 --- a/colossalai/testing/pytest_wrapper.py +++ b/colossalai/testing/pytest_wrapper.py @@ -1,12 +1,13 @@ """ This file will not be automatically imported by `colossalai.testing` -as this file has a dependency on `pytest`. Therefore, you need to +as this file has a dependency on `pytest`. Therefore, you need to explicitly import this file `from colossalai.testing.pytest_wrapper import `.from """ -import pytest import os +import pytest + def run_on_environment_flag(name: str): """ diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py index 64c1d6e7bcd0..953b4b74a3be 100644 --- a/colossalai/testing/utils.py +++ b/colossalai/testing/utils.py @@ -1,8 +1,9 @@ import re -import torch -from typing import Callable, List, Any from functools import partial from inspect import signature +from typing import Any, Callable, List + +import torch from packaging import version @@ -43,7 +44,7 @@ def say_something(person, msg): # > davis: hello # > davis: bye # > davis: stop - + Args: argument (str): the name of the argument to parameterize values (List[Any]): a list of values to iterate for this argument @@ -85,13 +86,13 @@ def test_method(): def test_method(): print('hey') raise RuntimeError('Address already in use') - + # rerun for infinite times if Runtime error occurs @rerun_on_exception(exception_type=RuntimeError, max_try=None) def test_method(): print('hey') raise RuntimeError('Address already in use') - + # rerun only the exception message is matched with pattern # for infinite times if Runtime error occurs @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") @@ -101,10 +102,10 @@ def test_method(): Args: exception_type (Exception, Optional): The type of exception to detect for rerun - pattern (str, Optional): The pattern to match the exception message. + pattern (str, Optional): The pattern to match the exception message. If the pattern is not None and matches the exception message, the exception will be detected for rerun - max_try (int, Optional): Maximum reruns for this function. The default value is 5. + max_try (int, Optional): Maximum reruns for this function. The default value is 5. If max_try is None, it will rerun foreven if exception keeps occurings """ diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py index 60bbc4eeee32..972224248b19 100644 --- a/colossalai/trainer/_trainer.py +++ b/colossalai/trainer/_trainer.py @@ -1,4 +1,4 @@ -from typing import Union, List, Any +from typing import Any, List, Union import torch from torch.utils.data import DataLoader @@ -6,9 +6,8 @@ from colossalai.engine import Engine from colossalai.logging import DistributedLogger -from colossalai.utils import MultiTimer -from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage from colossalai.trainer.hooks import BaseHook +from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0 class Trainer: diff --git a/colossalai/trainer/hooks/__init__.py b/colossalai/trainer/hooks/__init__.py index 4d36093833d9..bf9cc6421b67 100644 --- a/colossalai/trainer/hooks/__init__.py +++ b/colossalai/trainer/hooks/__init__.py @@ -1,7 +1,12 @@ from ._base_hook import BaseHook from ._checkpoint_hook import SaveCheckpointHook -from ._log_hook import (LogMemoryByEpochHook, LogMetricByEpochHook, LogMetricByStepHook, LogTimingByEpochHook, - TensorboardHook) +from ._log_hook import ( + LogMemoryByEpochHook, + LogMetricByEpochHook, + LogMetricByStepHook, + LogTimingByEpochHook, + TensorboardHook, +) from ._lr_scheduler_hook import LRSchedulerHook from ._metric_hook import AccuracyHook, LossHook, MetricHook, ThroughputHook diff --git a/colossalai/trainer/hooks/_checkpoint_hook.py b/colossalai/trainer/hooks/_checkpoint_hook.py index 3bcb32cd2dcb..ab1d423d84c4 100644 --- a/colossalai/trainer/hooks/_checkpoint_hook.py +++ b/colossalai/trainer/hooks/_checkpoint_hook.py @@ -1,11 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- import torch -from colossalai.logging import get_dist_logger +from colossalai.logging import get_dist_logger from colossalai.registry import HOOKS from colossalai.trainer.hooks import BaseHook from colossalai.utils.checkpointing import save_checkpoint + from ._lr_scheduler_hook import LRSchedulerHook diff --git a/colossalai/trainer/hooks/_log_hook.py b/colossalai/trainer/hooks/_log_hook.py index 5b1f33983422..b00d8217012d 100644 --- a/colossalai/trainer/hooks/_log_hook.py +++ b/colossalai/trainer/hooks/_log_hook.py @@ -3,17 +3,17 @@ import os import os.path as osp - from typing import List + from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.registry import HOOKS from colossalai.logging import DistributedLogger -from colossalai.utils import report_memory_usage, is_dp_rank_0, \ - is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer +from colossalai.registry import HOOKS +from colossalai.trainer.hooks._metric_hook import ThroughputMetric +from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage + from ._base_hook import BaseHook from ._commons_ import _format_number -from colossalai.trainer.hooks._metric_hook import ThroughputMetric class LogByEpochHook(BaseHook): diff --git a/colossalai/trainer/hooks/_lr_scheduler_hook.py b/colossalai/trainer/hooks/_lr_scheduler_hook.py index c6da33442dc3..0d19ab08a822 100644 --- a/colossalai/trainer/hooks/_lr_scheduler_hook.py +++ b/colossalai/trainer/hooks/_lr_scheduler_hook.py @@ -1,6 +1,7 @@ -from colossalai.registry import HOOKS from torch import Tensor +from colossalai.registry import HOOKS + from ._metric_hook import LearningRateMetric, MetricHook diff --git a/colossalai/trainer/hooks/_metric_hook.py b/colossalai/trainer/hooks/_metric_hook.py index 526d6c746ec6..96def4172fed 100644 --- a/colossalai/trainer/hooks/_metric_hook.py +++ b/colossalai/trainer/hooks/_metric_hook.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist + from colossalai.communication import all_reduce from colossalai.context import ParallelMode from colossalai.core import global_context as gpc @@ -19,8 +20,8 @@ class Metric(ABC): """A basic class of metric collectors. It collects a specific metric during training or evaluation and would always be used with - :class:`MetricHook` to help it update its states and show the - metric. So please use corresponding hook class to make the metric + :class:`MetricHook` to help it update its states and show the + metric. So please use corresponding hook class to make the metric collector works. Args: @@ -220,9 +221,9 @@ def is_better(a, b) -> bool: class MetricHook(BaseHook): - """Specialized hook classes for :class:`Metric`. - Some help metric collectors initialize, reset and - update their states. Others are used to display and + """Specialized hook classes for :class:`Metric`. + Some help metric collectors initialize, reset and + update their states. Others are used to display and record the metric. Args: diff --git a/colossalai/utils/activation_checkpoint.py b/colossalai/utils/activation_checkpoint.py index fa9ed827a8a7..ef0be929932b 100644 --- a/colossalai/utils/activation_checkpoint.py +++ b/colossalai/utils/activation_checkpoint.py @@ -1,13 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import weakref + import torch from torch.utils.checkpoint import check_backward_validity, detach_variable -from colossalai.context.random import get_states, get_current_mode, set_seed_states, set_mode, sync_states -from .cuda import get_current_device +from colossalai.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states -import weakref +from .cuda import get_current_device def copy_to_device(obj, device): @@ -143,7 +144,7 @@ def checkpoint(function, activation_offload, *args, use_reentrant: bool = True): Args: function: Describe the forward pass function. It should know how to handle the input tuples. - activation_offload: The variable to check whether we should offload activation to cpu + activation_offload: The variable to check whether we should offload activation to cpu args (list): Tuple containing the parameters of the function use_reentrant: Bool type to check if we need to use_reentrant, if use_reentrant=False, there might be more flexibility for user to define there checkpoint function diff --git a/colossalai/utils/checkpoint/__init__.py b/colossalai/utils/checkpoint/__init__.py index 1795b4ce36f4..558a956b31ac 100644 --- a/colossalai/utils/checkpoint/__init__.py +++ b/colossalai/utils/checkpoint/__init__.py @@ -1,3 +1,3 @@ -from .module_checkpoint import save_checkpoint, load_checkpoint +from .module_checkpoint import load_checkpoint, save_checkpoint __all__ = ['save_checkpoint', 'load_checkpoint'] diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index a109b3702577..d5b1ec762ffd 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -1,9 +1,11 @@ +from typing import Dict, Optional + import torch import torch.distributed as dist -from colossalai.tensor import ColoTensor + from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.tensor import ColoTensor from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor -from typing import Optional, Dict def save_checkpoint(path: str, @@ -13,7 +15,7 @@ def save_checkpoint(path: str, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, *args, **kwargs): - """save_checkpoint + """save_checkpoint save a model, whose parameters are `ColoTensor`s. Args: path (str): directory to save the checkpoint files. @@ -78,7 +80,7 @@ def load_checkpoint(path: str, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, torch_load_kwargs: Optional[Dict] = None, load_state_dict_kwargs: Optional[Dict] = None): - """load_checkpoint + """load_checkpoint load a model, whose parameters are `ColoTensor`s. Args: path (str): directory to save the checkpoint files. diff --git a/colossalai/utils/checkpoint/utils.py b/colossalai/utils/checkpoint/utils.py index 5652600ffd9b..61794975d1f5 100644 --- a/colossalai/utils/checkpoint/utils.py +++ b/colossalai/utils/checkpoint/utils.py @@ -1,63 +1,64 @@ -import torch -import torch.distributed as dist -from colossalai.tensor import ColoTensor, ColoTensorSpec -from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern - - -def robust_broadcast(tensor): - with torch.no_grad(): - is_cpu_ten = tensor.device.type == 'cpu' - if is_cpu_ten: - b_data = tensor.cuda() - else: - b_data = tensor - - dist.broadcast(b_data, 0) - - if is_cpu_ten: - tensor.copy_(b_data) - - -def gather_tensor(colo_tensor: ColoTensor) -> None: - """Make colo_tensor replicated when the rank is 0 - """ - if not colo_tensor.is_replicate(): - pg = colo_tensor.get_process_group() - # for the group which contains rank 0 - if pg.dp_local_rank() == 0: - old_dist_spec = colo_tensor.dist_spec - colo_tensor.to_replicate_() - if dist.get_rank() != 0: - colo_tensor.set_dist_spec(old_dist_spec) - - # synchronize all processes for unexpected problems - dist.barrier() - - if dist.get_rank() == 0: - setattr(colo_tensor, 'save_ready', True) # set saving signitrue - - -def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: - """Reversal operation of `gather_tensor`. - """ - if dist_spec.placement == DistPlacementPattern.REPLICATE: - robust_broadcast(colo_tensor.data) - else: - global_size = colo_tensor.size_global() - - if dist.get_rank() == 0: - entire_data = colo_tensor.data - else: - entire_data = torch.empty(global_size, device=colo_tensor.device) - robust_broadcast(entire_data) - - if dist.get_rank() == 0: - colo_tensor.set_dist_spec(dist_spec) - else: - rep_tensor = ColoTensor( - entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) - rep_tensor.set_dist_spec(dist_spec) - with torch.no_grad(): - colo_tensor.data.copy_(rep_tensor.data) - # synchronize all processes for unexpected problems - dist.barrier() +import torch +import torch.distributed as dist + +from colossalai.tensor import ColoTensor, ColoTensorSpec +from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec + + +def robust_broadcast(tensor): + with torch.no_grad(): + is_cpu_ten = tensor.device.type == 'cpu' + if is_cpu_ten: + b_data = tensor.cuda() + else: + b_data = tensor + + dist.broadcast(b_data, 0) + + if is_cpu_ten: + tensor.copy_(b_data) + + +def gather_tensor(colo_tensor: ColoTensor) -> None: + """Make colo_tensor replicated when the rank is 0 + """ + if not colo_tensor.is_replicate(): + pg = colo_tensor.get_process_group() + # for the group which contains rank 0 + if pg.dp_local_rank() == 0: + old_dist_spec = colo_tensor.dist_spec + colo_tensor.to_replicate_() + if dist.get_rank() != 0: + colo_tensor.set_dist_spec(old_dist_spec) + + # synchronize all processes for unexpected problems + dist.barrier() + + if dist.get_rank() == 0: + setattr(colo_tensor, 'save_ready', True) # set saving signitrue + + +def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None: + """Reversal operation of `gather_tensor`. + """ + if dist_spec.placement == DistPlacementPattern.REPLICATE: + robust_broadcast(colo_tensor.data) + else: + global_size = colo_tensor.size_global() + + if dist.get_rank() == 0: + entire_data = colo_tensor.data + else: + entire_data = torch.empty(global_size, device=colo_tensor.device) + robust_broadcast(entire_data) + + if dist.get_rank() == 0: + colo_tensor.set_dist_spec(dist_spec) + else: + rep_tensor = ColoTensor( + entire_data, ColoTensorSpec(pg=colo_tensor.get_process_group(), compute_attr=colo_tensor.compute_spec)) + rep_tensor.set_dist_spec(dist_spec) + with torch.no_grad(): + colo_tensor.data.copy_(rep_tensor.data) + # synchronize all processes for unexpected problems + dist.barrier() diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py index fe030866894f..df7144902ab5 100644 --- a/colossalai/utils/checkpoint_io/__init__.py +++ b/colossalai/utils/checkpoint_io/__init__.py @@ -1,2 +1,2 @@ from .io import load, merge, redist, save -from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta) +from .meta import ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py index 529ceb86829b..a8341c0f5172 100644 --- a/colossalai/utils/checkpoint_io/convertor.py +++ b/colossalai/utils/checkpoint_io/convertor.py @@ -6,7 +6,7 @@ from .distributed import merge_param, unmerge_param from .meta import ParamDistMeta, RedistMeta -from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none) +from .utils import ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none class CheckpointConvertor(ABC): diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py index bf720437c41a..7d0e171920d4 100644 --- a/colossalai/utils/checkpoint_io/distributed.py +++ b/colossalai/utils/checkpoint_io/distributed.py @@ -1,8 +1,10 @@ +from collections import defaultdict +from typing import List, Optional, Tuple + import torch from numpy import prod from torch import Tensor -from typing import List, Optional, Tuple -from collections import defaultdict + from .meta import ParamDistMeta, ParamRedistMeta diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py index f00212cdf859..e5c00a5054d3 100644 --- a/colossalai/utils/checkpoint_io/io.py +++ b/colossalai/utils/checkpoint_io/io.py @@ -6,8 +6,13 @@ from torch.optim import Optimizer from .backend import get_backend -from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, - OptimizerCheckpointRedistor) +from .convertor import ( + CheckpointConvertor, + ModelCheckpointMerger, + ModelCheckpointRedistor, + OptimizerCheckpointMerger, + OptimizerCheckpointRedistor, +) from .meta import ParamDistMeta, RedistMeta from .utils import build_checkpoints, optimizer_load_state_dict diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py index 994f08b4b5e4..ef825868a6f4 100644 --- a/colossalai/utils/checkpoint_io/meta.py +++ b/colossalai/utils/checkpoint_io/meta.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, Set, Dict +from typing import Dict, List, Optional, Set @dataclass diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py index 4552accde470..e63f206355b8 100644 --- a/colossalai/utils/checkpoint_io/writer.py +++ b/colossalai/utils/checkpoint_io/writer.py @@ -1,8 +1,16 @@ +import os from abc import ABC, abstractmethod from typing import Optional -from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME + import torch -import os + +from .constant import ( + GLOBAL_META_FILE_NAME, + META_CKPT_FILE_NAME, + MODEL_CKPT_FILE_NAME, + OPTIM_CKPT_FILE_NAME, + OTHER_CKPT_FILE_NAME, +) class CheckpointWriter(ABC): diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py index d1c6b6370ede..9f56dcaeb28d 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -3,9 +3,11 @@ import torch import torch.distributed as dist + +from colossalai.constants import IS_TENSOR_PARALLEL from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.constants import IS_TENSOR_PARALLEL + try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py index 60f3ccb60883..a0eff35966fc 100644 --- a/colossalai/utils/cuda.py +++ b/colossalai/utils/cuda.py @@ -23,7 +23,7 @@ def set_to_cuda(models): def get_current_device() -> torch.device: """ Returns currently selected device (gpu/cpu). - If cuda available, return gpu, otherwise return cpu. + If cuda available, return gpu, otherwise return cpu. """ if torch.cuda.is_available(): return torch.device(f'cuda:{torch.cuda.current_device()}') diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py index 945dc54b397a..570475df32da 100644 --- a/colossalai/utils/data_sampler/data_parallel_sampler.py +++ b/colossalai/utils/data_sampler/data_parallel_sampler.py @@ -4,11 +4,11 @@ import math import random -import numpy as np -from typing import TypeVar, Iterator +from typing import Iterator, TypeVar +import numpy as np import torch -from torch.utils.data import Sampler, Dataset, DataLoader +from torch.utils.data import DataLoader, Dataset, Sampler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -30,11 +30,7 @@ class DataParallelSampler(Sampler): the batch size, then the last batch will be smaller, defaults to False. """ - def __init__(self, - dataset: Dataset, - shuffle: bool = False, - seed: int = 0, - drop_last: bool = False) -> None: + def __init__(self, dataset: Dataset, shuffle: bool = False, seed: int = 0, drop_last: bool = False) -> None: self.dataset = dataset self.num_replicas = gpc.get_world_size(ParallelMode.DATA) self.rank = gpc.get_local_rank(ParallelMode.DATA) @@ -54,8 +50,7 @@ def __init__(self, self.num_replicas # type: ignore[arg-type] ) else: - self.num_samples = math.ceil( - len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed @@ -72,7 +67,7 @@ def __iter__(self) -> Iterator[T_co]: # set_epoch manually self.epoch += 1 else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] + indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible @@ -80,8 +75,7 @@ def __iter__(self) -> Iterator[T_co]: if padding_size <= len(indices): indices += indices[:padding_size] else: - indices += (indices * math.ceil(padding_size / - len(indices)))[:padding_size] + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] @@ -109,8 +103,8 @@ def set_epoch(self, epoch: int) -> None: def get_dataloader(dataset, shuffle=False, - seed=1024, - add_sampler=True, + seed=1024, + add_sampler=True, drop_last=False, pin_memory=False, num_workers=0, diff --git a/colossalai/utils/memory.py b/colossalai/utils/memory.py index 434e90edd3b9..c884a7572e0b 100644 --- a/colossalai/utils/memory.py +++ b/colossalai/utils/memory.py @@ -1,14 +1,14 @@ -import torch import gc -import psutil from collections import namedtuple +import psutil +import torch +from packaging import version + from colossalai.context.parallel_mode import ParallelMode -from colossalai.utils import get_current_device from colossalai.core import global_context as gpc -from colossalai.context.parallel_mode import ParallelMode from colossalai.logging import get_dist_logger -from packaging import version +from colossalai.utils import get_current_device _GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CPU_MEM_CAPACITY = -1 @@ -138,7 +138,7 @@ def colo_device_memory_used(device: torch.device) -> int: def colo_set_process_memory_fraction(ratio: float) -> None: - """colo_set_process_memory_fraction + """colo_set_process_memory_fraction set how much cuda memory used on the gpu belonging to the current process. diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py index 90783e5d9b8e..eb64d3d0467e 100644 --- a/colossalai/utils/moe.py +++ b/colossalai/utils/moe.py @@ -1,52 +1,55 @@ -import torch.nn as nn -import torch.distributed as dist -from colossalai.core import global_context as gpc -from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.context import ParallelMode -from .common import is_using_ddp -from typing import Dict, List - - -def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: - """Returns a parameter dictionary, the key of which is the expert parallel - size of every parameter. Since the parameters in data parallelism is replicated - in each GPU, we set their ep_size to 1. - - Args: - model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. - """ - epsize_param_dict = dict() - for param in model.parameters(): - if not hasattr(param, 'moe_info'): - ep_size = 1 # set ep_size to 1 for dp parameters - else: - ep_size = param.moe_info.ep_size - if ep_size not in epsize_param_dict: - epsize_param_dict[ep_size] = [] - epsize_param_dict[ep_size].append(param) - - return epsize_param_dict - - -def sync_moe_model_param(model: nn.Module): - """Make sure model parameters are consistent in MoE parallel context. - - Args: - model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. - """ - if is_using_ddp(): - - param_dict = get_moe_epsize_param_dict(model) - - # synchrosize the parameters whose dp_group is the whole world - if 1 in param_dict: - src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] - for param in param_dict[1]: - dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) - - for ep_size in param_dict: - # When ep_size = world_size, communication is not needed - if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: - src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) - for param in param_dict[ep_size]: - dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) +from typing import Dict, List + +import torch.distributed as dist +import torch.nn as nn + +from colossalai.context import ParallelMode +from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.core import global_context as gpc + +from .common import is_using_ddp + + +def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]: + """Returns a parameter dictionary, the key of which is the expert parallel + size of every parameter. Since the parameters in data parallelism is replicated + in each GPU, we set their ep_size to 1. + + Args: + model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict. + """ + epsize_param_dict = dict() + for param in model.parameters(): + if not hasattr(param, 'moe_info'): + ep_size = 1 # set ep_size to 1 for dp parameters + else: + ep_size = param.moe_info.ep_size + if ep_size not in epsize_param_dict: + epsize_param_dict[ep_size] = [] + epsize_param_dict[ep_size].append(param) + + return epsize_param_dict + + +def sync_moe_model_param(model: nn.Module): + """Make sure model parameters are consistent in MoE parallel context. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + if is_using_ddp(): + + param_dict = get_moe_epsize_param_dict(model) + + # synchrosize the parameters whose dp_group is the whole world + if 1 in param_dict: + src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0] + for param in param_dict[1]: + dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA)) + + for ep_size in param_dict: + # When ep_size = world_size, communication is not needed + if ep_size != 1 and ep_size != MOE_CONTEXT.world_size: + src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group) + for param in param_dict[ep_size]: + dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group) diff --git a/colossalai/utils/profiler/legacy/__init__.py b/colossalai/utils/profiler/legacy/__init__.py index 849c7fca3053..88beed86d7de 100644 --- a/colossalai/utils/profiler/legacy/__init__.py +++ b/colossalai/utils/profiler/legacy/__init__.py @@ -1,6 +1,6 @@ -from .comm_profiler import CommProfiler -from .pcie_profiler import PcieProfiler -from .prof_utils import ProfilerContext, BaseProfiler -from .mem_profiler import MemProfiler - -__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] +from .comm_profiler import CommProfiler +from .mem_profiler import MemProfiler +from .pcie_profiler import PcieProfiler +from .prof_utils import BaseProfiler, ProfilerContext + +__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] diff --git a/colossalai/utils/profiler/legacy/comm_profiler.py b/colossalai/utils/profiler/legacy/comm_profiler.py index a4f5729c97ec..0b6f5e31218d 100644 --- a/colossalai/utils/profiler/legacy/comm_profiler.py +++ b/colossalai/utils/profiler/legacy/comm_profiler.py @@ -1,308 +1,311 @@ -import inspect -from pathlib import Path -from functools import partial -import torch -from torch.autograd.profiler import profile -import torch.distributed as dist -from torch.distributed import ReduceOp -from colossalai.utils import get_current_device -from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth -from typing import List, Optional - - -def _get_code_location(depth: int): - ret = [] - length = min(len(inspect.stack()), depth + 1) - for i in range(3, length): - upper_frame = inspect.stack()[i] - function_name = inspect.stack()[i - 1].function - ret.append(upper_frame.filename) - ret.append('(') - ret.append(str(upper_frame.lineno)) - ret.append('): ') - ret.append(function_name) - if i != length - 1: - ret.append('\n') - - return ''.join(ret) - - -torch_all_reduce = dist.all_reduce -torch_all_gather = dist.all_gather -torch_reduce_scatter = dist.reduce_scatter -torch_broadcast = dist.broadcast -torch_reduce = dist.reduce - - -class CommEvent(object): - """Communication Event. Used for communication time and communication - volume recording. - """ - - def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): - self.self_count = count - self.self_comm_vol = comm_vol - self.self_cuda_time = cuda_time - - def add(self, rhs): - self.self_count += rhs.self_count - self.self_comm_vol += rhs.self_comm_vol - self.self_cuda_time += rhs.self_cuda_time - - -class CommProfiler(BaseProfiler): - """Communication profiler. Records all communication events. - """ - - def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): - super().__init__(profiler_name="Collective_Communication", priority=0) - self.depth = 3 + depth - self.total_count = total_count - self.total_comm_vol = total_comm_vol - self.total_cuda_time = total_cuda_time - - self.ops_record = dict() - self.profiler = None - self.pending_op = None - self.pending_metadata = None - self.warn_flag = False - - def reset(self): - self.total_count = 0 - self.total_comm_vol = 0 - self.total_cuda_time = 0 - - self.ops_record = dict() - self.profiler = None - self.pending_op = None - self.pending_metadata = None - self.warn_flag = False - - def enable(self): - dist.all_reduce = partial(all_reduce, profiler=self) - dist.all_gather = partial(all_gather, profiler=self) - dist.reduce_scatter = partial(reduce_scatter, profiler=self) - dist.broadcast = partial(broadcast, profiler=self) - dist.reduce = partial(reduce, profiler=self) - - def disable(self): - dist.all_reduce = torch_all_reduce - dist.all_gather = torch_all_gather - dist.reduce_scatter = torch_reduce_scatter - dist.broadcast = torch_broadcast - dist.reduce = torch_reduce - - def to_tensorboard(self, writer): - writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n")) - - def to_file(self, filename: Path): - with open(filename, "w") as f: - f.write(self.result_str()) - - def show(self): - print(self.result_str()) - - def result_str(self, sep: str = "\n"): - res = [] - - def append(s: str = None): - if s is not None: - res.append(s) - res.append(sep) - - if self.warn_flag: - append("Warnning: there exists multiple communication operations in the same time. As a result, " - "the profiling result is not accurate.") - - if self.total_cuda_time == 0: - return "No collective communication has been called yet!" - - append("Collective communication profiling result:") - append("total cuda time: {}".format(_format_time(self.total_cuda_time))) - append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time))) - append("total number of calls: {}".format(self.total_count)) - append("All events:") - - seperation = '-' * 74 - row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 - - append(seperation) - append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) - append(seperation) - - show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) - for location, event in show_list: - append(location) - append( - row_format.format('', _format_time(event.self_cuda_time), - '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), - _format_memory(event.self_comm_vol), - _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) - append() - - return ''.join(res) - - @property - def has_aync_op(self): - return self.pending_op is not None - - def activate_profiler(self, kn: str, vol: float): - self.pending_metadata = (kn, _get_code_location(self.depth), vol) - self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) - self.profiler.__enter__() - - def close_profiler(self, group=None): - assert self.profiler is not None, "There is no running dist op" - kernel_name, code_location, vol = self.pending_metadata - self.profiler.__exit__(None, None, None) - - if self.profiler.enabled and dist.get_world_size(group) > 1: - assert_flag = 0 - current_comm_event = None - events = self.profiler.function_events - for event in events: - if kernel_name in event.name: - assert assert_flag == 0, "Multiple dist ops has been called " - current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) - assert_flag += 1 - - assert current_comm_event is not None, "dist op has not been found" - - buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) - torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) - current_comm_event.self_cuda_time = buffer.item() - - self.total_count += current_comm_event.self_count - self.total_comm_vol += current_comm_event.self_comm_vol - self.total_cuda_time += current_comm_event.self_cuda_time - if code_location in self.ops_record: - self.ops_record[code_location].add(current_comm_event) - else: - self.ops_record[code_location] = current_comm_event - - self.profiler = None - self.pending_op = None - self.pending_metadata = None - - def wait_async_op(self): - if self.pending_op is not None: - op = self.pending_op - op.wait() - self.close_profiler() - - -class CommHandler(object): - """Communication handler. A dummy handler to wait aync operations. - """ - - def __init__(self, profiler: CommProfiler): - super().__init__() - self.prof = profiler - - def wait(self): - self.prof.wait_async_op() - - -def async_check(profiler: CommProfiler): - if profiler.pending_op is not None: - profiler.warn_flag = True - profiler.wait_async_op() - - -def all_reduce(tensor: torch.Tensor, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = 2 * (comm_size - 1) / comm_size - comm_vol = correction * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) - profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def reduce_scatter(output: torch.Tensor, - input_list: List[torch.Tensor], - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = (comm_size - 1) / comm_size - comm_vol = 0 - for tensor in input_list: - comm_vol += tensor.element_size() * tensor.numel() - comm_vol *= correction - profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) - profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def all_gather(tensor_list: List[torch.Tensor], - tensor: torch.Tensor, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_size = dist.get_world_size(group) - correction = (comm_size - 1) / comm_size - comm_vol = 0 - for ten in tensor_list: - comm_vol += ten.element_size() * ten.numel() - comm_vol *= correction - profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) - profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def broadcast(tensor: torch.Tensor, - src: int, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_vol = 1.0 * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) - profiler.pending_op = torch_broadcast(tensor, src, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) - - -def reduce(tensor: torch.Tensor, - dst: int, - op: ReduceOp = ReduceOp.SUM, - group=None, - async_op: bool = False, - profiler: CommProfiler = None) -> Optional[CommHandler]: - async_check(profiler) - - comm_vol = 1.0 * tensor.element_size() * tensor.numel() - profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) - profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) - - if async_op: - return CommHandler(profiler) - - profiler.close_profiler(group) +import inspect +from functools import partial +from pathlib import Path +from typing import List, Optional + +import torch +import torch.distributed as dist +from torch.autograd.profiler import profile +from torch.distributed import ReduceOp + +from colossalai.utils import get_current_device + +from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time + + +def _get_code_location(depth: int): + ret = [] + length = min(len(inspect.stack()), depth + 1) + for i in range(3, length): + upper_frame = inspect.stack()[i] + function_name = inspect.stack()[i - 1].function + ret.append(upper_frame.filename) + ret.append('(') + ret.append(str(upper_frame.lineno)) + ret.append('): ') + ret.append(function_name) + if i != length - 1: + ret.append('\n') + + return ''.join(ret) + + +torch_all_reduce = dist.all_reduce +torch_all_gather = dist.all_gather +torch_reduce_scatter = dist.reduce_scatter +torch_broadcast = dist.broadcast +torch_reduce = dist.reduce + + +class CommEvent(object): + """Communication Event. Used for communication time and communication + volume recording. + """ + + def __init__(self, count: int = 0, comm_vol: float = 0., cuda_time: int = 0): + self.self_count = count + self.self_comm_vol = comm_vol + self.self_cuda_time = cuda_time + + def add(self, rhs): + self.self_count += rhs.self_count + self.self_comm_vol += rhs.self_comm_vol + self.self_cuda_time += rhs.self_cuda_time + + +class CommProfiler(BaseProfiler): + """Communication profiler. Records all communication events. + """ + + def __init__(self, depth: int = 0, total_count: int = 0, total_comm_vol: float = 0, total_cuda_time: int = 0): + super().__init__(profiler_name="Collective_Communication", priority=0) + self.depth = 3 + depth + self.total_count = total_count + self.total_comm_vol = total_comm_vol + self.total_cuda_time = total_cuda_time + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def reset(self): + self.total_count = 0 + self.total_comm_vol = 0 + self.total_cuda_time = 0 + + self.ops_record = dict() + self.profiler = None + self.pending_op = None + self.pending_metadata = None + self.warn_flag = False + + def enable(self): + dist.all_reduce = partial(all_reduce, profiler=self) + dist.all_gather = partial(all_gather, profiler=self) + dist.reduce_scatter = partial(reduce_scatter, profiler=self) + dist.broadcast = partial(broadcast, profiler=self) + dist.reduce = partial(reduce, profiler=self) + + def disable(self): + dist.all_reduce = torch_all_reduce + dist.all_gather = torch_all_gather + dist.reduce_scatter = torch_reduce_scatter + dist.broadcast = torch_broadcast + dist.reduce = torch_reduce + + def to_tensorboard(self, writer): + writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + if self.warn_flag: + append("Warnning: there exists multiple communication operations in the same time. As a result, " + "the profiling result is not accurate.") + + if self.total_cuda_time == 0: + return "No collective communication has been called yet!" + + append("Collective communication profiling result:") + append("total cuda time: {}".format(_format_time(self.total_cuda_time))) + append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time))) + append("total number of calls: {}".format(self.total_count)) + append("All events:") + + seperation = '-' * 74 + row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2 + + append(seperation) + append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls')) + append(seperation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format('', _format_time(event.self_cuda_time), + '{:.1f}%'.format(event.self_cuda_time / self.total_cuda_time * 100.0), + _format_memory(event.self_comm_vol), + _format_bandwidth(event.self_comm_vol, event.self_cuda_time), event.self_count)) + append() + + return ''.join(res) + + @property + def has_aync_op(self): + return self.pending_op is not None + + def activate_profiler(self, kn: str, vol: float): + self.pending_metadata = (kn, _get_code_location(self.depth), vol) + self.profiler = profile(enabled=True, use_cuda=True, use_cpu=True, use_kineto=True) + self.profiler.__enter__() + + def close_profiler(self, group=None): + assert self.profiler is not None, "There is no running dist op" + kernel_name, code_location, vol = self.pending_metadata + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled and dist.get_world_size(group) > 1: + assert_flag = 0 + current_comm_event = None + events = self.profiler.function_events + for event in events: + if kernel_name in event.name: + assert assert_flag == 0, "Multiple dist ops has been called " + current_comm_event = CommEvent(1, vol, event.self_cuda_time_total) + assert_flag += 1 + + assert current_comm_event is not None, "dist op has not been found" + + buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) + torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) + current_comm_event.self_cuda_time = buffer.item() + + self.total_count += current_comm_event.self_count + self.total_comm_vol += current_comm_event.self_comm_vol + self.total_cuda_time += current_comm_event.self_cuda_time + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + + self.profiler = None + self.pending_op = None + self.pending_metadata = None + + def wait_async_op(self): + if self.pending_op is not None: + op = self.pending_op + op.wait() + self.close_profiler() + + +class CommHandler(object): + """Communication handler. A dummy handler to wait aync operations. + """ + + def __init__(self, profiler: CommProfiler): + super().__init__() + self.prof = profiler + + def wait(self): + self.prof.wait_async_op() + + +def async_check(profiler: CommProfiler): + if profiler.pending_op is not None: + profiler.warn_flag = True + profiler.wait_async_op() + + +def all_reduce(tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = 2 * (comm_size - 1) / comm_size + comm_vol = correction * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_AllReduce_", comm_vol) + profiler.pending_op = torch_all_reduce(tensor, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce_scatter(output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for tensor in input_list: + comm_vol += tensor.element_size() * tensor.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_ReduceScatter_", comm_vol) + profiler.pending_op = torch_reduce_scatter(output, input_list, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def all_gather(tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_size = dist.get_world_size(group) + correction = (comm_size - 1) / comm_size + comm_vol = 0 + for ten in tensor_list: + comm_vol += ten.element_size() * ten.numel() + comm_vol *= correction + profiler.activate_profiler("ncclKernel_AllGather_", comm_vol) + profiler.pending_op = torch_all_gather(tensor_list, tensor, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def broadcast(tensor: torch.Tensor, + src: int, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Broadcast_", comm_vol) + profiler.pending_op = torch_broadcast(tensor, src, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) + + +def reduce(tensor: torch.Tensor, + dst: int, + op: ReduceOp = ReduceOp.SUM, + group=None, + async_op: bool = False, + profiler: CommProfiler = None) -> Optional[CommHandler]: + async_check(profiler) + + comm_vol = 1.0 * tensor.element_size() * tensor.numel() + profiler.activate_profiler("ncclKernel_Reduce_", comm_vol) + profiler.pending_op = torch_reduce(tensor, dst, op, group, async_op) + + if async_op: + return CommHandler(profiler) + + profiler.close_profiler(group) diff --git a/colossalai/utils/profiler/legacy/pcie_profiler.py b/colossalai/utils/profiler/legacy/pcie_profiler.py index 526222941ef9..b50a6e4c054a 100644 --- a/colossalai/utils/profiler/legacy/pcie_profiler.py +++ b/colossalai/utils/profiler/legacy/pcie_profiler.py @@ -1,148 +1,150 @@ -from pathlib import Path -from torch.autograd.profiler import profile -from .prof_utils import BaseProfiler, _format_time, _format_memory, _format_bandwidth -from typing import List - - -def _get_size(dtype: str): - if dtype == "fp16": - return 2 - elif dtype == "fp32": - return 4 - else: - raise NotImplementedError - - -def _get_numel(my_list: List[int]) -> int: - from functools import reduce - from operator import mul - return reduce(mul, my_list) - - -def _reduce_location(locations: List[str]) -> str: - ret = [] - for lo in locations: - ret.append(lo) - ret.append("\n") - ret = ret[:-1] - return ''.join(ret) - - -class PcieEvent(object): - """Pcie Event. - """ - - def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): - self.count = count - self.pcie_vol = pcie_vol - self.cuda_time = cuda_time - - def add(self, rhs): - self.count += rhs.count - self.pcie_vol += rhs.pcie_vol - self.cuda_time += rhs.cuda_time - - -class PcieProfiler(BaseProfiler): - """Pcie profiler. Records all data transmission between CPU and GPU. - - TODO: Merge pcie profiler into communication profiler - """ - - def __init__(self, dtype: str = "fp32", depth: int = 1): - super().__init__(profiler_name="Pcie", priority=10) - self.depth = depth - self.data_size = _get_size(dtype) - self.h2d_count = 0 - self.h2d_time = 0 - self.d2h_count = 0 - self.d2h_time = 0 - - self.ops_record = dict() - self.profiler = None - - def reset(self): - self.h2d_count = 0 - self.h2d_time = 0 - self.d2h_count = 0 - self.d2h_time = 0 - - self.ops_record = dict() - self.profiler = None - - def enable(self): - self.profiler = profile(enabled=True, - use_cuda=True, - use_cpu=True, - use_kineto=True, - record_shapes=True, - with_stack=True) - self.profiler.__enter__() - - def disable(self): - self.profiler.__exit__(None, None, None) - - if self.profiler.enabled: - events = self.profiler.function_events - for event in events: - if event.name == "aten::copy_": - t_shape = event.input_shapes[0] - if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: - continue - current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) - code_location = _reduce_location(event.stack[:self.depth]) - if code_location in self.ops_record: - self.ops_record[code_location].add(current_comm_event) - else: - self.ops_record[code_location] = current_comm_event - elif 'Memcpy HtoD' in event.name: - self.h2d_count += 1 - self.h2d_time += event.cuda_time_total - elif 'Memcpy DtoH' in event.name: - self.d2h_count += 1 - self.d2h_time += event.cuda_time_total - - self.profiler = None - - def to_tensorboard(self, writer): - writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n")) - - def to_file(self, filename: Path): - with open(filename, "w") as f: - f.write(self.result_str()) - - def show(self): - print(self.result_str()) - - def result_str(self, sep: str = "\n"): - res = [] - - def append(s: str = None): - if s is not None: - res.append(s) - res.append(sep) - - append("Pcie profiling result:") - append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time))) - append("number of transmission (CPU -> GPU): {}".format(self.h2d_count)) - append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time))) - append("number of transmission (GPU -> CPU): {}".format(self.d2h_count)) - - append("Possible data transmission events in PCIE:") - - seperation = '-' * 62 - row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 - - append(seperation) - append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) - append(seperation) - - show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) - for location, event in show_list: - append(location) - append( - row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), - _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) - append() - - return ''.join(res) +from pathlib import Path +from typing import List + +from torch.autograd.profiler import profile + +from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time + + +def _get_size(dtype: str): + if dtype == "fp16": + return 2 + elif dtype == "fp32": + return 4 + else: + raise NotImplementedError + + +def _get_numel(my_list: List[int]) -> int: + from functools import reduce + from operator import mul + return reduce(mul, my_list) + + +def _reduce_location(locations: List[str]) -> str: + ret = [] + for lo in locations: + ret.append(lo) + ret.append("\n") + ret = ret[:-1] + return ''.join(ret) + + +class PcieEvent(object): + """Pcie Event. + """ + + def __init__(self, count: int = 0, pcie_vol: int = 0, cuda_time: int = 0): + self.count = count + self.pcie_vol = pcie_vol + self.cuda_time = cuda_time + + def add(self, rhs): + self.count += rhs.count + self.pcie_vol += rhs.pcie_vol + self.cuda_time += rhs.cuda_time + + +class PcieProfiler(BaseProfiler): + """Pcie profiler. Records all data transmission between CPU and GPU. + + TODO: Merge pcie profiler into communication profiler + """ + + def __init__(self, dtype: str = "fp32", depth: int = 1): + super().__init__(profiler_name="Pcie", priority=10) + self.depth = depth + self.data_size = _get_size(dtype) + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def reset(self): + self.h2d_count = 0 + self.h2d_time = 0 + self.d2h_count = 0 + self.d2h_time = 0 + + self.ops_record = dict() + self.profiler = None + + def enable(self): + self.profiler = profile(enabled=True, + use_cuda=True, + use_cpu=True, + use_kineto=True, + record_shapes=True, + with_stack=True) + self.profiler.__enter__() + + def disable(self): + self.profiler.__exit__(None, None, None) + + if self.profiler.enabled: + events = self.profiler.function_events + for event in events: + if event.name == "aten::copy_": + t_shape = event.input_shapes[0] + if len(t_shape) == 0 or event.cuda_time_total == 0 or len(event.stack) == 0: + continue + current_comm_event = PcieEvent(1, self.data_size * _get_numel(t_shape), event.cuda_time_total) + code_location = _reduce_location(event.stack[:self.depth]) + if code_location in self.ops_record: + self.ops_record[code_location].add(current_comm_event) + else: + self.ops_record[code_location] = current_comm_event + elif 'Memcpy HtoD' in event.name: + self.h2d_count += 1 + self.h2d_time += event.cuda_time_total + elif 'Memcpy DtoH' in event.name: + self.d2h_count += 1 + self.d2h_time += event.cuda_time_total + + self.profiler = None + + def to_tensorboard(self, writer): + writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n")) + + def to_file(self, filename: Path): + with open(filename, "w") as f: + f.write(self.result_str()) + + def show(self): + print(self.result_str()) + + def result_str(self, sep: str = "\n"): + res = [] + + def append(s: str = None): + if s is not None: + res.append(s) + res.append(sep) + + append("Pcie profiling result:") + append("time of data transmission (CPU -> GPU): {}".format(_format_time(self.h2d_time))) + append("number of transmission (CPU -> GPU): {}".format(self.h2d_count)) + append("time of data transmission (GPU -> CPU): {}".format(_format_time(self.d2h_time))) + append("number of transmission (GPU -> CPU): {}".format(self.d2h_count)) + + append("Possible data transmission events in PCIE:") + + seperation = '-' * 62 + row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2 + + append(seperation) + append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls')) + append(seperation) + + show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time) + for location, event in show_list: + append(location) + append( + row_format.format('', _format_time(event.cuda_time), _format_memory(event.pcie_vol), + _format_bandwidth(event.pcie_vol, event.cuda_time), event.count)) + append() + + return ''.join(res) diff --git a/colossalai/utils/profiler/legacy/prof_utils.py b/colossalai/utils/profiler/legacy/prof_utils.py index 87ad644a7ecc..2b613bbd611e 100644 --- a/colossalai/utils/profiler/legacy/prof_utils.py +++ b/colossalai/utils/profiler/legacy/prof_utils.py @@ -1,131 +1,132 @@ -from abc import ABC, abstractmethod -from pathlib import Path -from typing import Union, List -from colossalai.core import global_context as gpc - - -# copied from high version pytorch to support low version -def _format_time(time_us): - """Defines how to format time in FunctionEvent""" - US_IN_SECOND = 1000.0 * 1000.0 - US_IN_MS = 1000.0 - if time_us >= US_IN_SECOND: - return '{:.3f}s'.format(time_us / US_IN_SECOND) - if time_us >= US_IN_MS: - return '{:.3f}ms'.format(time_us / US_IN_MS) - return '{:.3f}us'.format(time_us) - - -# copied from high version pytorch to support low version -def _format_memory(nbytes): - """Returns a formatted memory size string""" - KB = 1024 - MB = 1024 * KB - GB = 1024 * MB - if (abs(nbytes) >= GB): - return '{:.2f} GB'.format(nbytes * 1.0 / GB) - elif (abs(nbytes) >= MB): - return '{:.2f} MB'.format(nbytes * 1.0 / MB) - elif (abs(nbytes) >= KB): - return '{:.2f} KB'.format(nbytes * 1.0 / KB) - else: - return str(nbytes) + ' B' - - -def _format_bandwidth(volme: float or int, time_us: int): - sec_div_mb = (1000.0 / 1024.0)**2 - mb_per_sec = volme / time_us * sec_div_mb - - if mb_per_sec >= 1024.0: - return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) - else: - return '{:.3f} MB/s'.format(mb_per_sec) - - -class BaseProfiler(ABC): - - def __init__(self, profiler_name: str, priority: int): - self.name = profiler_name - self.priority = priority - - @abstractmethod - def enable(self): - pass - - @abstractmethod - def disable(self): - pass - - @abstractmethod - def to_tensorboard(self, writer): - pass - - @abstractmethod - def to_file(self, filename: Path): - pass - - @abstractmethod - def show(self): - pass - - -class ProfilerContext(object): - """Profiler context manager - - Usage:: - - world_size = 4 - inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) - outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) - outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) - - cc_prof = CommProfiler() - - with ProfilerContext([cc_prof]) as prof: - op = dist.all_reduce(inputs, async_op=True) - dist.all_gather(outputs_list, inputs) - op.wait() - dist.reduce_scatter(inputs, outputs_list) - dist.broadcast(inputs, 0) - dist.reduce(inputs, 0) - - prof.show() - """ - - def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): - self.enable = enable - self.profilers = sorted(profilers, key=lambda prof: prof.priority) - - def __enter__(self): - if self.enable: - for prof in self.profilers: - prof.enable() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.enable: - for prof in self.profilers: - prof.disable() - - def to_tensorboard(self, writer): - from torch.utils.tensorboard import SummaryWriter - - assert isinstance(writer, SummaryWriter), \ - f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' - - for prof in self.profilers: - prof.to_tensorboard(writer) - - def to_file(self, log_dir: Union[str, Path]): - if isinstance(log_dir, str): - log_dir = Path(log_dir) - - if not log_dir.exists(): - log_dir.mkdir(parents=True, exist_ok=True) - for prof in self.profilers: - log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') - prof.to_file(log_file) - - def show(self): - for prof in self.profilers: - prof.show() +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Union + +from colossalai.core import global_context as gpc + + +# copied from high version pytorch to support low version +def _format_time(time_us): + """Defines how to format time in FunctionEvent""" + US_IN_SECOND = 1000.0 * 1000.0 + US_IN_MS = 1000.0 + if time_us >= US_IN_SECOND: + return '{:.3f}s'.format(time_us / US_IN_SECOND) + if time_us >= US_IN_MS: + return '{:.3f}ms'.format(time_us / US_IN_MS) + return '{:.3f}us'.format(time_us) + + +# copied from high version pytorch to support low version +def _format_memory(nbytes): + """Returns a formatted memory size string""" + KB = 1024 + MB = 1024 * KB + GB = 1024 * MB + if (abs(nbytes) >= GB): + return '{:.2f} GB'.format(nbytes * 1.0 / GB) + elif (abs(nbytes) >= MB): + return '{:.2f} MB'.format(nbytes * 1.0 / MB) + elif (abs(nbytes) >= KB): + return '{:.2f} KB'.format(nbytes * 1.0 / KB) + else: + return str(nbytes) + ' B' + + +def _format_bandwidth(volme: float or int, time_us: int): + sec_div_mb = (1000.0 / 1024.0)**2 + mb_per_sec = volme / time_us * sec_div_mb + + if mb_per_sec >= 1024.0: + return '{:.3f} GB/s'.format(mb_per_sec / 1024.0) + else: + return '{:.3f} MB/s'.format(mb_per_sec) + + +class BaseProfiler(ABC): + + def __init__(self, profiler_name: str, priority: int): + self.name = profiler_name + self.priority = priority + + @abstractmethod + def enable(self): + pass + + @abstractmethod + def disable(self): + pass + + @abstractmethod + def to_tensorboard(self, writer): + pass + + @abstractmethod + def to_file(self, filename: Path): + pass + + @abstractmethod + def show(self): + pass + + +class ProfilerContext(object): + """Profiler context manager + + Usage:: + + world_size = 4 + inputs = torch.randn(10, 10, dtype=torch.float32, device=get_current_device()) + outputs = torch.empty(world_size, 10, 10, dtype=torch.float32, device=get_current_device()) + outputs_list = list(torch.chunk(outputs, chunks=world_size, dim=0)) + + cc_prof = CommProfiler() + + with ProfilerContext([cc_prof]) as prof: + op = dist.all_reduce(inputs, async_op=True) + dist.all_gather(outputs_list, inputs) + op.wait() + dist.reduce_scatter(inputs, outputs_list) + dist.broadcast(inputs, 0) + dist.reduce(inputs, 0) + + prof.show() + """ + + def __init__(self, profilers: List[BaseProfiler] = None, enable: bool = True): + self.enable = enable + self.profilers = sorted(profilers, key=lambda prof: prof.priority) + + def __enter__(self): + if self.enable: + for prof in self.profilers: + prof.enable() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + for prof in self.profilers: + prof.disable() + + def to_tensorboard(self, writer): + from torch.utils.tensorboard import SummaryWriter + + assert isinstance(writer, SummaryWriter), \ + f'torch.utils.tensorboard.SummaryWriter is required, but found {type(writer)}.' + + for prof in self.profilers: + prof.to_tensorboard(writer) + + def to_file(self, log_dir: Union[str, Path]): + if isinstance(log_dir, str): + log_dir = Path(log_dir) + + if not log_dir.exists(): + log_dir.mkdir(parents=True, exist_ok=True) + for prof in self.profilers: + log_file = log_dir.joinpath(f'{prof.name}_rank_{gpc.get_global_rank()}.log') + prof.to_file(log_file) + + def show(self): + for prof in self.profilers: + prof.show() diff --git a/colossalai/utils/profiler/profiler.py b/colossalai/utils/profiler/profiler.py index 8f43a0b96de0..c355127b26fa 100644 --- a/colossalai/utils/profiler/profiler.py +++ b/colossalai/utils/profiler/profiler.py @@ -1,17 +1,17 @@ -import os -from typing import List -from colossalai.engine import Engine -from torch.profiler import profile as torch_profile -from torch.profiler.profiler import ProfilerAction -from typing import Any, Callable, Iterable, Optional -from torch.autograd import ProfilerActivity +import gzip import json import os import tempfile -import gzip +from typing import Any, Callable, Iterable, List, Optional + +from torch.autograd import ProfilerActivity +from torch.profiler import profile as torch_profile +from torch.profiler.profiler import ProfilerAction + +from colossalai.engine import Engine +from colossalai.logging import get_dist_logger from colossalai.utils.profiler.extention import ProfilerExtension from colossalai.utils.profiler.stateful_tensor_mem_extention import StatefulTensorMemoryProfilerExtention -from colossalai.logging import get_dist_logger class profile(torch_profile): diff --git a/colossalai/utils/profiler/stateful_tensor_mem_extention.py b/colossalai/utils/profiler/stateful_tensor_mem_extention.py index 127055c8c1ef..26a6b35268e3 100644 --- a/colossalai/utils/profiler/stateful_tensor_mem_extention.py +++ b/colossalai/utils/profiler/stateful_tensor_mem_extention.py @@ -1,12 +1,14 @@ import os import threading import time -import torch from enum import Enum from typing import List -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.gemini.ophooks import BaseOpHook + +import torch + from colossalai.engine import Engine +from colossalai.gemini.ophooks import BaseOpHook +from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.utils.profiler.extention import ProfilerExtension diff --git a/colossalai/utils/rank_recorder/README.md b/colossalai/utils/rank_recorder/README.md index e30a925d2a92..65c1297ed6f2 100644 --- a/colossalai/utils/rank_recorder/README.md +++ b/colossalai/utils/rank_recorder/README.md @@ -1,7 +1,7 @@ # Rank Recorder This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily. -Before using the tool, you should ensure dist.is_initialized() return true before exit of program. +Before using the tool, you should ensure dist.is_initialized() return true before exit of program. ## Usage @@ -58,10 +58,10 @@ def worker(rank): with recorder("calc_1(x100)", rank) as r: calc(100, 100) - + with recorder("calc_2(x400)", rank) as r: calc(400, 400) - + with recorder("calc_2(x200)", rank) as r: calc(200, 200) @@ -69,4 +69,4 @@ if __name__ == "__main__": mp.spawn(worker, nprocs=WORLD_SIZE) ``` -run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. \ No newline at end of file +run the script directly and you will get `kernel_select.json` and `kernel_select.png` in your current folder. diff --git a/colossalai/utils/rank_recorder/__init__.py b/colossalai/utils/rank_recorder/__init__.py index 1274d0e7dbc5..1d347075a8ce 100644 --- a/colossalai/utils/rank_recorder/__init__.py +++ b/colossalai/utils/rank_recorder/__init__.py @@ -1,3 +1,3 @@ from colossalai.utils.rank_recorder.rank_recorder import recorder -__all__ = ["recorder"] \ No newline at end of file +__all__ = ["recorder"] diff --git a/colossalai/utils/rank_recorder/rank_recorder.py b/colossalai/utils/rank_recorder/rank_recorder.py index c088ceeb2e87..7726d51cbbfc 100644 --- a/colossalai/utils/rank_recorder/rank_recorder.py +++ b/colossalai/utils/rank_recorder/rank_recorder.py @@ -1,18 +1,15 @@ -import time -from typing import List, Dict +import atexit import json import os -import time import shutil -import atexit +import time +from typing import Dict, List +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt import torch import torch.distributed as dist -import json -import matplotlib.pyplot as plt -import matplotlib.colors as mcolors - cmap = list(mcolors.TABLEAU_COLORS.values()) LOG_FOLDER = "record.log" diff --git a/colossalai/utils/tensor_detector/__init__.py b/colossalai/utils/tensor_detector/__init__.py index cafc19b67c5c..c6c68aa4009b 100644 --- a/colossalai/utils/tensor_detector/__init__.py +++ b/colossalai/utils/tensor_detector/__init__.py @@ -1 +1 @@ -from .tensor_detector import TensorDetector +from .tensor_detector import TensorDetector diff --git a/colossalai/utils/tensor_detector/readme.md b/colossalai/utils/tensor_detector/readme.md index 840dc8f4eca6..8acf2ec833b5 100644 --- a/colossalai/utils/tensor_detector/readme.md +++ b/colossalai/utils/tensor_detector/readme.md @@ -14,7 +14,7 @@ class MLP(nn.Module): super().__init__() self.mlp = nn.Sequential(nn.Linear(64, 8), nn.ReLU(), - nn.Linear(8, 32)) + nn.Linear(8, 32)) def forward(self, x): return self.mlp(x) ``` @@ -125,4 +125,3 @@ Totle GPU Memery Allocated on cuda:0 is 14.0 KB This tool was inspired by https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py and https://github.com/Oldpan/Pytorch-Memory-Utils - diff --git a/colossalai/utils/tensor_detector/tensor_detector.py b/colossalai/utils/tensor_detector/tensor_detector.py index a8186f76834c..8b8916a099c4 100644 --- a/colossalai/utils/tensor_detector/tensor_detector.py +++ b/colossalai/utils/tensor_detector/tensor_detector.py @@ -1,9 +1,10 @@ import gc import inspect +from collections import defaultdict +from typing import Optional + import torch import torch.nn as nn -from typing import Optional -from collections import defaultdict LINE_WIDTH = 108 LINE = '-' * LINE_WIDTH + '\n' diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py index 4b61f4a5ef11..3874c1e0fb5b 100644 --- a/colossalai/utils/timer.py +++ b/colossalai/utils/timer.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import time from typing import Tuple + from .cuda import synchronize diff --git a/colossalai/zero/shard_utils/base_shard_strategy.py b/colossalai/zero/shard_utils/base_shard_strategy.py index 7c2f4c9f6659..8c2ccf9124ff 100644 --- a/colossalai/zero/shard_utils/base_shard_strategy.py +++ b/colossalai/zero/shard_utils/base_shard_strategy.py @@ -2,6 +2,7 @@ from typing import List, Optional import torch.distributed as dist + from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py index a7bd7cf538e7..d7bf0438bd83 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py @@ -2,17 +2,18 @@ import torch import torch.distributed as dist +from torch._utils import _flatten_dense_tensors as flatten + from colossalai.utils import get_current_device from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from torch._utils import _flatten_dense_tensors as flatten from .tensor_shard_strategy import TensorShardStrategy class BucketTensorShardStrategy(TensorShardStrategy): - """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, - which will fully utilize network bandwidth. - It is especially useful when sub-module contains bias, + """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, + which will fully utilize network bandwidth. + It is especially useful when sub-module contains bias, since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small). """ diff --git a/colossalai/zero/shard_utils/commons.py b/colossalai/zero/shard_utils/commons.py index 71cef44c177f..ea2536c8ea7f 100644 --- a/colossalai/zero/shard_utils/commons.py +++ b/colossalai/zero/shard_utils/commons.py @@ -1,6 +1,7 @@ +from typing import Tuple + import torch import torch.nn.functional as F -from typing import Tuple def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 5bdd95400d82..55c983dab968 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -2,11 +2,12 @@ import torch import torch.distributed as dist + +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.utils import get_current_device from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline class TensorShardStrategy(BaseShardStrategy): @@ -27,7 +28,7 @@ def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGr Args: t (ShardedTensor): a tensor to be sharded. - process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. + process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. Defaults to None. """ if t.is_sharded: diff --git a/colossalai/zero/sharded_model/__init__.py b/colossalai/zero/sharded_model/__init__.py index 725179295c60..93120bdc34b4 100644 --- a/colossalai/zero/sharded_model/__init__.py +++ b/colossalai/zero/sharded_model/__init__.py @@ -1,3 +1,3 @@ from .sharded_model_v2 import ShardedModelV2 -__all__ = ['ShardedModelV2'] \ No newline at end of file +__all__ = ['ShardedModelV2'] diff --git a/colossalai/zero/sharded_model/_utils.py b/colossalai/zero/sharded_model/_utils.py index 85a3ab73dd1b..bb4ae009b2a0 100644 --- a/colossalai/zero/sharded_model/_utils.py +++ b/colossalai/zero/sharded_model/_utils.py @@ -1,8 +1,8 @@ -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Tuple, Union import torch import torch.nn.functional as F -from typing import Union + from colossalai.gemini.stateful_tensor import StatefulTensor diff --git a/colossalai/zero/sharded_model/utils.py b/colossalai/zero/sharded_model/utils.py index 69f5a23ac920..fdd1f4c4beee 100644 --- a/colossalai/zero/sharded_model/utils.py +++ b/colossalai/zero/sharded_model/utils.py @@ -1,7 +1,8 @@ +import copy + import torch -from colossalai.zero.sharded_model import ShardedModelV2 -import copy +from colossalai.zero.sharded_model import ShardedModelV2 def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py index 5642a504acf7..a68073b91d6f 100644 --- a/colossalai/zero/sharded_param/__init__.py +++ b/colossalai/zero/sharded_param/__init__.py @@ -1,4 +1,4 @@ -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor __all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index db0f2d149431..3b7731a3c41a 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -1,9 +1,10 @@ +from typing import List, Optional, Tuple + import torch -from typing import Optional, Tuple -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.gemini.tensor_utils import colo_tensor_mem_usage + from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState -from typing import List +from colossalai.gemini.tensor_utils import colo_tensor_mem_usage +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor EMPTY_TENSOR_DICT = {} diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index 77f4aec30f32..ae14f92975c2 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -1,4 +1,5 @@ import torch + from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState diff --git a/colossalai/zero/utils/__init__.py b/colossalai/zero/utils/__init__.py index c4e687228957..566396b1c827 100644 --- a/colossalai/zero/utils/__init__.py +++ b/colossalai/zero/utils/__init__.py @@ -1,3 +1,3 @@ from .zero_hook import ZeroHook -__all__ = ['ZeroHook'] \ No newline at end of file +__all__ = ['ZeroHook'] diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md index a70792b9f4a4..22be84a9944d 100644 --- a/examples/images/diffusion/README.md +++ b/examples/images/diffusion/README.md @@ -56,7 +56,7 @@ pip install transformers diffusers invisible-watermark #### Step 2: install lightning -Install Lightning version later than 2022.01.04. We suggest you install lightning from source. Notice that the default download path of pip should be within the conda environment, or you may need to specify using 'which pip' and redirect the path into conda environment. +Install Lightning version later than 2022.01.04. We suggest you install lightning from source. Notice that the default download path of pip should be within the conda environment, or you may need to specify using 'which pip' and redirect the path into conda environment. ##### From Source ``` diff --git a/examples/images/diffusion/configs/train_ddp.yaml b/examples/images/diffusion/configs/train_ddp.yaml index a63df887e719..d73a86eb478f 100644 --- a/examples/images/diffusion/configs/train_ddp.yaml +++ b/examples/images/diffusion/configs/train_ddp.yaml @@ -93,7 +93,7 @@ data: lightning: trainer: - accelerator: 'gpu' + accelerator: 'gpu' devices: 8 log_gpu_memory: all max_epochs: 2 diff --git a/examples/images/diffusion/ldm/data/cifar10.py b/examples/images/diffusion/ldm/data/cifar10.py index 53cd61263b47..5bac752e4648 100644 --- a/examples/images/diffusion/ldm/data/cifar10.py +++ b/examples/images/diffusion/ldm/data/cifar10.py @@ -1,15 +1,17 @@ +import json +from pathlib import Path from typing import Dict + import numpy as np -from omegaconf import DictConfig, ListConfig import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms +from datasets import load_dataset from einops import rearrange from ldm.util import instantiate_from_config -from datasets import load_dataset +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders @@ -22,7 +24,7 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): assert caption_files is None, \ "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): - list_of_paths.extend([folder_path]*repeats) + list_of_paths.extend([folder_path] * repeats) paths = list_of_paths if caption_files is not None: @@ -31,8 +33,11 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) + class FolderData(Dataset): - def __init__(self, + + def __init__( + self, root_dir, caption_file=None, image_transforms=[], @@ -40,7 +45,7 @@ def __init__(self, default_caption="", postprocess=None, return_paths=False, - ) -> None: + ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) @@ -75,12 +80,12 @@ def __init__(self, self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms - def __len__(self): if self.captions is not None: return len(self.captions.keys()) @@ -94,7 +99,7 @@ def __getitem__(self, index): caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption - filename = self.root_dir/chosen + filename = self.root_dir / chosen else: filename = self.paths[index] @@ -119,6 +124,7 @@ def process_im(self, im): im = im.convert("RGB") return self.tform(im) + def hf_dataset( name, image_transforms=[], @@ -128,13 +134,14 @@ def hf_dataset( split='train', image_key='image', caption_key='txt', - ): +): """Make huggingface dataset with appropriate list of transforms applied """ ds = load_dataset(name, split=split) image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) tform = transforms.Compose(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" @@ -144,7 +151,18 @@ def pre_process(examples): processed = {} processed[image_key] = [tform(im) for im in examples[image_column]] - label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"} + label_to_text_dict = { + 0: "airplane", + 1: "automobile", + 2: "bird", + 3: "cat", + 4: "deer", + 5: "dog", + 6: "frog", + 7: "horse", + 8: "ship", + 9: "truck" + } processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]] @@ -153,7 +171,9 @@ def pre_process(examples): ds.set_transform(pre_process) return ds + class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): """Returns only captions with dummy images""" self.output_size = output_size @@ -166,7 +186,7 @@ def __init__(self, captions, output_size, image_key="image", caption_key="txt", if n_gpus > 1: # hack to make sure that all the captions appear on each gpu - repeated = [n_gpus*[x] for x in self.captions] + repeated = [n_gpus * [x] for x in self.captions] self.captions = [] [self.captions.extend(x) for x in repeated] @@ -181,4 +201,4 @@ def __getitem__(self, index): def _load_caption_file(self, filename): with open(filename, 'rt') as f: captions = f.readlines() - return [x.strip('\n') for x in captions] \ No newline at end of file + return [x.strip('\n') for x in captions] diff --git a/examples/images/diffusion/ldm/data/imagenet.py b/examples/images/diffusion/ldm/data/imagenet.py index 1c473f9c6965..81591ad84204 100644 --- a/examples/images/diffusion/ldm/data/imagenet.py +++ b/examples/images/diffusion/ldm/data/imagenet.py @@ -1,35 +1,39 @@ -import os, yaml, pickle, shutil, tarfile, glob -import cv2 +import glob +import os +import pickle +import shutil +import tarfile +from functools import partial + import albumentations -import PIL +import cv2 import numpy as np +import PIL +import taming.data.utils as tdu import torchvision.transforms.functional as TF +import yaml +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from omegaconf import OmegaConf -from functools import partial from PIL import Image -from tqdm import tqdm +from taming.data.imagenet import ImagePaths, download, give_synsets_from_indices, retrieve, str_to_indices from torch.utils.data import Dataset, Subset - -import taming.data.utils as tdu -from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve -from taming.data.imagenet import ImagePaths - -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light +from tqdm import tqdm def synset2idx(path_to_yaml="data/index_synset.yaml"): with open(path_to_yaml) as f: di2s = yaml.load(f) - return dict((v,k) for k,v in di2s.items()) + return dict((v, k) for k, v in di2s.items()) class ImageNetBase(Dataset): + def __init__(self, config=None): self.config = config or OmegaConf.create() - if not type(self.config)==dict: + if not type(self.config) == dict: self.config = OmegaConf.to_container(self.config) self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) - self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths self._prepare() self._prepare_synset_to_human() self._prepare_idx_to_synset() @@ -52,7 +56,7 @@ def _filter_relpaths(self, relpaths): relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] if "sub_indices" in self.config: indices = str_to_indices(self.config["sub_indices"]) - synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) files = [] for rpath in relpaths: @@ -67,8 +71,7 @@ def _prepare_synset_to_human(self): SIZE = 2655750 URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): + if (not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE): download(URL, self.human_dict) def _prepare_idx_to_synset(self): @@ -122,11 +125,12 @@ def _load(self): if self.process_images: self.size = retrieve(self.config, "size", default=256) - self.data = ImagePaths(self.abspaths, - labels=labels, - size=self.size, - random_crop=self.random_crop, - ) + self.data = ImagePaths( + self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) else: self.data = self.abspaths @@ -157,8 +161,7 @@ def _prepare(self): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 1281167 - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -166,7 +169,7 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -187,7 +190,7 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) @@ -222,8 +225,7 @@ def _prepare(self): self.datadir = os.path.join(self.root, "data") self.txt_filelist = os.path.join(self.root, "filelist.txt") self.expected_length = 50000 - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) if not tdu.is_prepared(self.root): # prep print("Preparing dataset {} in {}".format(self.NAME, self.root)) @@ -231,7 +233,7 @@ def _prepare(self): datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + if not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0]: import academictorrents as at atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path @@ -242,7 +244,7 @@ def _prepare(self): tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + if not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1]: download(self.VS_URL, vspath) with open(vspath, "r") as f: @@ -261,18 +263,16 @@ def _prepare(self): filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" + filelist = "\n".join(filelist) + "\n" with open(self.txt_filelist, "w") as f: f.write(filelist) tdu.mark_prepared(self.root) - class ImageNetSR(Dataset): - def __init__(self, size=None, - degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., - random_crop=True): + + def __init__(self, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., random_crop=True): """ Imagenet Superresolution Dataloader Performs following ops in order: @@ -296,12 +296,12 @@ def __init__(self, size=None, self.LR_size = int(size / downscale_f) self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f - assert(max_crop_f <= 1.) + assert (max_crop_f <= 1.) self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) - self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow if degradation == "bsrgan": self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) @@ -311,17 +311,17 @@ def __init__(self, size=None, else: interpolation_fn = { - "cv_nearest": cv2.INTER_NEAREST, - "cv_bilinear": cv2.INTER_LINEAR, - "cv_bicubic": cv2.INTER_CUBIC, - "cv_area": cv2.INTER_AREA, - "cv_lanczos": cv2.INTER_LANCZOS4, - "pil_nearest": PIL.Image.NEAREST, - "pil_bilinear": PIL.Image.BILINEAR, - "pil_bicubic": PIL.Image.BICUBIC, - "pil_box": PIL.Image.BOX, - "pil_hamming": PIL.Image.HAMMING, - "pil_lanczos": PIL.Image.LANCZOS, + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, }[degradation] self.pil_interpolation = degradation.startswith("pil_") @@ -366,13 +366,14 @@ def __getitem__(self, i): else: LR_image = self.degradation_process(image=image)["image"] - example["image"] = (image/127.5 - 1.0).astype(np.float32) - example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image / 127.5 - 1.0).astype(np.float32) return example class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): super().__init__(**kwargs) @@ -384,6 +385,7 @@ def get_base(self): class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/examples/images/diffusion/ldm/data/lsun.py b/examples/images/diffusion/ldm/data/lsun.py index 6256e45715ff..4dba734b39fb 100644 --- a/examples/images/diffusion/ldm/data/lsun.py +++ b/examples/images/diffusion/ldm/data/lsun.py @@ -1,4 +1,5 @@ import os + import numpy as np import PIL from PIL import Image @@ -7,13 +8,8 @@ class LSUNBase(Dataset): - def __init__(self, - txt_file, - data_root, - size=None, - interpolation="bicubic", - flip_p=0.5 - ): + + def __init__(self, txt_file, data_root, size=None, interpolation="bicubic", flip_p=0.5): self.data_paths = txt_file self.data_root = data_root with open(self.data_paths, "r") as f: @@ -21,16 +17,16 @@ def __init__(self, self._length = len(self.image_paths) self.labels = { "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) for l in self.image_paths], } self.size = size - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): @@ -46,8 +42,7 @@ def __getitem__(self, i): img = np.array(image).astype(np.uint8) crop = min(img.shape[0], img.shape[1]) h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] + img = img[(h - crop) // 2:(h + crop) // 2, (w - crop) // 2:(w + crop) // 2] image = Image.fromarray(img) if self.size is not None: @@ -60,33 +55,39 @@ def __getitem__(self, i): class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", - flip_p=flip_p, **kwargs) + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", + data_root="data/lsun/churches", + flip_p=flip_p, + **kwargs) class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", - flip_p=flip_p, **kwargs) + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", flip_p=flip_p, **kwargs) class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", - flip_p=flip_p, **kwargs) + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", flip_p=flip_p, **kwargs) diff --git a/examples/images/diffusion/ldm/data/teyvat.py b/examples/images/diffusion/ldm/data/teyvat.py index 61dc29d56e7c..4661e9e1ae35 100644 --- a/examples/images/diffusion/ldm/data/teyvat.py +++ b/examples/images/diffusion/ldm/data/teyvat.py @@ -1,15 +1,17 @@ +import json +from pathlib import Path from typing import Dict + import numpy as np -from omegaconf import DictConfig, ListConfig import torch -from torch.utils.data import Dataset -from pathlib import Path -import json -from PIL import Image -from torchvision import transforms +from datasets import load_dataset from einops import rearrange from ldm.util import instantiate_from_config -from datasets import load_dataset +from omegaconf import DictConfig, ListConfig +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + def make_multi_folder_data(paths, caption_files=None, **kwargs): """Make a concat dataset from multiple folders @@ -22,7 +24,7 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): assert caption_files is None, \ "Caption files not yet supported for repeats" for folder_path, repeats in paths.items(): - list_of_paths.extend([folder_path]*repeats) + list_of_paths.extend([folder_path] * repeats) paths = list_of_paths if caption_files is not None: @@ -31,8 +33,11 @@ def make_multi_folder_data(paths, caption_files=None, **kwargs): datasets = [FolderData(p, **kwargs) for p in paths] return torch.utils.data.ConcatDataset(datasets) + class FolderData(Dataset): - def __init__(self, + + def __init__( + self, root_dir, caption_file=None, image_transforms=[], @@ -40,7 +45,7 @@ def __init__(self, default_caption="", postprocess=None, return_paths=False, - ) -> None: + ) -> None: """Create a dataset from a folder of images. If you pass in a root directory it will be searched for images ending in ext (ext can be a list) @@ -75,12 +80,12 @@ def __init__(self, self.paths.extend(list(self.root_dir.rglob(f"*.{e}"))) if isinstance(image_transforms, ListConfig): image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms.extend( + [transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) image_transforms = transforms.Compose(image_transforms) self.tform = image_transforms - def __len__(self): if self.captions is not None: return len(self.captions.keys()) @@ -94,7 +99,7 @@ def __getitem__(self, index): caption = self.captions.get(chosen, None) if caption is None: caption = self.default_caption - filename = self.root_dir/chosen + filename = self.root_dir / chosen else: filename = self.paths[index] @@ -119,23 +124,25 @@ def process_im(self, im): im = im.convert("RGB") return self.tform(im) + def hf_dataset( - path = "Fazzie/Teyvat", + path="Fazzie/Teyvat", image_transforms=[], image_column="image", text_column="text", image_key='image', caption_key='txt', - ): +): """Make huggingface dataset with appropriate list of transforms applied """ ds = load_dataset(path, name="train") ds = ds["train"] image_transforms = [instantiate_from_config(tt) for tt in image_transforms] - image_transforms.extend([transforms.Resize((256, 256)), - transforms.ToTensor(), - transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))] - ) + image_transforms.extend([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c')) + ]) tform = transforms.Compose(image_transforms) assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" @@ -149,4 +156,4 @@ def pre_process(examples): return processed ds.set_transform(pre_process) - return ds \ No newline at end of file + return ds diff --git a/examples/images/diffusion/ldm/lr_scheduler.py b/examples/images/diffusion/ldm/lr_scheduler.py index be39da9ca6da..3bbb5b877f4a 100644 --- a/examples/images/diffusion/ldm/lr_scheduler.py +++ b/examples/images/diffusion/ldm/lr_scheduler.py @@ -5,6 +5,7 @@ class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start @@ -16,7 +17,8 @@ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, ver def schedule(self, n, **kwargs): if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr @@ -24,13 +26,12 @@ def schedule(self, n, **kwargs): else: t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) self.last_lr = lr return lr def __call__(self, n, **kwargs): - return self.schedule(n,**kwargs) + return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: @@ -38,6 +39,7 @@ class LambdaWarmUpCosineScheduler2: supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) self.lr_warm_up_steps = warm_up_steps @@ -60,8 +62,9 @@ def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f @@ -69,8 +72,7 @@ def schedule(self, n, **kwargs): else: t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi)) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) self.last_f = f return f @@ -84,15 +86,16 @@ def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - + n) / (self.cycle_lengths[cycle]) self.last_f = f return f - diff --git a/examples/images/diffusion/ldm/models/autoencoder.py b/examples/images/diffusion/ldm/models/autoencoder.py index b1bd8377835b..6b8f49de1e60 100644 --- a/examples/images/diffusion/ldm/models/autoencoder.py +++ b/examples/images/diffusion/ldm/models/autoencoder.py @@ -1,20 +1,21 @@ import torch + try: import lightning.pytorch as pl except: import pytorch_lightning as pl -import torch.nn.functional as F from contextlib import contextmanager -from ldm.modules.diffusionmodules.model import Encoder, Decoder +import torch.nn.functional as F +from ldm.modules.diffusionmodules.model import Decoder, Encoder from ldm.modules.distributions.distributions import DiagonalGaussianDistribution - -from ldm.util import instantiate_from_config from ldm.modules.ema import LitEma +from ldm.util import instantiate_from_config class AutoencoderKL(pl.LightningModule): + def __init__(self, ddconfig, lossconfig, @@ -25,8 +26,7 @@ def __init__(self, colorize_nlabels=None, monitor=None, ema_decay=None, - learn_logvar=False - ): + learn_logvar=False): super().__init__() self.learn_logvar = learn_logvar self.image_key = image_key @@ -34,11 +34,11 @@ def __init__(self, self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert type(colorize_nlabels)==int + assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor @@ -116,16 +116,26 @@ def training_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + aeloss, log_dict_ae = self.loss(inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train") self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) return aeloss if optimizer_idx == 1: # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") + discloss, log_dict_disc = self.loss(inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train") self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) @@ -140,11 +150,21 @@ def validation_step(self, batch, batch_idx): def _validation_step(self, batch, batch_idx, postfix=""): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val"+postfix) + aeloss, log_dict_ae = self.loss(inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix) + + discloss, log_dict_disc = self.loss(inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix) self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) self.log_dict(log_dict_ae) @@ -158,10 +178,8 @@ def configure_optimizers(self): if self.learn_logvar: print(f"{self.__class__.__name__}: Learning logvar") ae_params_list.append(self.loss.logvar) - opt_ae = torch.optim.Adam(ae_params_list, - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) return [opt_ae, opt_disc], [] def get_last_layer(self): @@ -198,11 +216,12 @@ def to_rgb(self, x): if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): self.vq_interface = vq_interface super().__init__() @@ -220,4 +239,3 @@ def quantize(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs): return x - diff --git a/examples/images/diffusion/ldm/models/diffusion/classifier.py b/examples/images/diffusion/ldm/models/diffusion/classifier.py index 612a8371bf20..25cd605e720b 100644 --- a/examples/images/diffusion/ldm/models/diffusion/classifier.py +++ b/examples/images/diffusion/ldm/models/diffusion/classifier.py @@ -1,22 +1,19 @@ import os -import torch +from copy import deepcopy +from glob import glob + import lightning.pytorch as pl +import torch +from einops import rearrange +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import default, instantiate_from_config, ismap, log_txt_as_img +from natsort import natsorted from omegaconf import OmegaConf from torch.nn import functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR -from copy import deepcopy -from einops import rearrange -from glob import glob -from natsort import natsorted - -from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config -__models__ = { - 'class_label': EncoderUNetModel, - 'segmentation': UNetModel -} +__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel} def disabled_train(self, mode=True): @@ -114,7 +111,9 @@ def get_x_noisy(self, x, t, noise=None): continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) # todo: make sure t+1 is correct here - return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + return self.diffusion_model.q_sample(x_start=x, + t=t, + noise=noise, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) def forward(self, x_noisy, t, *args, **kwargs): @@ -163,12 +162,8 @@ def write_logs(self, loss, logits, targets): log_prefix = 'train' if self.training else 'val' log = {} log[f"{log_prefix}/loss"] = loss.mean() - log[f"{log_prefix}/acc@1"] = self.compute_top_k( - logits, targets, k=1, reduction="mean" - ) - log[f"{log_prefix}/acc@5"] = self.compute_top_k( - logits, targets, k=5, reduction="mean" - ) + log[f"{log_prefix}/acc@1"] = self.compute_top_k(logits, targets, k=1, reduction="mean") + log[f"{log_prefix}/acc@5"] = self.compute_top_k(logits, targets, k=5, reduction="mean") self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) @@ -200,8 +195,12 @@ def training_step(self, batch, batch_idx): return loss def reset_noise_accs(self): - self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in - range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + self.noisy_acc = { + t: { + 'acc@1': [], + 'acc@5': [] + } for t in range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t) + } def on_validation_start(self): self.reset_noise_accs() @@ -224,12 +223,11 @@ def configure_optimizers(self): scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] + scheduler = [{ + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] return [optimizer], scheduler return optimizer diff --git a/examples/images/diffusion/ldm/models/diffusion/ddim.py b/examples/images/diffusion/ldm/models/diffusion/ddim.py index 27ead0ea914c..8eef27f5a1d9 100644 --- a/examples/images/diffusion/ldm/models/diffusion/ddim.py +++ b/examples/images/diffusion/ldm/models/diffusion/ddim.py @@ -1,13 +1,18 @@ """SAMPLING ONLY.""" -import torch import numpy as np +import torch +from ldm.modules.diffusionmodules.util import ( + extract_into_tensor, + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) from tqdm import tqdm -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor - class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model @@ -21,8 +26,10 @@ def register_buffer(self, name, attr): setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose) alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) @@ -41,46 +48,48 @@ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) + eta=ddim_eta, + verbose=verbose) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * + (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - ucg_schedule=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs): if conditioning is not None: if isinstance(conditioning, dict): ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): ctmp = ctmp[0] + while isinstance(ctmp, list): + ctmp = ctmp[0] cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") @@ -100,11 +109,13 @@ def sample(self, size = (batch_size, C, H, W) print(f'Data shape for DDIM sampling is {size}, eta {eta}') - samples, intermediates = self.ddim_sampling(conditioning, size, + samples, intermediates = self.ddim_sampling(conditioning, + size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, - mask=mask, x0=x0, + mask=mask, + x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, @@ -115,17 +126,29 @@ def sample(self, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) + ucg_schedule=ucg_schedule) return samples, intermediates @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + def ddim_sampling(self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + dynamic_threshold=None, ucg_schedule=None): device = self.model.betas.device b = shape[0] @@ -141,7 +164,7 @@ def ddim_sampling(self, cond, shape, timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running DDIM Sampling with {total_steps} timesteps") @@ -153,23 +176,31 @@ def ddim_sampling(self, cond, shape, if mask is not None: assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img = img_orig * mask + (1. - mask) * img if ucg_schedule is not None: assert len(ucg_schedule) == len(time_range) unconditional_guidance_scale = ucg_schedule[i] - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, + outs = self.p_sample_ddim(img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, dynamic_threshold=dynamic_threshold) img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) @@ -178,9 +209,20 @@ def ddim_sampling(self, cond, shape, return img, intermediates @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, + def p_sample_ddim(self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, dynamic_threshold=None): b, *_, device = *x.shape, x.device @@ -194,13 +236,9 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F c_in = dict() for k in c: if isinstance(c[k], list): - c_in[k] = [torch.cat([ - unconditional_conditioning[k][i], - c[k][i]]) for i in range(len(c[k]))] + c_in[k] = [torch.cat([unconditional_conditioning[k][i], c[k][i]]) for i in range(len(c[k]))] else: - c_in[k] = torch.cat([ - unconditional_conditioning[k], - c[k]]) + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) elif isinstance(c, list): c_in = list() assert isinstance(unconditional_conditioning, list) @@ -228,7 +266,7 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 if self.model.parameterization != "v": @@ -251,8 +289,15 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0 @torch.no_grad() - def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, - unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + def encode(self, + x0, + c, + t_enc, + use_original_steps=False, + return_intermediates=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + callback=None): num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] assert t_enc <= num_reference_steps @@ -280,17 +325,17 @@ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=No noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next - weighted_noise_pred = alphas_next[i].sqrt() * ( - (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + weighted_noise_pred = alphas_next[i].sqrt() * ((1 / alphas_next[i] - 1).sqrt() - + (1 / alphas[i] - 1).sqrt()) * noise_pred x_next = xt_weighted + weighted_noise_pred - if return_intermediates and i % ( - num_steps // return_intermediates) == 0 and i < num_steps - 1: + if return_intermediates and i % (num_steps // return_intermediates) == 0 and i < num_steps - 1: intermediates.append(x_next) inter_steps.append(i) elif return_intermediates and i >= num_steps - 2: intermediates.append(x_next) inter_steps.append(i) - if callback: callback(i) + if callback: + callback(i) out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} if return_intermediates: @@ -314,8 +359,14 @@ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False, callback=None): + def decode(self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + callback=None): timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = timesteps[:t_start] @@ -329,8 +380,13 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + x_dec, _ = self.p_sample_ddim(x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) - if callback: callback(i) - return x_dec \ No newline at end of file + if callback: + callback(i) + return x_dec diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py index 7427f38c0753..f56611cb5fb3 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/__init__.py @@ -1 +1 @@ -from .sampler import DPMSolverSampler \ No newline at end of file +from .sampler import DPMSolverSampler diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py index 095e5ba3ce0b..a2481b86bb6a 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -1,17 +1,19 @@ +import math + import torch import torch.nn.functional as F -import math from tqdm import tqdm class NoiseScheduleVP: + def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., ): """Create a wrapper class for the forward SDE (VP type). *** @@ -85,15 +87,18 @@ def __init__( self.total_N = len(log_alphas) self.T = 1. self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) + self.log_alpha_array = log_alphas.reshape(( + 1, + -1, + )) else: self.total_N = 1000 self.beta_0 = continuous_beta_0 self.beta_1 = continuous_beta_1 self.cosine_s = 0.008 self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / + math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) self.schedule = schedule if schedule == 'cosine': @@ -111,7 +116,7 @@ def marginal_log_mean_coeff(self, t): return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 elif self.schedule == 'cosine': log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 @@ -143,7 +148,7 @@ def inverse_lambda(self, lamb): """ if self.schedule == 'linear': tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0 ** 2 + tmp + Delta = self.beta_0**2 + tmp return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) elif self.schedule == 'discrete': log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) @@ -153,22 +158,22 @@ def inverse_lambda(self, lamb): else: log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( - 1. + self.cosine_s) / math.pi - self.cosine_s + 1. + self.cosine_s) / math.pi - self.cosine_s t = t_fn(log_alpha) return t def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, ): """Create a wrapper function for the noise prediction model. DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to @@ -317,6 +322,7 @@ def model_fn(x, t_continuous): class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): """Construct a DPM-Solver. We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). @@ -358,7 +364,7 @@ def data_prediction_fn(self, x, t): alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) if self.thresholding: - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) x0 = torch.clamp(x0, -s, s) / s @@ -396,7 +402,7 @@ def get_time_steps(self, skip_type, t_T, t_0, N, device): return torch.linspace(t_T, t_0, N + 1).to(device) elif skip_type == 'time_quadratic': t_order = 2 - t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) return t else: raise ValueError( @@ -435,29 +441,43 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3, ] * (K - 2) + [2, 1] + orders = [ + 3, + ] * (K - 2) + [2, 1] elif steps % 3 == 1: - orders = [3, ] * (K - 1) + [1] + orders = [ + 3, + ] * (K - 1) + [1] else: - orders = [3, ] * (K - 1) + [2] + orders = [ + 3, + ] * (K - 1) + [2] elif order == 2: if steps % 2 == 0: K = steps // 2 - orders = [2, ] * K + orders = [ + 2, + ] * K else: K = steps // 2 + 1 - orders = [2, ] * (K - 1) + [1] + orders = [ + 2, + ] * (K - 1) + [1] elif order == 1: K = 1 - orders = [1, ] * steps + orders = [ + 1, + ] * steps else: raise ValueError("'order' must be '1' or '2' or '3'.") if skip_type == 'logSNR': # To reproduce the results in DPM-Solver paper timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, + device)[torch.cumsum(torch.tensor([ + 0, + ] + orders)).to(device)] return timesteps_outer, orders def denoise_to_zero_fn(self, x, s): @@ -491,10 +511,7 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal phi_1 = torch.expm1(-h) if model_s is None: model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) + x_t = (expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s) if return_intermediate: return x_t, {'model_s': model_s} else: @@ -503,16 +520,20 @@ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=Fal phi_1 = torch.expm1(h) if model_s is None: model_s = self.model_fn(x, s) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - ) + x_t = (expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s) if return_intermediate: return x_t, {'model_s': model_s} else: return x_t - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + def singlestep_dpm_solver_second_update(self, + x, + s, + t, + r1=0.5, + model_s=None, + return_intermediate=False, solver_type='dpm_solver'): """ Singlestep solver DPM-Solver-2 from time `s` to time `t`. @@ -550,54 +571,46 @@ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, ret if model_s is None: model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = (expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s) model_s1 = self.model_fn(x_s1, s1) if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) - ) + x_t = (expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s - + (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)) elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( - model_s1 - model_s) - ) + x_t = (expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s)) else: phi_11 = torch.expm1(r1 * h) phi_1 = torch.expm1(h) if model_s is None: model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) + x_s1 = (expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - + expand_dims(sigma_s1 * phi_11, dims) * model_s) model_s1 = self.model_fn(x_s1, s1) if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) - ) + x_t = (expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * + (model_s1 - model_s)) elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) - ) + x_t = (expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)) if return_intermediate: return x_t, {'model_s': model_s, 'model_s1': model_s1} else: return x_t - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, - return_intermediate=False, solver_type='dpm_solver'): + def singlestep_dpm_solver_third_update(self, + x, + s, + t, + r1=1. / 3., + r2=2. / 3., + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type='dpm_solver'): """ Singlestep solver DPM-Solver-3 from time `s` to time `t`. Args: @@ -647,34 +660,21 @@ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., mo if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) + x_s1 = (expand_dims(sigma_s1 / sigma_s, dims) * x - expand_dims(alpha_s1 * phi_11, dims) * model_s) model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) - ) + x_s2 = (expand_dims(sigma_s2 / sigma_s, dims) * x - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)) model_s2 = self.model_fn(x_s2, s2) if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) - ) + x_t = (expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)) elif solver_type == 'taylor': D1_0 = (1. / r1) * (model_s1 - model_s) D1_1 = (1. / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 - ) + x_t = (expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 - expand_dims(alpha_t * phi_3, dims) * D2) else: phi_11 = torch.expm1(r1 * h) phi_12 = torch.expm1(r2 * h) @@ -686,34 +686,25 @@ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., mo if model_s is None: model_s = self.model_fn(x, s) if model_s1 is None: - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) + x_s1 = (expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - + expand_dims(sigma_s1 * phi_11, dims) * model_s) model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) - ) + x_s2 = (expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - + expand_dims(sigma_s2 * phi_12, dims) * model_s - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * + (model_s1 - model_s)) model_s2 = self.model_fn(x_s2, s2) if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) - ) + x_t = (expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * + (model_s2 - model_s)) elif solver_type == 'taylor': D1_0 = (1. / r1) * (model_s1 - model_s) D1_1 = (1. / r2) * (model_s2 - model_s) D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 - ) + x_t = (expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - + expand_dims(sigma_t * phi_1, dims) * model_s - expand_dims(sigma_t * phi_2, dims) * D1 - + expand_dims(sigma_t * phi_3, dims) * D2) if return_intermediate: return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} @@ -751,30 +742,22 @@ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) if self.predict_x0: if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 - ) + x_t = (expand_dims(sigma_t / sigma_prev_0, dims) * x - + expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0) elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 - ) + x_t = (expand_dims(sigma_t / sigma_prev_0, dims) * x - + expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0) else: if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 - ) + x_t = (expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - + expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - + 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0) elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 - ) + x_t = (expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - + expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - + expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0) return x_t def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): @@ -809,22 +792,25 @@ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) if self.predict_x0: - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 - ) + x_t = (expand_dims(sigma_t / sigma_prev_0, dims) * x - + expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - + expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2) else: - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 - ) + x_t = (expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - + expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - + expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - + expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2) return x_t - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + def singlestep_dpm_solver_update(self, + x, + s, + t, + order, + return_intermediate=False, + solver_type='dpm_solver', + r1=None, r2=None): """ Singlestep DPM-Solver with the order `order` from time `s` to time `t`. @@ -844,11 +830,20 @@ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False if order == 1: return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1) + return self.singlestep_dpm_solver_second_update(x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1) elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, - solver_type=solver_type, r1=r1, r2=r2) + return self.singlestep_dpm_solver_third_update(x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + r2=r2) else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) @@ -875,7 +870,16 @@ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, else: raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + def dpm_solver_adaptive(self, + x, + order, + t_T, + t_0, + h_init=0.05, + atol=0.0078, + rtol=0.05, + theta=0.9, + t_err=1e-5, solver_type='dpm_solver'): """ The adaptive step size solver based on singlestep DPM-Solver. @@ -906,17 +910,14 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol if order == 2: r1 = 0.5 lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - solver_type=solver_type, - **kwargs) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, solver_type=solver_type, **kwargs) elif order == 3: r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, - return_intermediate=True, - solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, - solver_type=solver_type, - **kwargs) + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update( + x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) else: raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) while torch.abs((s - t_0)).mean() > t_err: @@ -936,10 +937,21 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol print('adaptive solver nfe', nfe) return x - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=3, + skip_type='time_uniform', + method='singlestep', + lower_order_final=True, + denoise_to_zero=False, + solver_type='dpm_solver', + atol=0.0078, + rtol=0.05, + ): """ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. ===================================================== @@ -1039,7 +1051,12 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time device = x.device if method == 'adaptive': with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + x = self.dpm_solver_adaptive(x, + order=order, + t_T=t_T, + t_0=t_0, + atol=atol, + rtol=rtol, solver_type=solver_type) elif method == 'multistep': assert steps >= order @@ -1052,7 +1069,11 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # Init the first `order` values by lower order multistep DPM-Solver. for init_order in tqdm(range(1, order), desc="DPM init order"): vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + x = self.multistep_dpm_solver_update(x, + model_prev_list, + t_prev_list, + vec_t, + init_order, solver_type=solver_type) model_prev_list.append(self.model_fn(x, vec_t)) t_prev_list.append(vec_t) @@ -1063,7 +1084,11 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time step_order = min(order, steps + 1 - step) else: step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + x = self.multistep_dpm_solver_update(x, + model_prev_list, + t_prev_list, + vec_t, + step_order, solver_type=solver_type) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] @@ -1074,18 +1099,25 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time model_prev_list[-1] = self.model_fn(x, vec_t) elif method in ['singlestep', 'singlestep_fixed']: if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, + order=order, skip_type=skip_type, - t_T=t_T, t_0=t_0, + t_T=t_T, + t_0=t_0, device=device) elif method == 'singlestep_fixed': K = steps // order - orders = [order, ] * K + orders = [ + order, + ] * K timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) for i, order in enumerate(orders): t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), - N=order, device=device) + timesteps_inner = self.get_time_steps(skip_type=skip_type, + t_T=t_T_inner.item(), + t_0=t_0_inner.item(), + N=order, + device=device) lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) h = lambda_inner[-1] - lambda_inner[0] @@ -1101,6 +1133,7 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # other utility functions ############################################################# + def interpolate_fn(x, xp, yp): """ A piecewise linear function y = f(x), using xp and yp as keypoints. @@ -1122,7 +1155,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(1, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) @@ -1132,7 +1167,9 @@ def interpolate_fn(x, xp, yp): torch.eq(x_idx, 0), torch.tensor(0, device=x.device), torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, ), ) y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) @@ -1151,4 +1188,4 @@ def expand_dims(v, dims): Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file + return v[(...,) + (None,) * (dims - 1)] diff --git a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py index 7d137b8cf367..21350650cb1f 100644 --- a/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py +++ b/examples/images/diffusion/ldm/models/diffusion/dpm_solver/sampler.py @@ -1,16 +1,13 @@ """SAMPLING ONLY.""" import torch -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver +from .dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper - -MODEL_TYPES = { - "eps": "noise", - "v": "v" -} +MODEL_TYPES = {"eps": "noise", "v": "v"} class DPMSolverSampler(object): + def __init__(self, model, **kwargs): super().__init__() self.model = model @@ -24,30 +21,30 @@ def register_buffer(self, name, attr): setattr(self, name, attr) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -82,6 +79,11 @@ def sample(self, ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + x = dpm_solver.sample(img, + steps=S, + skip_type="time_uniform", + method="multistep", + order=2, + lower_order_final=True) - return x.to(device), None \ No newline at end of file + return x.to(device), None diff --git a/examples/images/diffusion/ldm/models/diffusion/plms.py b/examples/images/diffusion/ldm/models/diffusion/plms.py index 7002a365d271..e4f9c56efd57 100644 --- a/examples/images/diffusion/ldm/models/diffusion/plms.py +++ b/examples/images/diffusion/ldm/models/diffusion/plms.py @@ -1,15 +1,16 @@ """SAMPLING ONLY.""" -import torch -import numpy as np -from tqdm import tqdm from functools import partial -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +import numpy as np +import torch from ldm.models.diffusion.sampling_util import norm_thresholding +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from tqdm import tqdm class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model @@ -25,8 +26,10 @@ def register_buffer(self, name, attr): def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): if ddim_eta != 0: raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose) alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) @@ -45,42 +48,43 @@ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) + eta=ddim_eta, + verbose=verbose) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * + (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - dynamic_threshold=None, - **kwargs - ): + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] @@ -96,31 +100,46 @@ def sample(self, size = (batch_size, C, H, W) print(f'Data shape for PLMS sampling is {size}') - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) + samples, intermediates = self.plms_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) return samples, intermediates @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, + def plms_sampling(self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, dynamic_threshold=None): device = self.model.betas.device b = shape[0] @@ -136,7 +155,7 @@ def plms_sampling(self, cond, shape, timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] print(f"Running PLMS Sampling with {total_steps} timesteps") @@ -150,23 +169,32 @@ def plms_sampling(self, cond, shape, if mask is not None: assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img = img_orig * mask + (1. - mask) * img - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, + outs = self.p_sample_plms(img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next, + old_eps=old_eps, + t_next=ts_next, dynamic_threshold=dynamic_threshold) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) @@ -175,9 +203,22 @@ def plms_sampling(self, cond, shape, return img, intermediates @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + def p_sample_plms(self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + old_eps=None, + t_next=None, dynamic_threshold=None): b, *_, device = *x.shape, x.device @@ -207,7 +248,7 @@ def get_x_prev_and_pred_x0(e_t, index): a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() diff --git a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py index 7eff02be6d7c..8c05c4e90ff9 100644 --- a/examples/images/diffusion/ldm/models/diffusion/sampling_util.py +++ b/examples/images/diffusion/ldm/models/diffusion/sampling_util.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch def append_dims(x, target_dims): @@ -19,4 +19,4 @@ def norm_thresholding(x0, value): def spatial_norm_thresholding(x0, value): # b c h w s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) - return x0 * (value / s) \ No newline at end of file + return x0 * (value / s) diff --git a/examples/images/diffusion/ldm/modules/attention.py b/examples/images/diffusion/ldm/modules/attention.py index d504d939f6a0..1a089f42bfeb 100644 --- a/examples/images/diffusion/ldm/modules/attention.py +++ b/examples/images/diffusion/ldm/modules/attention.py @@ -1,13 +1,12 @@ -from inspect import isfunction import math +from inspect import isfunction +from typing import Any, Optional + import torch import torch.nn.functional as F -from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional, Any - from ldm.modules.diffusionmodules.util import checkpoint - +from torch import einsum, nn try: import xformers @@ -22,7 +21,7 @@ def exists(val): def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -44,6 +43,7 @@ def init_(tensor): # feedforward class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) @@ -54,20 +54,14 @@ def forward(self, x): class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) @@ -87,31 +81,16 @@ def Normalize(in_channels): class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x @@ -121,7 +100,7 @@ def forward(self, x): v = self.v(h_) # compute attention - b,c,h,w = q.shape + b, c, h, w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) @@ -136,26 +115,24 @@ def forward(self, x): h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) - return x+h_ + return x + h_ class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads @@ -211,11 +188,8 @@ def forward(self, x, context=None, mask=None): b, _, _ = q.shape q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), + lambda t: t.unsqueeze(3).reshape(b, t.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( + b * self.heads, t.shape[1], self.dim_head).contiguous(), (q, k, v), ) @@ -224,32 +198,41 @@ def forward(self, x, context=None, mask=None): if exists(mask): raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) + out = (out.unsqueeze(0).reshape(b, self.heads, out.shape[1], + self.dim_head).permute(0, 2, 1, 3).reshape(b, out.shape[1], + self.heads * self.dim_head)) return self.to_out(out) class BasicTransformerBlock(nn.Module): ATTENTION_MODES = { - "softmax": CrossAttention, # vanilla attention + "softmax": CrossAttention, # vanilla attention "softmax-xformers": MemoryEfficientCrossAttention } - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, disable_self_attn=False): super().__init__() attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" assert attn_mode in self.ATTENTION_MODES attn_cls = self.ATTENTION_MODES[attn_mode] self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.attn1 = attn_cls(query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else + None) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, + dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) @@ -274,9 +257,16 @@ class SpatialTransformer(nn.Module): Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None, - disable_self_attn=False, use_linear=False, + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, use_checkpoint=True): super().__init__() if exists(context_dim) and not isinstance(context_dim, list): @@ -285,25 +275,21 @@ def __init__(self, in_channels, n_heads, d_head, inner_dim = n_heads * d_head self.norm = Normalize(in_channels) if not use_linear: - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) else: self.proj_in = nn.Linear(in_channels, inner_dim) - self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], - disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) - for d in range(depth)] - ) + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock(inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) if not use_linear: - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) else: self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) self.use_linear = use_linear @@ -328,4 +314,3 @@ def forward(self, x, context=None): if not self.use_linear: x = self.proj_out(x) return x + x_in - diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py index cd639d936046..22090d9c626a 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/openaimodel.py @@ -1,21 +1,20 @@ -from abc import abstractmethod import math +from abc import abstractmethod import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F - +from ldm.modules.attention import SpatialTransformer from ldm.modules.diffusionmodules.util import ( + avg_pool_nd, checkpoint, conv_nd, linear, - avg_pool_nd, - zero_module, normalization, timestep_embedding, + zero_module, ) -from ldm.modules.attention import SpatialTransformer from ldm.util import exists @@ -23,6 +22,7 @@ def convert_module_to_f16(x): pass + def convert_module_to_f32(x): pass @@ -41,7 +41,7 @@ def __init__( output_dim: int = None, ): super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels @@ -49,9 +49,9 @@ def __init__( def forward(self, x): b, c, *_spatial = x.shape - x = x.reshape(b, c, -1) # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) @@ -108,25 +108,25 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x + class TransposedUpsample(nn.Module): 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) - def forward(self,x): + def forward(self, x): return self.up(x) @@ -139,7 +139,7 @@ class Downsample(nn.Module): downsampling occurs in the inner-two dimensions. """ - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels @@ -147,9 +147,7 @@ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) @@ -225,17 +223,13 @@ def __init__( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) @@ -246,10 +240,7 @@ def forward(self, x, emb): :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) def _forward(self, x, emb): if self.updown: @@ -294,9 +285,8 @@ def __init__( if num_head_channels == -1: self.num_heads = num_heads else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + assert (channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint self.norm = normalization(channels) @@ -311,7 +301,8 @@ def __init__( self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + return checkpoint(self._forward, (x,), self.parameters(), + True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! #return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): @@ -339,7 +330,7 @@ def count_flops_attn(model, _x, y): # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c + matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) @@ -363,9 +354,7 @@ def forward(self, qkv): ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @@ -399,7 +388,7 @@ def forward(self, qkv): "bct,bcs->bts", (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards + ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) @@ -461,9 +450,9 @@ def __init__( resblock_updown=False, use_new_attention_order=False, use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, disable_self_attentions=None, num_attention_blocks=None, @@ -505,7 +494,8 @@ def __init__( assert len(disable_self_attentions) == len(channel_mult) if num_attention_blocks is not None: assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + assert all( + map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " f"This option has LESS priority than attention_resolutions {attention_resolutions}, " f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " @@ -540,12 +530,7 @@ def __init__( raise ValueError() self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels @@ -586,12 +571,15 @@ def __init__( num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint ) - ) + if not use_spatial_transformer else SpatialTransformer(ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint)) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) @@ -608,13 +596,7 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) + ) if resblock_updown else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch))) ch = out_ch input_block_chans.append(ch) ds *= 2 @@ -643,11 +625,15 @@ def __init__( num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint - ), + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint), ResBlock( ch, time_embed_dim, @@ -697,12 +683,15 @@ def __init__( num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint ) - ) + if not use_spatial_transformer else SpatialTransformer(ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint)) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( @@ -715,10 +704,7 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) + ) if resblock_updown else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch @@ -730,10 +716,10 @@ def __init__( ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) + ) def convert_to_fp16(self): """ @@ -751,7 +737,7 @@ def convert_to_fp32(self): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -760,9 +746,8 @@ def forward(self, x, timesteps=None, context=None, y=None,**kwargs): :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" + assert (y is not None) == (self.num_classes + is not None), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) t_emb = t_emb.type(self.dtype) diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py index 03816662098c..757b9e0ded57 100644 --- a/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py +++ b/examples/images/diffusion/ldm/modules/diffusionmodules/upscaling.py @@ -1,8 +1,8 @@ -import torch -import torch.nn as nn -import numpy as np from functools import partial +import numpy as np +import torch +import torch.nn as nn from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule from ldm.util import default @@ -14,9 +14,16 @@ def __init__(self, noise_schedule_config=None): if noise_schedule_config is not None: self.register_schedule(**noise_schedule_config) - def register_schedule(self, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + def register_schedule(self, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) @@ -65,6 +72,7 @@ def forward(self, x): class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): super().__init__(noise_schedule_config=noise_schedule_config) self.max_noise_level = max_noise_level @@ -76,6 +84,3 @@ def forward(self, x, noise_level=None): assert isinstance(noise_level, torch.Tensor) z = self.q_sample(x, noise_level) return z, noise_level - - - diff --git a/examples/images/diffusion/ldm/modules/distributions/distributions.py b/examples/images/diffusion/ldm/modules/distributions/distributions.py index f2b8ef901130..a08f15dc8cc7 100644 --- a/examples/images/diffusion/ldm/modules/distributions/distributions.py +++ b/examples/images/diffusion/ldm/modules/distributions/distributions.py @@ -1,8 +1,9 @@ -import torch import numpy as np +import torch class AbstractDistribution: + def sample(self): raise NotImplementedError() @@ -11,6 +12,7 @@ def mode(self): class DiracDistribution(AbstractDistribution): + def __init__(self, value): self.value = value @@ -22,6 +24,7 @@ def mode(self): class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) @@ -41,22 +44,17 @@ def kl(self, other=None): return torch.Tensor([0.]) else: if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) + return 0.5 * torch.sum(torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - + self.logvar + other.logvar, + dim=[1, 2, 3]) - def nll(self, sample, dims=[1,2,3]): + def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.]) logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) def mode(self): return self.mean @@ -78,15 +76,6 @@ def normal_kl(mean1, logvar1, mean2, logvar2): # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2)**2) * torch.exp(-logvar2)) diff --git a/examples/images/diffusion/ldm/modules/ema.py b/examples/images/diffusion/ldm/modules/ema.py index bded25019b9b..d0627b63d94b 100644 --- a/examples/images/diffusion/ldm/modules/ema.py +++ b/examples/images/diffusion/ldm/modules/ema.py @@ -3,6 +3,7 @@ class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: @@ -10,8 +11,8 @@ def __init__(self, model, decay=0.9999, use_num_upates=True): self.m_name2s_name = {} self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates - else torch.tensor(-1, dtype=torch.int)) + self.register_buffer('num_updates', + torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int)) for name, p in model.named_parameters(): if p.requires_grad: diff --git a/examples/images/diffusion/ldm/modules/encoders/modules.py b/examples/images/diffusion/ldm/modules/encoders/modules.py index 4edd5496b9e6..23a0f671223e 100644 --- a/examples/images/diffusion/ldm/modules/encoders/modules.py +++ b/examples/images/diffusion/ldm/modules/encoders/modules.py @@ -1,14 +1,13 @@ +import open_clip import torch import torch.nn as nn +from ldm.util import count_params, default from torch.utils.checkpoint import checkpoint - -from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel - -import open_clip -from ldm.util import default, count_params +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class AbstractEncoder(nn.Module): + def __init__(self): super().__init__() @@ -23,6 +22,7 @@ def encode(self, x): class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): super().__init__() self.key = key @@ -37,13 +37,13 @@ def forward(self, batch, key=None, disable_dropout=False): c = batch[key][:, None] if self.ucg_rate > 0. and not disable_dropout: mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) - c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) c = c.long() c = self.embedding(c) return c def get_unconditional_conditioning(self, bs, device="cuda"): - uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc = torch.ones((bs,), device=device) * uc_class uc = {self.key: uc} return uc @@ -57,12 +57,17 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + + def __init__(self, + version="google/t5-v1_1-large", + device="cuda", + max_length=77, + freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) self.device = device - self.max_length = max_length # TODO: typical value? + self.max_length = max_length # TODO: typical value? if freeze: self.freeze() @@ -73,8 +78,13 @@ def freeze(self): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer(text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) outputs = self.transformer(input_ids=tokens) @@ -87,13 +97,15 @@ def encode(self, text): class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" - LAYERS = [ - "last", - "pooled", - "hidden" - ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, - freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + LAYERS = ["last", "pooled", "hidden"] + + def __init__(self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS self.tokenizer = CLIPTokenizer.from_pretrained(version) @@ -115,10 +127,15 @@ def freeze(self): param.requires_grad = False def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + batch_encoding = self.tokenizer(text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") if self.layer == "last": z = outputs.last_hidden_state elif self.layer == "pooled": @@ -136,12 +153,18 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): Uses the OpenCLIP transformer encoder for text """ LAYERS = [ - #"pooled", + #"pooled", "last", "penultimate" ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, - freeze=True, layer="last"): + + def __init__(self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last"): super().__init__() assert layer in self.LAYERS model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) @@ -171,15 +194,15 @@ def forward(self, text): return z def encode_with_transformer(self, text): - x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND + x = x.permute(1, 0, 2) # NLD -> LND x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD + x = x.permute(1, 0, 2) # LND -> NLD x = self.model.ln_final(x) return x - def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): for i, r in enumerate(self.model.transformer.resblocks): if i == len(self.model.transformer.resblocks) - self.layer_idx: break @@ -194,8 +217,13 @@ def encode(self, text): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", - clip_max_length=77, t5_max_length=77): + + def __init__(self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) @@ -209,5 +237,3 @@ def forward(self, text): clip_z = self.clip_encoder.encode(text) t5_z = self.t5_encoder.encode(text) return [clip_z, t5_z] - - diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py index 32ef56169978..8b99a8ee3f31 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan.py @@ -10,20 +10,19 @@ # -------------------------------------------- """ -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util def modcrop_np(img, sf): @@ -152,18 +151,17 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set random eigen-vals (lambdas) and angle (theta) for COV matrix lambda_1 = min_var + np.random.rand() * (max_var - min_var) lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta + theta = np.random.rand() * np.pi # random theta noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) MU = MU[None, None, :, None] # Create meshgrid for Gaussian @@ -254,7 +252,7 @@ def srmd_degradation(x, k, sf=3): year={2018} } ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x @@ -338,9 +336,9 @@ def add_blur(img, sf=4): def add_resize(img, sf=4): rnum = np.random.rand() - if rnum > 0.8: # up + if rnum > 0.8: # up sf1 = random.uniform(1, 2) - elif rnum < 0.7: # down + elif rnum < 0.7: # down sf1 = random.uniform(0.5 / sf, 1) else: sf1 = 1.0 @@ -366,19 +364,20 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() - if rnum > 0.6: # add color Gaussian noise + if rnum > 0.6: # add color Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: # add grayscale Gaussian noise + elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: # add noise + else: # add noise L = noise_level2 / 255. D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -396,14 +395,14 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): img = np.clip((img * 255.0).round(), 0, 255) / 255. - vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + vals = 10**(2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: @@ -452,7 +451,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: @@ -460,7 +459,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): hq = img.copy() - if sf == 4 and random.random() < scale2_prob: # downsample1 + if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3])) @@ -471,7 +470,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last + if idx1 > idx2: # keep downsample3 last shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: @@ -492,9 +491,9 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') - img = img[0::sf, 0::sf, ...] # nearest downsampling + img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) elif i == 3: @@ -544,12 +543,12 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] hq = image.copy() - if sf == 4 and random.random() < scale2_prob: # downsample1 + if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), interpolation=random.choice([1, 2, 3])) @@ -560,7 +559,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last + if idx1 > idx2: # keep downsample3 last shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: @@ -581,9 +580,9 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') - image = image[0::sf, 0::sf, ...] # nearest downsampling + image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) elif i == 3: @@ -609,7 +608,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) - example = {"image":image} + example = {"image": image} return example @@ -630,7 +629,7 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc """ h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: @@ -702,29 +701,28 @@ def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patc if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') - - + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py index 808c7f882cb7..bf9fe222564e 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/bsrgan_light.py @@ -1,19 +1,17 @@ # -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from functools import partial import random -from scipy import ndimage +from functools import partial + +import albumentations +import cv2 +import ldm.modules.image_degradation.utils_image as util +import numpy as np import scipy import scipy.stats as ss +import torch +from scipy import ndimage from scipy.interpolate import interp2d from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util - """ # -------------------------------------------- # Super-Resolution @@ -25,6 +23,7 @@ # -------------------------------------------- """ + def modcrop_np(img, sf): ''' Args: @@ -151,18 +150,17 @@ def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var # Set random eigen-vals (lambdas) and angle (theta) for COV matrix lambda_1 = min_var + np.random.rand() * (max_var - min_var) lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta + theta = np.random.rand() * np.pi # random theta noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) + Q = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) MU = MU[None, None, :, None] # Create meshgrid for Gaussian @@ -253,7 +251,7 @@ def srmd_degradation(x, k, sf=3): year={2018} } ''' - x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x @@ -325,8 +323,8 @@ def add_blur(img, sf=4): wd2 = 4.0 + sf wd = 2.0 + 0.2 * sf - wd2 = wd2/4 - wd = wd/4 + wd2 = wd2 / 4 + wd = wd / 4 if random.random() < 0.5: l1 = wd2 * random.random() @@ -341,9 +339,9 @@ def add_blur(img, sf=4): def add_resize(img, sf=4): rnum = np.random.rand() - if rnum > 0.8: # up + if rnum > 0.8: # up sf1 = random.uniform(1, 2) - elif rnum < 0.7: # down + elif rnum < 0.7: # down sf1 = random.uniform(0.5 / sf, 1) else: sf1 = 1.0 @@ -369,19 +367,20 @@ def add_resize(img, sf=4): # img = np.clip(img, 0.0, 1.0) # return img + def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() - if rnum > 0.6: # add color Gaussian noise + if rnum > 0.6: # add color Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: # add grayscale Gaussian noise + elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: # add noise + else: # add noise L = noise_level2 / 255. D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -399,14 +398,14 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): img = np.clip((img * 255.0).round(), 0, 255) / 255. - vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + vals = 10**(2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: @@ -455,7 +454,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): sf_ori = sf h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: @@ -463,7 +462,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): hq = img.copy() - if sf == 4 and random.random() < scale2_prob: # downsample1 + if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3])) @@ -474,7 +473,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last + if idx1 > idx2: # keep downsample3 last shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: @@ -495,9 +494,9 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') - img = img[0::sf, 0::sf, ...] # nearest downsampling + img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) elif i == 3: @@ -547,12 +546,12 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): sf_ori = sf h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] hq = image.copy() - if sf == 4 and random.random() < scale2_prob: # downsample1 + if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), interpolation=random.choice([1, 2, 3])) @@ -563,7 +562,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last + if idx1 > idx2: # keep downsample3 last shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] for i in shuffle_order: @@ -587,9 +586,9 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') - image = image[0::sf, 0::sf, ...] # nearest downsampling + image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -617,13 +616,12 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): image = add_JPEG_noise(image) image = util.single2uint(image) if up: - image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + image = cv2.resize(image, (w1, h1), + interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then example = {"image": image} return example - - if __name__ == '__main__': print("hey") img = util.imread_uint('utils/test.png', 3) @@ -638,7 +636,8 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): img_lq = deg_fn(img)["image"] img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, + interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] print(img_lq.shape) print("bicubic", img_lq_bicubic.shape) print(img_hq.shape) diff --git a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py index 0175f155ad90..611a1846c0c9 100644 --- a/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py +++ b/examples/images/diffusion/ldm/modules/image_degradation/utils_image.py @@ -1,17 +1,16 @@ -import os import math +import os import random +from datetime import datetime + +import cv2 import numpy as np import torch -import cv2 from torchvision.utils import make_grid -from datetime import datetime -#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py - - -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" ''' # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) @@ -22,7 +21,6 @@ # -------------------------------------------- ''' - IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] @@ -49,10 +47,10 @@ def surf(Z, cmap='rainbow', figsize=None): ax3 = plt.axes(projection='3d') w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) + xx = np.arange(0, w, 1) + yy = np.arange(0, h, 1) X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) + ax3.plot_surface(X, Y, Z, cmap=cmap) #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) plt.show() @@ -65,7 +63,7 @@ def surf(Z, cmap='rainbow', figsize=None): def get_image_paths(dataroot): - paths = None # return None if dataroot is None + paths = None # return None if dataroot is None if dataroot is not None: paths = sorted(_get_paths_from_images(dataroot)) return paths @@ -85,7 +83,7 @@ def _get_paths_from_images(path): ''' # -------------------------------------------- -# split large images into small images +# split large images into small images # -------------------------------------------- ''' @@ -94,15 +92,15 @@ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): w, h = img.shape[:2] patches = [] if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) -# print(w1) -# print(h1) + w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int)) + w1.append(w - p_size) + h1.append(h - p_size) + # print(w1) + # print(h1) for i in w1: for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) + patches.append(img[i:i + p_size, j:j + p_size, :]) else: patches.append(img) @@ -118,7 +116,7 @@ def imssave(imgs, img_path): for i, img in enumerate(imgs): if img.ndim == 3: img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + new_path = os.path.join(os.path.dirname(img_path), img_name + str('_s{:04d}'.format(i)) + '.png') cv2.imwrite(new_path, img) @@ -139,10 +137,11 @@ def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, # img_name, ext = os.path.splitext(os.path.basename(img_path)) img = imread_uint(img_path, n_channels=n_channels) patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + imssave(patches, os.path.join(taget_dataroot, os.path.basename(img_path))) #if original_dataroot == taget_dataroot: #del img_path + ''' # -------------------------------------------- # makedir @@ -186,14 +185,14 @@ def imread_uint(path, n_channels=3): # input: path # output: HxWx3(RGB or GGG), or HxWx1 (G) if n_channels == 1: - img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE - img = np.expand_dims(img, axis=2) # HxWx1 + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 elif n_channels == 3: - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G if img.ndim == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG else: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB return img @@ -206,6 +205,7 @@ def imsave(img, img_path): img = img[:, :, [2, 1, 0]] cv2.imwrite(img_path, img) + def imwrite(img, img_path): img = np.squeeze(img) if img.ndim == 3: @@ -213,14 +213,13 @@ def imwrite(img, img_path): cv2.imwrite(img_path, img) - # -------------------------------------------- # get single image of size HxWxn_channles (BGR) # -------------------------------------------- def read_img(path): # read image by cv2 # return: Numpy float32, HWC, BGR, [0,1] - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE img = img.astype(np.float32) / 255. if img.ndim == 2: img = np.expand_dims(img, axis=2) @@ -240,7 +239,6 @@ def read_img(path): # -------------------------------------------- ''' - # -------------------------------------------- # numpy(single) [0, 1] <---> numpy(unit) # -------------------------------------------- @@ -248,22 +246,22 @@ def read_img(path): def uint2single(img): - return np.float32(img/255.) + return np.float32(img / 255.) def single2uint(img): - return np.uint8((img.clip(0, 1)*255.).round()) + return np.uint8((img.clip(0, 1) * 255.).round()) def uint162single(img): - return np.float32(img/65535.) + return np.float32(img / 65535.) def single2uint16(img): - return np.uint16((img.clip(0, 1)*65535.).round()) + return np.uint16((img.clip(0, 1) * 65535.).round()) # -------------------------------------------- @@ -290,7 +288,7 @@ def tensor2uint(img): img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) + return np.uint8((img * 255.0).round()) # -------------------------------------------- @@ -316,6 +314,7 @@ def tensor2single(img): return img + # convert torch tensor to single def tensor2single3(img): img = img.data.squeeze().float().cpu().numpy() @@ -345,21 +344,20 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) ''' - tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp - tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() if n_dim == 4: n_img = len(tensor) img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 3: img_np = tensor.numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 2: img_np = tensor.numpy() else: - raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. @@ -511,7 +509,7 @@ def shave(img_in, border=0): # img_in: Numpy, HWC or HW img = np.copy(img_in) h, w = img.shape[:2] - img = img[border:h-border, border:w-border] + img = img[border:h - border, border:w - border] return img @@ -541,8 +539,8 @@ def rgb2ycbcr(img, only_y=True): if only_y: rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214] + ]) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: @@ -585,8 +583,8 @@ def bgr2ycbcr(img, only_y=True): if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0] + ]) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: @@ -596,13 +594,13 @@ def bgr2ycbcr(img, only_y=True): def channel_convert(in_c, tar_type, img_list): # conversion among BGR, gray and y - if in_c == 3 and tar_type == 'gray': # BGR to gray + if in_c == 3 and tar_type == 'gray': # BGR to gray gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] return [np.expand_dims(img, axis=2) for img in gray_list] - elif in_c == 3 and tar_type == 'y': # BGR to y + elif in_c == 3 and tar_type == 'y': # BGR to y y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] return [np.expand_dims(img, axis=2) for img in y_list] - elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] else: return img_list @@ -625,8 +623,8 @@ def calculate_psnr(img1, img2, border=0): if not img1.shape == img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border:h - border, border:w - border] + img2 = img2[border:h - border, border:w - border] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) @@ -649,8 +647,8 @@ def calculate_ssim(img1, img2, border=0): if not img1.shape == img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] + img1 = img1[border:h - border, border:w - border] + img2 = img2[border:h - border, border:w - border] if img1.ndim == 2: return ssim(img1, img2) @@ -658,7 +656,7 @@ def calculate_ssim(img1, img2, border=0): if img1.shape[2] == 3: ssims = [] for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + ssims.append(ssim(img1[:, :, i], img2[:, :, i])) return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) @@ -675,7 +673,7 @@ def ssim(img1, img2): kernel = cv2.getGaussianKernel(11, 1.5) window = np.outer(kernel, kernel.transpose()) - mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] mu1_sq = mu1**2 mu2_sq = mu2**2 @@ -684,8 +682,7 @@ def ssim(img1, img2): sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() @@ -729,8 +726,8 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand( + out_length, P) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. @@ -781,10 +778,10 @@ def imresize(img, scale, antialiasing=True): # Now we do not support this. # get weights and indices - weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) - weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(in_H, out_H, scale, kernel, kernel_width, + antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(in_W, out_W, scale, kernel, kernel_width, + antialiasing) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) @@ -856,10 +853,10 @@ def imresize_np(img, scale, antialiasing=True): # Now we do not support this. # get weights and indices - weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) - weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(in_H, out_H, scale, kernel, kernel_width, + antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(in_W, out_W, scale, kernel, kernel_width, + antialiasing) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) @@ -913,4 +910,4 @@ def imresize_np(img, scale, antialiasing=True): print('---') # img = imread_uint('test.bmp', 3) # img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file +# img_bicubic = imresize_np(img, 1/4) diff --git a/examples/images/diffusion/ldm/modules/midas/api.py b/examples/images/diffusion/ldm/modules/midas/api.py index b58ebbffd942..ab4abb196ba9 100644 --- a/examples/images/diffusion/ldm/modules/midas/api.py +++ b/examples/images/diffusion/ldm/modules/midas/api.py @@ -3,13 +3,11 @@ import cv2 import torch import torch.nn as nn -from torchvision.transforms import Compose - from ldm.modules.midas.midas.dpt_depth import DPTDepthModel from ldm.modules.midas.midas.midas_net import MidasNet from ldm.modules.midas.midas.midas_net_custom import MidasNet_small -from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet - +from ldm.modules.midas.midas.transforms import NormalizeImage, PrepareForNet, Resize +from torchvision.transforms import Compose ISL_PATHS = { "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", @@ -28,12 +26,12 @@ def disabled_train(self, mode=True): def load_midas_transform(model_type): # https://github.com/isl-org/MiDaS/blob/master/run.py # load transform only - if model_type == "dpt_large": # DPT-Large + if model_type == "dpt_large": # DPT-Large net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - elif model_type == "dpt_hybrid": # DPT-Hybrid + elif model_type == "dpt_hybrid": # DPT-Hybrid net_w, net_h = 384, 384 resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) @@ -51,21 +49,19 @@ def load_midas_transform(model_type): else: assert False, f"model_type '{model_type}' not implemented, use: --model_type large" - transform = Compose( - [ - Resize( - net_w, - net_h, - resize_target=None, - keep_aspect_ratio=True, - ensure_multiple_of=32, - resize_method=resize_mode, - image_interpolation_method=cv2.INTER_CUBIC, - ), - normalization, - PrepareForNet(), - ] - ) + transform = Compose([ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ]) return transform @@ -74,7 +70,7 @@ def load_model(model_type): # https://github.com/isl-org/MiDaS/blob/master/run.py # load network model_path = ISL_PATHS[model_type] - if model_type == "dpt_large": # DPT-Large + if model_type == "dpt_large": # DPT-Large model = DPTDepthModel( path=model_path, backbone="vitl16_384", @@ -84,7 +80,7 @@ def load_model(model_type): resize_mode = "minimal" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - elif model_type == "dpt_hybrid": # DPT-Hybrid + elif model_type == "dpt_hybrid": # DPT-Hybrid model = DPTDepthModel( path=model_path, backbone="vitb_rn50_384", @@ -98,48 +94,42 @@ def load_model(model_type): model = MidasNet(model_path, non_negative=True) net_w, net_h = 384, 384 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) elif model_type == "midas_v21_small": - model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, - non_negative=True, blocks={'expand': True}) + model = MidasNet_small(model_path, + features=64, + backbone="efficientnet_lite3", + exportable=True, + non_negative=True, + blocks={'expand': True}) net_w, net_h = 256, 256 resize_mode = "upper_bound" - normalization = NormalizeImage( - mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] - ) + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) else: print(f"model_type '{model_type}' not implemented, use: --model_type large") assert False - transform = Compose( - [ - Resize( - net_w, - net_h, - resize_target=None, - keep_aspect_ratio=True, - ensure_multiple_of=32, - resize_method=resize_mode, - image_interpolation_method=cv2.INTER_CUBIC, - ), - normalization, - PrepareForNet(), - ] - ) + transform = Compose([ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ]) return model.eval(), transform class MiDaSInference(nn.Module): - MODEL_TYPES_TORCH_HUB = [ - "DPT_Large", - "DPT_Hybrid", - "MiDaS_small" - ] + MODEL_TYPES_TORCH_HUB = ["DPT_Large", "DPT_Hybrid", "MiDaS_small"] MODEL_TYPES_ISL = [ "dpt_large", "dpt_hybrid", @@ -167,4 +157,3 @@ def forward(self, x): ) assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) return prediction - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py index 5cf430239b47..43144b6a8f66 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/base_model.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/base_model.py @@ -2,6 +2,7 @@ class BaseModel(torch.nn.Module): + def load(self, path): """Load model from file. diff --git a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py index 2145d18fa980..6d00e828e2ae 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/blocks.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/blocks.py @@ -1,21 +1,24 @@ import torch import torch.nn as nn -from .vit import ( - _make_pretrained_vitb_rn50_384, - _make_pretrained_vitl16_384, - _make_pretrained_vitb16_384, - forward_vit, -) - -def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): +from .vit import _make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384, _make_pretrained_vitl16_384, forward_vit + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", +): if backbone == "vitl16_384": - pretrained = _make_pretrained_vitl16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) - scratch = _make_scratch( - [256, 512, 1024, 1024], features, groups=groups, expand=expand - ) # ViT-L/16 - 85.0% Top1 (backbone) + pretrained = _make_pretrained_vitl16_384(use_pretrained, hooks=hooks, use_readout=use_readout) + scratch = _make_scratch([256, 512, 1024, 1024], features, groups=groups, + expand=expand) # ViT-L/16 - 85.0% Top1 (backbone) elif backbone == "vitb_rn50_384": pretrained = _make_pretrained_vitb_rn50_384( use_pretrained, @@ -23,26 +26,22 @@ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, ex use_vit_only=use_vit_only, use_readout=use_readout, ) - scratch = _make_scratch( - [256, 512, 768, 768], features, groups=groups, expand=expand - ) # ViT-H/16 - 85.0% Top1 (backbone) + scratch = _make_scratch([256, 512, 768, 768], features, groups=groups, + expand=expand) # ViT-H/16 - 85.0% Top1 (backbone) elif backbone == "vitb16_384": - pretrained = _make_pretrained_vitb16_384( - use_pretrained, hooks=hooks, use_readout=use_readout - ) - scratch = _make_scratch( - [96, 192, 384, 768], features, groups=groups, expand=expand - ) # ViT-B/16 - 84.6% Top1 (backbone) + pretrained = _make_pretrained_vitb16_384(use_pretrained, hooks=hooks, use_readout=use_readout) + scratch = _make_scratch([96, 192, 384, 768], features, groups=groups, + expand=expand) # ViT-B/16 - 84.6% Top1 (backbone) elif backbone == "resnext101_wsl": pretrained = _make_pretrained_resnext101_wsl(use_pretrained) - scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 elif backbone == "efficientnet_lite3": pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) - scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 else: print(f"Backbone '{backbone}' not implemented") assert False - + return pretrained, scratch @@ -53,56 +52,66 @@ def _make_scratch(in_shape, out_shape, groups=1, expand=False): out_shape2 = out_shape out_shape3 = out_shape out_shape4 = out_shape - if expand==True: + if expand == True: out_shape1 = out_shape - out_shape2 = out_shape*2 - out_shape3 = out_shape*4 - out_shape4 = out_shape*8 - - scratch.layer1_rn = nn.Conv2d( - in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer2_rn = nn.Conv2d( - in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer3_rn = nn.Conv2d( - in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) - scratch.layer4_rn = nn.Conv2d( - in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups - ) + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups) + scratch.layer4_rn = nn.Conv2d(in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups) return scratch def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): - efficientnet = torch.hub.load( - "rwightman/gen-efficientnet-pytorch", - "tf_efficientnet_lite3", - pretrained=use_pretrained, - exportable=exportable - ) + efficientnet = torch.hub.load("rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable) return _make_efficientnet_backbone(efficientnet) def _make_efficientnet_backbone(effnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] - ) + pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]) pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) return pretrained - + def _make_resnet_backbone(resnet): pretrained = nn.Module() - pretrained.layer1 = nn.Sequential( - resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 - ) + pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1) pretrained.layer2 = resnet.layer2 pretrained.layer3 = resnet.layer3 @@ -116,7 +125,6 @@ def _make_pretrained_resnext101_wsl(use_pretrained): return _make_resnet_backbone(resnet) - class Interpolate(nn.Module): """Interpolation module. """ @@ -145,9 +153,7 @@ def forward(self, x): tensor: interpolated data """ - x = self.interp( - x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners - ) + x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) return x @@ -164,13 +170,9 @@ def __init__(self, features): """ super().__init__() - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True - ) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) self.relu = nn.ReLU(inplace=True) @@ -219,15 +221,11 @@ def forward(self, *xs): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=True - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True) return output - - class ResidualConvUnit_custom(nn.Module): """Residual convolution module. """ @@ -242,17 +240,13 @@ def __init__(self, features, activation, bn): self.bn = bn - self.groups=1 + self.groups = 1 - self.conv1 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) - - self.conv2 = nn.Conv2d( - features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups - ) + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) - if self.bn==True: + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: self.bn1 = nn.BatchNorm2d(features) self.bn2 = nn.BatchNorm2d(features) @@ -269,15 +263,15 @@ def forward(self, x): Returns: tensor: output """ - + out = self.activation(x) out = self.conv1(out) - if self.bn==True: + if self.bn == True: out = self.bn1(out) - + out = self.activation(out) out = self.conv2(out) - if self.bn==True: + if self.bn == True: out = self.bn2(out) if self.groups > 1: @@ -303,18 +297,18 @@ def __init__(self, features, activation, deconv=False, bn=False, expand=False, a self.deconv = deconv self.align_corners = align_corners - self.groups=1 + self.groups = 1 self.expand = expand out_features = features - if self.expand==True: - out_features = features//2 - + if self.expand == True: + out_features = features // 2 + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) - + self.skip_add = nn.quantized.FloatFunctional() def forward(self, *xs): @@ -332,11 +326,8 @@ def forward(self, *xs): output = self.resConfUnit2(output) - output = nn.functional.interpolate( - output, scale_factor=2, mode="bilinear", align_corners=self.align_corners - ) + output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=self.align_corners) output = self.out_conv(output) return output - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py index 4e9aab5d2767..7460b6a05592 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/dpt_depth.py @@ -3,13 +3,7 @@ import torch.nn.functional as F from .base_model import BaseModel -from .blocks import ( - FeatureFusionBlock, - FeatureFusionBlock_custom, - Interpolate, - _make_encoder, - forward_vit, -) +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder, forward_vit def _make_fusion_block(features, use_bn): @@ -24,6 +18,7 @@ def _make_fusion_block(features, use_bn): class DPT(BaseModel): + def __init__( self, head, @@ -48,7 +43,7 @@ def __init__( self.pretrained, self.scratch = _make_encoder( backbone, features, - False, # Set to true of you want to train from scratch, uses ImageNet weights + False, # Set to true of you want to train from scratch, uses ImageNet weights groups=1, expand=False, exportable=False, @@ -63,7 +58,6 @@ def __init__( self.scratch.output_conv = head - def forward(self, x): if self.channels_last == True: x.contiguous(memory_format=torch.channels_last) @@ -86,6 +80,7 @@ def forward(self, x): class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): features = kwargs["features"] if "features" in kwargs else 256 @@ -102,8 +97,7 @@ def __init__(self, path=None, non_negative=True, **kwargs): super().__init__(head, **kwargs) if path is not None: - self.load(path) + self.load(path) def forward(self, x): return super().forward(x).squeeze(dim=1) - diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py index 8a954977800b..da518e0ba655 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net.py @@ -27,7 +27,9 @@ def __init__(self, path=None, features=256, non_negative=True): use_pretrained = False if path is None else True - self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", + features=features, + use_pretrained=use_pretrained) self.scratch.refinenet4 = FeatureFusionBlock(features) self.scratch.refinenet3 = FeatureFusionBlock(features) diff --git a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py index 50e4acb5e53d..e2cfa5de4e08 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/midas_net_custom.py @@ -13,8 +13,15 @@ class MidasNet_small(BaseModel): """Network for monocular depth estimation. """ - def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, - blocks={'expand': True}): + def __init__(self, + path=None, + features=64, + backbone="efficientnet_lite3", + non_negative=True, + exportable=True, + channels_last=False, + align_corners=True, + blocks={'expand': True}): """Init. Args: @@ -27,49 +34,71 @@ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_ne super(MidasNet_small, self).__init__() use_pretrained = False if path else True - + self.channels_last = channels_last self.blocks = blocks self.backbone = backbone self.groups = 1 - features1=features - features2=features - features3=features - features4=features + features1 = features + features2 = features + features3 = features + features4 = features self.expand = False if "expand" in self.blocks and self.blocks['expand'] == True: self.expand = True - features1=features - features2=features*2 - features3=features*4 - features4=features*8 - - self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) - - self.scratch.activation = nn.ReLU(False) - - self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) - self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + features1 = features + features2 = features * 2 + features3 = features * 4 + features4 = features * 8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, + features, + use_pretrained, + groups=self.groups, + expand=self.expand, + exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, + self.scratch.activation, + deconv=False, + bn=False, + expand=self.expand, + align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, + self.scratch.activation, + deconv=False, + bn=False, + expand=self.expand, + align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, + self.scratch.activation, + deconv=False, + bn=False, + expand=self.expand, + align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, + self.scratch.activation, + deconv=False, + bn=False, + align_corners=align_corners) - self.scratch.output_conv = nn.Sequential( - nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1, groups=self.groups), Interpolate(scale_factor=2, mode="bilinear"), - nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), self.scratch.activation, nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), nn.ReLU(True) if non_negative else nn.Identity(), nn.Identity(), ) - + if path: self.load(path) - def forward(self, x): """Forward pass. @@ -79,33 +108,30 @@ def forward(self, x): Returns: tensor: depth """ - if self.channels_last==True: + if self.channels_last == True: print("self.channels_last = ", self.channels_last) x.contiguous(memory_format=torch.channels_last) - layer_1 = self.pretrained.layer1(x) layer_2 = self.pretrained.layer2(layer_1) layer_3 = self.pretrained.layer3(layer_2) layer_4 = self.pretrained.layer4(layer_3) - + layer_1_rn = self.scratch.layer1_rn(layer_1) layer_2_rn = self.scratch.layer2_rn(layer_2) layer_3_rn = self.scratch.layer3_rn(layer_3) layer_4_rn = self.scratch.layer4_rn(layer_4) - path_4 = self.scratch.refinenet4(layer_4_rn) path_3 = self.scratch.refinenet3(path_4, layer_3_rn) path_2 = self.scratch.refinenet2(path_3, layer_2_rn) path_1 = self.scratch.refinenet1(path_2, layer_1_rn) - + out = self.scratch.output_conv(path_1) return torch.squeeze(out, dim=1) - def fuse_model(m): prev_previous_type = nn.Identity() prev_previous_name = '' @@ -125,4 +151,4 @@ def fuse_model(m): prev_previous_type = previous_type prev_previous_name = previous_name previous_type = type(module) - previous_name = name \ No newline at end of file + previous_name = name diff --git a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py index 350cbc116626..f60e2e9ca606 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/transforms.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/transforms.py @@ -1,7 +1,8 @@ -import numpy as np -import cv2 import math +import cv2 +import numpy as np + def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): """Rezise the sample to ensure the given size. Keeps aspect ratio. @@ -28,13 +29,9 @@ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): shape[1] = math.ceil(scale * shape[1]) # resize - sample["image"] = cv2.resize( - sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method - ) + sample["image"] = cv2.resize(sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method) - sample["disparity"] = cv2.resize( - sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST - ) + sample["disparity"] = cv2.resize(sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), tuple(shape[::-1]), @@ -133,24 +130,14 @@ def get_size(self, width, height): # fit height scale_width = scale_height else: - raise ValueError( - f"resize_method {self.__resize_method} not implemented" - ) + raise ValueError(f"resize_method {self.__resize_method} not implemented") if self.__resize_method == "lower_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, min_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, min_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) elif self.__resize_method == "upper_bound": - new_height = self.constrain_to_multiple_of( - scale_height * height, max_val=self.__height - ) - new_width = self.constrain_to_multiple_of( - scale_width * width, max_val=self.__width - ) + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) elif self.__resize_method == "minimal": new_height = self.constrain_to_multiple_of(scale_height * height) new_width = self.constrain_to_multiple_of(scale_width * width) @@ -160,9 +147,7 @@ def get_size(self, width, height): return (new_width, new_height) def __call__(self, sample): - width, height = self.get_size( - sample["image"].shape[1], sample["image"].shape[0] - ) + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) # resize sample sample["image"] = cv2.resize( @@ -180,9 +165,7 @@ def __call__(self, sample): ) if "depth" in sample: - sample["depth"] = cv2.resize( - sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST - ) + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) sample["mask"] = cv2.resize( sample["mask"].astype(np.float32), diff --git a/examples/images/diffusion/ldm/modules/midas/midas/vit.py b/examples/images/diffusion/ldm/modules/midas/midas/vit.py index ea46b1be88b2..88e7b50b54ed 100644 --- a/examples/images/diffusion/ldm/modules/midas/midas/vit.py +++ b/examples/images/diffusion/ldm/modules/midas/midas/vit.py @@ -1,21 +1,24 @@ +import math +import types + +import timm import torch import torch.nn as nn -import timm -import types -import math import torch.nn.functional as F class Slice(nn.Module): + def __init__(self, start_index=1): super(Slice, self).__init__() self.start_index = start_index def forward(self, x): - return x[:, self.start_index :] + return x[:, self.start_index:] class AddReadout(nn.Module): + def __init__(self, start_index=1): super(AddReadout, self).__init__() self.start_index = start_index @@ -25,10 +28,11 @@ def forward(self, x): readout = (x[:, 0] + x[:, 1]) / 2 else: readout = x[:, 0] - return x[:, self.start_index :] + readout.unsqueeze(1) + return x[:, self.start_index:] + readout.unsqueeze(1) class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): super(ProjectReadout, self).__init__() self.start_index = start_index @@ -36,13 +40,14 @@ def __init__(self, in_features, start_index=1): self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) def forward(self, x): - readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) - features = torch.cat((x[:, self.start_index :], readout), -1) + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:]) + features = torch.cat((x[:, self.start_index:], readout), -1) return self.project(features) class Transpose(nn.Module): + def __init__(self, dim0, dim1): super(Transpose, self).__init__() self.dim0 = dim0 @@ -71,14 +76,11 @@ def forward_vit(pretrained, x): unflatten = nn.Sequential( nn.Unflatten( 2, - torch.Size( - [ - h // pretrained.model.patch_size[1], - w // pretrained.model.patch_size[0], - ] - ), - ) - ) + torch.Size([ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ]), + )) if layer_1.ndim == 3: layer_1 = unflatten(layer_1) @@ -89,18 +91,18 @@ def forward_vit(pretrained, x): if layer_4.ndim == 3: layer_4 = unflatten(layer_4) - layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) - layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) - layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) - layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](layer_4) return layer_1, layer_2, layer_3, layer_4 def _resize_pos_embed(self, posemb, gs_h, gs_w): posemb_tok, posemb_grid = ( - posemb[:, : self.start_index], - posemb[0, self.start_index :], + posemb[:, :self.start_index], + posemb[0, self.start_index:], ) gs_old = int(math.sqrt(len(posemb_grid))) @@ -117,29 +119,23 @@ def _resize_pos_embed(self, posemb, gs_h, gs_w): def forward_flex(self, x): b, c, h, w = x.shape - pos_embed = self._resize_pos_embed( - self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] - ) + pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]) B = x.shape[0] if hasattr(self.patch_embed, "backbone"): x = self.patch_embed.backbone(x) if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features + x = x[-1] # last feature if backbone outputs list/tuple of features x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) if getattr(self, "dist_token", None) is not None: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_tokens, dist_token, x), dim=1) else: - cls_tokens = self.cls_token.expand( - B, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + pos_embed @@ -157,6 +153,7 @@ def forward_flex(self, x): def get_activation(name): + def hook(model, input, output): activations[name] = output @@ -169,13 +166,9 @@ def get_readout_oper(vit_features, features, use_readout, start_index=1): elif use_readout == "add": readout_oper = [AddReadout(start_index)] * len(features) elif use_readout == "project": - readout_oper = [ - ProjectReadout(vit_features, start_index) for out_feat in features - ] + readout_oper = [ProjectReadout(vit_features, start_index) for out_feat in features] else: - assert ( - False - ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + assert (False), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" return readout_oper @@ -287,9 +280,7 @@ def _make_vit_b16_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained @@ -311,24 +302,18 @@ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks - return _make_vit_b16_backbone( - model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout - ) + return _make_vit_b16_backbone(model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout) def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): - model = timm.create_model( - "vit_deit_base_distilled_patch16_384", pretrained=pretrained - ) + model = timm.create_model("vit_deit_base_distilled_patch16_384", pretrained=pretrained) hooks = [2, 5, 8, 11] if hooks == None else hooks return _make_vit_b16_backbone( @@ -358,12 +343,8 @@ def _make_vit_b_rn50_backbone( pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) else: - pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( - get_activation("1") - ) - pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( - get_activation("2") - ) + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(get_activation("1")) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(get_activation("2")) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) @@ -419,12 +400,8 @@ def _make_vit_b_rn50_backbone( ), ) else: - pretrained.act_postprocess1 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) - pretrained.act_postprocess2 = nn.Sequential( - nn.Identity(), nn.Identity(), nn.Identity() - ) + pretrained.act_postprocess1 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) + pretrained.act_postprocess2 = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity()) pretrained.act_postprocess3 = nn.Sequential( readout_oper[2], @@ -468,16 +445,12 @@ def _make_vit_b_rn50_backbone( # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. - pretrained.model._resize_pos_embed = types.MethodType( - _resize_pos_embed, pretrained.model - ) + pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model) return pretrained -def _make_pretrained_vitb_rn50_384( - pretrained, use_readout="ignore", hooks=None, use_vit_only=False -): +def _make_pretrained_vitb_rn50_384(pretrained, use_readout="ignore", hooks=None, use_vit_only=False): model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) hooks = [0, 1, 8, 11] if hooks == None else hooks diff --git a/examples/images/diffusion/ldm/modules/midas/utils.py b/examples/images/diffusion/ldm/modules/midas/utils.py index 9a9d3b5b6637..d346a8977dd6 100644 --- a/examples/images/diffusion/ldm/modules/midas/utils.py +++ b/examples/images/diffusion/ldm/modules/midas/utils.py @@ -1,8 +1,9 @@ """Utils for monoDepth.""" -import sys import re -import numpy as np +import sys + import cv2 +import numpy as np import torch @@ -72,11 +73,9 @@ def write_pfm(path, image, scale=1): image = np.flipud(image) - if len(image.shape) == 3 and image.shape[2] == 3: # color image + if len(image.shape) == 3 and image.shape[2] == 3: # color image color = True - elif ( - len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 - ): # greyscale + elif (len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1): # greyscale color = False else: raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") @@ -135,9 +134,7 @@ def resize_image(img): img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) - img_resized = ( - torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() - ) + img_resized = (torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()) img_resized = img_resized.unsqueeze(0) return img_resized @@ -156,12 +153,11 @@ def resize_depth(depth, width, height): """ depth = torch.squeeze(depth[0, :, :, :]).to("cpu") - depth_resized = cv2.resize( - depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC - ) + depth_resized = cv2.resize(depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC) return depth_resized + def write_depth(path, depth, bits=1): """Write depth map to pfm and png file. @@ -174,7 +170,7 @@ def write_depth(path, depth, bits=1): depth_min = depth.min() depth_max = depth.max() - max_val = (2**(8*bits))-1 + max_val = (2**(8 * bits)) - 1 if depth_max - depth_min > np.finfo("float").eps: out = max_val * (depth - depth_min) / (depth_max - depth_min) diff --git a/examples/images/diffusion/ldm/util.py b/examples/images/diffusion/ldm/util.py index 8c09ca1c72f7..6b9dba3f6afd 100644 --- a/examples/images/diffusion/ldm/util.py +++ b/examples/images/diffusion/ldm/util.py @@ -1,11 +1,10 @@ import importlib +from inspect import isfunction -import torch -from torch import optim import numpy as np - -from inspect import isfunction +import torch from PIL import Image, ImageDraw, ImageFont +from torch import optim def log_txt_as_img(wh, xc, size=10): @@ -39,7 +38,7 @@ def ismap(x): def isimage(x): - if not isinstance(x,torch.Tensor): + if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) @@ -89,9 +88,17 @@ def get_obj_from_str(string, reload=False): class AdamWwithEMAandWings(optim.Optimizer): # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 - def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using - weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code - ema_power=1., param_names=()): + def __init__( + self, + params, + lr=1.e-3, + betas=(0.9, 0.999), + eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, + amsgrad=False, + ema_decay=0.9999, # ema decay to match previous code + ema_power=1., + param_names=()): """AdamW that saves EMA versions of the parameters.""" if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -105,9 +112,14 @@ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: che raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) if not 0.0 <= ema_decay <= 1.0: raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, - ema_power=ema_power, param_names=param_names) + defaults = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ema_decay=ema_decay, + ema_power=ema_power, + param_names=param_names) super().__init__(params, defaults) def __setstate__(self, state): @@ -177,21 +189,21 @@ def step(self, closure=None): state_steps.append(state['step']) optim._functional.adamw(params_with_grad, - grads, - exp_avgs, - exp_avg_sqs, - max_exp_avg_sqs, - state_steps, - amsgrad=amsgrad, - beta1=beta1, - beta2=beta2, - lr=group['lr'], - weight_decay=group['weight_decay'], - eps=group['eps'], - maximize=False) - - cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power) for param, ema_param in zip(params_with_grad, ema_params_with_grad): ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) - return loss \ No newline at end of file + return loss diff --git a/examples/images/diffusion/scripts/download_first_stages.sh b/examples/images/diffusion/scripts/download_first_stages.sh index a8d79e99ccdf..50dab5de5b90 100755 --- a/examples/images/diffusion/scripts/download_first_stages.sh +++ b/examples/images/diffusion/scripts/download_first_stages.sh @@ -38,4 +38,4 @@ unzip -o model.zip cd ../vq-f16 unzip -o model.zip -cd ../.. \ No newline at end of file +cd ../.. diff --git a/examples/images/diffusion/scripts/img2img.py b/examples/images/diffusion/scripts/img2img.py index 877538d4733d..07a0894960e1 100644 --- a/examples/images/diffusion/scripts/img2img.py +++ b/examples/images/diffusion/scripts/img2img.py @@ -1,28 +1,30 @@ """make variations of input image""" -import argparse, os +import argparse +import os +from contextlib import nullcontext +from itertools import islice + +import numpy as np import PIL import torch -import numpy as np +from einops import rearrange, repeat from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange, repeat -from torchvision.utils import make_grid from torch import autocast -from contextlib import nullcontext +from torchvision.utils import make_grid +from tqdm import tqdm, trange + try: from lightning.pytorch import seed_everything except: from pytorch_lightning import seed_everything -from imwatermark import WatermarkEncoder - -from scripts.txt2img import put_watermark -from ldm.util import instantiate_from_config +from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler -from utils import replace_module, getModelSize +from ldm.util import instantiate_from_config +from scripts.txt2img import put_watermark +from utils import getModelSize, replace_module def chunk(it, size): @@ -53,7 +55,7 @@ def load_img(path): image = Image.open(path).convert("RGB") w, h = image.size print(f"loaded input image of size ({w}, {h}) from {path}") - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) @@ -64,28 +66,19 @@ def load_img(path): def main(): parser = argparse.ArgumentParser() - parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="a painting of a virus monster playing guitar", - help="the prompt to render" - ) + parser.add_argument("--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render") - parser.add_argument( - "--init-img", - type=str, - nargs="?", - help="path to the input image" - ) + parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image") - parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/img2img-samples" - ) + parser.add_argument("--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/img2img-samples") parser.add_argument( "--ddim_steps", @@ -176,13 +169,11 @@ def main(): default=42, help="the seed (for reproducible sampling)", ) - parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" - ) + parser.add_argument("--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast") parser.add_argument( "--use_int8", type=bool, @@ -204,7 +195,7 @@ def main(): model = replace_module(model) # # to compute the model size # getModelSize(model) - + sampler = DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) @@ -236,7 +227,7 @@ def main(): assert os.path.isfile(opt.init_img) init_image = load_img(opt.init_img).to(device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) @@ -261,8 +252,13 @@ def main(): # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, ) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + ) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/examples/images/diffusion/scripts/inpaint.py b/examples/images/diffusion/scripts/inpaint.py index d6e6387a9a3b..5fbd991f6e48 100644 --- a/examples/images/diffusion/scripts/inpaint.py +++ b/examples/images/diffusion/scripts/inpaint.py @@ -1,32 +1,36 @@ -import argparse, os, sys, glob -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm +import argparse +import glob +import os +import sys + import numpy as np import torch -from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +from main import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm def make_batch(image, mask, device): image = np.array(Image.open(image).convert("RGB")) - image = image.astype(np.float32)/255.0 - image = image[None].transpose(0,3,1,2) + image = image.astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) mask = np.array(Image.open(mask).convert("L")) - mask = mask.astype(np.float32)/255.0 - mask = mask[None,None] + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) - masked_image = (1-mask)*image + masked_image = (1 - mask) * image batch = {"image": image, "mask": mask, "masked_image": masked_image} for k in batch: batch[k] = batch[k].to(device=device) - batch[k] = batch[k]*2.0-1.0 + batch[k] = batch[k] * 2.0 - 1.0 return batch @@ -58,8 +62,7 @@ def make_batch(image, mask, device): config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") model = instantiate_from_config(config.model) - model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], - strict=False) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) @@ -74,11 +77,10 @@ def make_batch(image, mask, device): # encode masked image and concat downsampled mask c = model.cond_stage_model.encode(batch["masked_image"]) - cc = torch.nn.functional.interpolate(batch["mask"], - size=c.shape[-2:]) + cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) c = torch.cat((c, cc), dim=1) - shape = (c.shape[1]-1,)+c.shape[2:] + shape = (c.shape[1] - 1,) + c.shape[2:] samples_ddim, _ = sampler.sample(S=opt.steps, conditioning=c, batch_size=c.shape[0], @@ -86,13 +88,10 @@ def make_batch(image, mask, device): verbose=False) x_samples_ddim = model.decode_first_stage(samples_ddim) - image = torch.clamp((batch["image"]+1.0)/2.0, - min=0.0, max=1.0) - mask = torch.clamp((batch["mask"]+1.0)/2.0, - min=0.0, max=1.0) - predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, - min=0.0, max=1.0) + image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0) + mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - inpainted = (1-mask)*image+mask*predicted_image - inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + inpainted = (1 - mask) * image + mask * predicted_image + inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/examples/images/diffusion/scripts/knn2img.py b/examples/images/diffusion/scripts/knn2img.py index e6eaaecab53e..a2129ae49aa6 100644 --- a/examples/images/diffusion/scripts/knn2img.py +++ b/examples/images/diffusion/scripts/knn2img.py @@ -1,22 +1,25 @@ -import argparse, os, sys, glob +import argparse +import glob +import os +import sys +import time +from itertools import islice +from multiprocessing import cpu_count + import clip +import numpy as np +import scann import torch import torch.nn as nn -import numpy as np -from omegaconf import OmegaConf -from PIL import Image -from tqdm import tqdm, trange -from itertools import islice from einops import rearrange, repeat -from torchvision.utils import make_grid -import scann -import time -from multiprocessing import cpu_count - -from ldm.util import instantiate_from_config, parallel_data_prefetch from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder +from ldm.util import instantiate_from_config, parallel_data_prefetch +from omegaconf import OmegaConf +from PIL import Image +from torchvision.utils import make_grid +from tqdm import tqdm, trange DATABASES = [ "openimages", @@ -59,6 +62,7 @@ def load_model_from_config(config, ckpt, verbose=False): class Searcher(object): + def __init__(self, database, retriever_version='ViT-L/14'): assert database in DATABASES # self.database = self.load_database(database) @@ -66,20 +70,15 @@ def __init__(self, database, retriever_version='ViT-L/14'): self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' self.retriever = self.load_retriever(version=retriever_version) - self.database = {'embedding': [], - 'img_id': [], - 'patch_coords': []} + self.database = {'embedding': [], 'img_id': [], 'patch_coords': []} self.load_database() self.load_searcher() - def train_searcher(self, k, - metric='dot_product', - searcher_savedir=None): + def train_searcher(self, k, metric='dot_product', searcher_savedir=None): print('Start training searcher') - searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / - np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], - k, metric) + searcher = scann.scann_ops_pybind.builder( + self.database['embedding'] / np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], k, metric) self.searcher = searcher.score_brute_force().build() print('Finish training searcher') @@ -110,17 +109,23 @@ def load_database(self): self.load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(self.load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') - - self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in - self.database} + prefetched_data = parallel_data_prefetch(self.load_multi_files, + data, + n_proc=min(len(data), cpu_count()), + target_data_type='dict') + + self.database = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database + } else: raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') - def load_retriever(self, version='ViT-L/14', ): + def load_retriever( + self, + version='ViT-L/14', + ): model = FrozenClipImageEmbedder(model=version) if torch.cuda.is_available(): model.cuda() @@ -134,7 +139,7 @@ def load_searcher(self): def search(self, x, k): if self.searcher is None and self.database['embedding'].shape[0] < 2e4: - self.train_searcher(k) # quickly fit searcher on the fly for small databases + self.train_searcher(k) # quickly fit searcher on the fly for small databases assert self.searcher is not None, 'Cannot search with uninitialized searcher' if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() @@ -150,13 +155,15 @@ def search(self, x, k): out_img_ids = self.database['img_id'][nns] out_pc = self.database['patch_coords'][nns] - out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], - 'img_ids': out_img_ids, - 'patch_coords': out_pc, - 'queries': x, - 'exec_time': end - start, - 'nns': nns, - 'q_embeddings': query_embeddings} + out = { + 'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + 'img_ids': out_img_ids, + 'patch_coords': out_pc, + 'queries': x, + 'exec_time': end - start, + 'nns': nns, + 'q_embeddings': query_embeddings + } return out @@ -168,21 +175,17 @@ def __call__(self, x, n): parser = argparse.ArgumentParser() # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc) # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt? - parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="a painting of a virus monster playing guitar", - help="the prompt to render" - ) - - parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" - ) + parser.add_argument("--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render") + + parser.add_argument("--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples") parser.add_argument( "--skip_grid", @@ -363,16 +366,17 @@ def __call__(self, x, n): uc = torch.zeros_like(c) if isinstance(prompts, tuple): prompts = list(prompts) - shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - ) + shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model + samples_ddim, _ = sampler.sample( + S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/examples/images/diffusion/scripts/sample_diffusion.py b/examples/images/diffusion/scripts/sample_diffusion.py index 876fe3c3642f..84131e7c4768 100644 --- a/examples/images/diffusion/scripts/sample_diffusion.py +++ b/examples/images/diffusion/scripts/sample_diffusion.py @@ -1,17 +1,22 @@ -import argparse, os, sys, glob, datetime, yaml -import torch +import argparse +import datetime +import glob +import os +import sys import time -import numpy as np -from tqdm import trange - -from omegaconf import OmegaConf -from PIL import Image +import numpy as np +import torch +import yaml from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config +from omegaconf import OmegaConf +from PIL import Image +from tqdm import trange rescale = lambda x: (x + 1.) / 2. + def custom_to_pil(x): x = x.detach().cpu() x = torch.clamp(x, -1., 1.) @@ -51,49 +56,51 @@ def logs2pil(logs, keys=["sample"]): @torch.no_grad() -def convsample(model, shape, return_intermediates=True, - verbose=True, - make_prog_row=False): - +def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False): if not make_prog_row: - return model.p_sample_loop(None, shape, - return_intermediates=return_intermediates, verbose=verbose) + return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose) else: - return model.progressive_denoising( - None, shape, verbose=True - ) + return model.progressive_denoising(None, shape, verbose=True) @torch.no_grad() -def convsample_ddim(model, steps, shape, eta=1.0 - ): +def convsample_ddim(model, steps, shape, eta=1.0): ddim = DDIMSampler(model) bs = shape[0] shape = shape[1:] - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + samples, intermediates = ddim.sample( + steps, + batch_size=bs, + shape=shape, + eta=eta, + verbose=False, + ) return samples, intermediates @torch.no_grad() -def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): - +def make_convolutional_sample( + model, + batch_size, + vanilla=False, + custom_steps=None, + eta=1.0, +): log = dict() - shape = [batch_size, - model.model.diffusion_model.in_channels, - model.model.diffusion_model.image_size, - model.model.diffusion_model.image_size] + shape = [ + batch_size, model.model.diffusion_model.in_channels, model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size + ] with model.ema_scope("Plotting"): t0 = time.time() if vanilla: - sample, progrow = convsample(model, shape, - make_prog_row=True) + sample, progrow = convsample(model, shape, make_prog_row=True) else: - sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, - eta=eta) + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta) t1 = time.time() @@ -105,23 +112,25 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non print(f'Throughput for this batch: {log["throughput"]}') return log + def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): if vanilla: print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') else: print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') - tstart = time.time() - n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + n_saved = len(glob.glob(os.path.join(logdir, '*.png'))) - 1 # path = logdir if model.cond_stage_model is None: all_images = [] print(f"Running unconditional sampling for {n_samples} samples") for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): - logs = make_convolutional_sample(model, batch_size=batch_size, - vanilla=vanilla, custom_steps=custom_steps, + logs = make_convolutional_sample(model, + batch_size=batch_size, + vanilla=vanilla, + custom_steps=custom_steps, eta=eta) n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") all_images.extend([custom_to_np(logs["sample"])]) @@ -135,7 +144,7 @@ def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None np.savez(nppath, all_img) else: - raise NotImplementedError('Currently only sampling for unconditional models supported.') + raise NotImplementedError('Currently only sampling for unconditional models supported.') print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") @@ -168,22 +177,13 @@ def get_parser(): nargs="?", help="load from logdir or checkpoint in logdir", ) - parser.add_argument( - "-n", - "--n_samples", - type=int, - nargs="?", - help="number of samples to draw", - default=50000 - ) - parser.add_argument( - "-e", - "--eta", - type=float, - nargs="?", - help="eta for ddim sampling (0.0 yields deterministic sampling)", - default=1.0 - ) + parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000) + parser.add_argument("-e", + "--eta", + type=float, + nargs="?", + help="eta for ddim sampling (0.0 yields deterministic sampling)", + default=1.0) parser.add_argument( "-v", "--vanilla_sample", @@ -191,35 +191,20 @@ def get_parser(): action='store_true', help="vanilla sampling (default option is DDIM sampling)?", ) - parser.add_argument( - "-l", - "--logdir", - type=str, - nargs="?", - help="extra logdir", - default="none" - ) - parser.add_argument( - "-c", - "--custom_steps", - type=int, - nargs="?", - help="number of steps for ddim and fastdpm sampling", - default=50 - ) - parser.add_argument( - "--batch_size", - type=int, - nargs="?", - help="the bs", - default=10 - ) + parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none") + parser.add_argument("-c", + "--custom_steps", + type=int, + nargs="?", + help="number of steps for ddim and fastdpm sampling", + default=50) + parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10) return parser def load_model_from_config(config, sd): model = instantiate_from_config(config) - model.load_state_dict(sd,strict=False) + model.load_state_dict(sd, strict=False) model.cuda() model.eval() return model @@ -233,8 +218,7 @@ def load_model(config, ckpt, gpu, eval_mode): else: pl_sd = {"state_dict": None} global_step = None - model = load_model_from_config(config.model, - pl_sd["state_dict"]) + model = load_model_from_config(config.model, pl_sd["state_dict"]) return model, global_step @@ -258,7 +242,7 @@ def load_model(config, ckpt, gpu, eval_mode): print(f'Logdir is {logdir}') except ValueError: paths = opt.resume.split("/") - idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt + idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt logdir = "/".join(paths[:idx]) ckpt = opt.resume else: @@ -278,7 +262,8 @@ def load_model(config, ckpt, gpu, eval_mode): if opt.logdir != "none": locallog = logdir.split(os.sep)[-1] - if locallog == "": locallog = logdir.split(os.sep)[-2] + if locallog == "": + locallog = logdir.split(os.sep)[-2] print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") logdir = os.path.join(opt.logdir, locallog) @@ -305,9 +290,13 @@ def load_model(config, ckpt, gpu, eval_mode): yaml.dump(sampling_conf, f, default_flow_style=False) print(sampling_conf) - - run(model, imglogdir, eta=opt.eta, - vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, - batch_size=opt.batch_size, nplog=numpylogdir) + run(model, + imglogdir, + eta=opt.eta, + vanilla=opt.vanilla_sample, + n_samples=opt.n_samples, + custom_steps=opt.custom_steps, + batch_size=opt.batch_size, + nplog=numpylogdir) print("done.") diff --git a/examples/images/diffusion/scripts/tests/test_checkpoint.py b/examples/images/diffusion/scripts/tests/test_checkpoint.py index a32e66d44cf2..1caa4b3fadda 100644 --- a/examples/images/diffusion/scripts/tests/test_checkpoint.py +++ b/examples/images/diffusion/scripts/tests/test_checkpoint.py @@ -1,12 +1,11 @@ import os import sys from copy import deepcopy - -import yaml from datetime import datetime -from diffusers import StableDiffusionPipeline import torch +import yaml +from diffusers import StableDiffusionPipeline from ldm.util import instantiate_from_config from main import get_parser @@ -19,9 +18,7 @@ unet_config = base_config['model']['params']['unet_config'] diffusion_model = instantiate_from_config(unet_config).to("cuda:0") - pipe = StableDiffusionPipeline.from_pretrained( - "/data/scratch/diffuser/stable-diffusion-v1-4" - ).to("cuda:0") + pipe = StableDiffusionPipeline.from_pretrained("/data/scratch/diffuser/stable-diffusion-v1-4").to("cuda:0") dif_model_2 = pipe.unet random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0") @@ -34,4 +31,4 @@ out_1 = diffusion_model(random_input_, time_stamp, context_) out_2 = dif_model_2(random_input_2, time_stamp2, context_2) print(out_1.shape) - print(out_2['sample'].shape) \ No newline at end of file + print(out_2['sample'].shape) diff --git a/examples/images/diffusion/scripts/tests/test_watermark.py b/examples/images/diffusion/scripts/tests/test_watermark.py index f93f8a6e7076..7997973e5e5a 100644 --- a/examples/images/diffusion/scripts/tests/test_watermark.py +++ b/examples/images/diffusion/scripts/tests/test_watermark.py @@ -15,4 +15,4 @@ def testit(img_path): if __name__ == "__main__": - fire.Fire(testit) \ No newline at end of file + fire.Fire(testit) diff --git a/examples/images/diffusion/scripts/train_searcher.py b/examples/images/diffusion/scripts/train_searcher.py index 1e7904889c01..cdde8e439936 100644 --- a/examples/images/diffusion/scripts/train_searcher.py +++ b/examples/images/diffusion/scripts/train_searcher.py @@ -1,20 +1,21 @@ -import os, sys -import numpy as np -import scann import argparse import glob +import os +import sys from multiprocessing import cpu_count -from tqdm import tqdm +import numpy as np +import scann from ldm.util import parallel_data_prefetch +from tqdm import tqdm def search_bruteforce(searcher): return searcher.score_brute_force().build() -def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search): +def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, + num_leaves_to_search): return searcher.tree(num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize). \ @@ -22,11 +23,11 @@ def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): - return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( - reorder_k).build() + return searcher.score_ah(dims_per_block, + anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() -def load_datapool(dpath): +def load_datapool(dpath): def load_single_file(saved_embeddings): compressed = np.load(saved_embeddings) @@ -48,10 +49,14 @@ def load_multi_files(data_archive): data_pool = load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') - - data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + prefetched_data = parallel_data_prefetch(load_multi_files, + data, + n_proc=min(len(data), cpu_count()), + target_data_type='dict') + + data_pool = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys() + } else: raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') @@ -59,15 +64,17 @@ def load_multi_files(data_archive): return data_pool -def train_searcher(opt, - metric='dot_product', - partioning_trainsize=None, - reorder_k=None, - # todo tune - aiq_thld=0.2, - dims_per_block=2, - num_leaves=None, - num_leaves_to_search=None,): +def train_searcher( + opt, + metric='dot_product', + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None, +): data_pool = load_datapool(opt.database) k = opt.knn @@ -77,7 +84,8 @@ def train_searcher(opt, # normalize # embeddings = - searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) + searcher = scann.scann_ops_pybind.builder( + data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) pool_size = data_pool['embedding'].shape[0] print(*(['#'] * 100)) @@ -114,8 +122,8 @@ def train_searcher(opt, print(f'num_leaves: {num_leaves}') print(f'num_leaves_to_search: {num_leaves_to_search}') # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) - searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search) + searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, + num_leaves_to_search) print('Finish training searcher') searcher_savedir = opt.target_path @@ -123,6 +131,7 @@ def train_searcher(opt, searcher.serialize(searcher_savedir) print(f'Saved trained searcher under "{searcher_savedir}"') + if __name__ == '__main__': sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() @@ -142,6 +151,6 @@ def train_searcher(opt, type=int, help='number of nearest neighbors, for which the searcher shall be optimized') - opt, _ = parser.parse_known_args() + opt, _ = parser.parse_known_args() - train_searcher(opt,) \ No newline at end of file + train_searcher(opt,) diff --git a/examples/images/diffusion/scripts/txt2img.py b/examples/images/diffusion/scripts/txt2img.py index 364ebac6c67b..a811c986451a 100644 --- a/examples/images/diffusion/scripts/txt2img.py +++ b/examples/images/diffusion/scripts/txt2img.py @@ -1,29 +1,34 @@ -import argparse, os +import argparse +import os +from itertools import islice + import cv2 -import torch import numpy as np +import torch +from einops import rearrange from omegaconf import OmegaConf from PIL import Image -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange from torchvision.utils import make_grid +from tqdm import tqdm, trange + try: from lightning.pytorch import seed_everything except: from pytorch_lightning import seed_everything -from torch import autocast + from contextlib import nullcontext -from imwatermark import WatermarkEncoder -from ldm.util import instantiate_from_config +from imwatermark import WatermarkEncoder from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from utils import replace_module, getModelSize +from ldm.models.diffusion.plms import PLMSSampler +from ldm.util import instantiate_from_config +from torch import autocast +from utils import getModelSize, replace_module torch.set_grad_enabled(False) + def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) @@ -50,20 +55,16 @@ def load_model_from_config(config, ckpt, verbose=False): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="a professional photograph of an astronaut riding a triceratops", - help="the prompt to render" - ) - parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" - ) + parser.add_argument("--prompt", + type=str, + nargs="?", + default="a professional photograph of an astronaut riding a triceratops", + help="the prompt to render") + parser.add_argument("--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples") parser.add_argument( "--steps", type=int, @@ -161,13 +162,11 @@ def parse_args(): default=42, help="the seed (for reproducible sampling)", ) - parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" - ) + parser.add_argument("--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast") parser.add_argument( "--repeat", type=int, @@ -197,17 +196,17 @@ def main(opt): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") - + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) - + # quantize model if opt.use_int8: model = replace_module(model) # # to compute the model size # getModelSize(model) - + if opt.plms: sampler = PLMSSampler(model) elif opt.dpm: @@ -251,50 +250,50 @@ def main(opt): with torch.no_grad(), \ precision_scope("cuda"), \ model.ema_scope(): - all_samples = list() - for n in trange(opt.n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - uc = None - if opt.scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples, _ = sampler.sample(S=opt.steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) - - x_samples = model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - img = Image.fromarray(x_sample.astype(np.uint8)) - img = put_watermark(img, wm_encoder) - img.save(os.path.join(sample_path, f"{base_count:05}.png")) - base_count += 1 - sample_count += 1 - - all_samples.append(x_samples) - - # additionally, save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - grid = Image.fromarray(grid.astype(np.uint8)) - grid = put_watermark(grid, wm_encoder) - grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) - grid_count += 1 + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples, _ = sampler.sample(S=opt.steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + img = Image.fromarray(x_sample.astype(np.uint8)) + img = put_watermark(img, wm_encoder) + img.save(os.path.join(sample_path, f"{base_count:05}.png")) + base_count += 1 + sample_count += 1 + + all_samples.append(x_samples) + + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + grid = Image.fromarray(grid.astype(np.uint8)) + grid = put_watermark(grid, wm_encoder) + grid.save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") diff --git a/examples/images/diffusion/scripts/utils.py b/examples/images/diffusion/scripts/utils.py index c954b22ca190..5ce1b5e6ea91 100644 --- a/examples/images/diffusion/scripts/utils.py +++ b/examples/images/diffusion/scripts/utils.py @@ -1,22 +1,20 @@ import bitsandbytes as bnb -import torch.nn as nn import torch +import torch.nn as nn + class Linear8bit(nn.Linear): - def __init__( - self, - input_features, - output_features, - bias=True, - has_fp16_weights=False, - memory_efficient_backward=False, - threshold=6.0, - weight_data=None, - bias_data=None - ): - super(Linear8bit, self).__init__( - input_features, output_features, bias - ) + + def __init__(self, + input_features, + output_features, + bias=True, + has_fp16_weights=False, + memory_efficient_backward=False, + threshold=6.0, + weight_data=None, + bias_data=None): + super(Linear8bit, self).__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() self.bias = bias_data self.state.threshold = threshold @@ -24,13 +22,12 @@ def __init__( self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - + self.register_parameter("SCB", nn.Parameter(torch.empty(0), requires_grad=False)) self.weight = weight_data self.quant() - - def quant(self): + def quant(self): weight = self.weight.data.contiguous().half().cuda() CB, _, SCB, _, _ = bnb.functional.double_quant(weight) delattr(self, "weight") @@ -41,32 +38,34 @@ def quant(self): def forward(self, x): self.state.is_training = self.training - + if self.bias is not None and self.bias.dtype != torch.float16: self.bias.data = self.bias.data.half() - + self.state.CB = self.weight.data self.state.SCB = self.SCB.data - + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) del self.state.CxB return out + def replace_module(model): for name, module in model.named_children(): if len(list(module.children())) > 0: replace_module(module) - if isinstance(module, nn.Linear) and "out_proj" not in name: + if isinstance(module, nn.Linear) and "out_proj" not in name: model._modules[name] = Linear8bit( - input_features=module.in_features, - output_features=module.out_features, - threshold=6.0, - weight_data=module.weight, - bias_data=module.bias, - ) + input_features=module.in_features, + output_features=module.out_features, + threshold=6.0, + weight_data=module.weight, + bias_data=module.bias, + ) return model + def getModelSize(model): param_size = 0 param_sum = 0 diff --git a/examples/images/diffusion/setup.py b/examples/images/diffusion/setup.py index a24d54167640..f7684ea9fea1 100644 --- a/examples/images/diffusion/setup.py +++ b/examples/images/diffusion/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name='latent-diffusion', @@ -10,4 +10,4 @@ 'numpy', 'tqdm', ], -) \ No newline at end of file +) diff --git a/examples/images/diffusion/train_ddp.sh b/examples/images/diffusion/train_ddp.sh index 78fe765488c6..8304d6fa8b4f 100644 --- a/examples/images/diffusion/train_ddp.sh +++ b/examples/images/diffusion/train_ddp.sh @@ -1,5 +1,5 @@ -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 -DIFFUSERS_OFFLINE=1 +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 +DIFFUSERS_OFFLINE=1 python main.py --logdir /tmp -t -b /configs/train_ddp.yaml diff --git a/examples/images/dreambooth/colossalai.sh b/examples/images/dreambooth/colossalai.sh index 227d8b8bdb04..2a9a433e2e49 100755 --- a/examples/images/dreambooth/colossalai.sh +++ b/examples/images/dreambooth/colossalai.sh @@ -1,10 +1,10 @@ -export MODEL_NAME= +export MODEL_NAME= export INSTANCE_DIR= export CLASS_DIR="path-to-class-images" export OUTPUT_DIR="path-to-save-model" -HF_DATASETS_OFFLINE=1 -TRANSFORMERS_OFFLINE=1 +HF_DATASETS_OFFLINE=1 +TRANSFORMERS_OFFLINE=1 DIFFUSERS_OFFLINE=1 torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \ diff --git a/examples/images/dreambooth/inference.py b/examples/images/dreambooth/inference.py index c342821c7830..232348e6d058 100644 --- a/examples/images/dreambooth/inference.py +++ b/examples/images/dreambooth/inference.py @@ -1,5 +1,5 @@ -from diffusers import StableDiffusionPipeline, DiffusionPipeline import torch +from diffusers import DiffusionPipeline, StableDiffusionPipeline model_id = print(f"Loading model... from{model_id}") diff --git a/examples/language/gpt/experiments/auto_offload/model_zoo.py b/examples/language/gpt/experiments/auto_offload/model_zoo.py index 35e44608f810..d91680dbe4d5 100644 --- a/examples/language/gpt/experiments/auto_offload/model_zoo.py +++ b/examples/language/gpt/experiments/auto_offload/model_zoo.py @@ -2,14 +2,10 @@ import torch.nn as nn from transformers import GPT2Config, GPT2LMHeadModel + class GPTLMModel(nn.Module): - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): super().__init__() self.model = GPT2LMHeadModel( GPT2Config(n_embd=hidden_size, @@ -36,6 +32,7 @@ def forward(self, logits, labels): # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + def get_gpt2_components(model_type: str, batch_size: int): vocab_size = 1024 seq_len = 8 @@ -62,4 +59,4 @@ def gpt2_data_gen(device="cuda"): kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - return gpt2_model_builder, gpt2_data_gen \ No newline at end of file + return gpt2_model_builder, gpt2_data_gen diff --git a/examples/language/gpt/experiments/auto_offload/requirements.txt b/examples/language/gpt/experiments/auto_offload/requirements.txt index 3ebde8d460aa..137a69e80498 100644 --- a/examples/language/gpt/experiments/auto_offload/requirements.txt +++ b/examples/language/gpt/experiments/auto_offload/requirements.txt @@ -1,2 +1,2 @@ colossalai >= 0.1.12 -torch >= 1.8.1 \ No newline at end of file +torch >= 1.8.1 diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py index 729d1ce4456b..5cc45d84fd2c 100644 --- a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py +++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py @@ -1,20 +1,21 @@ -import time -import pytest import argparse +import time from functools import partial +import pytest import torch -from torch.utils._pytree import tree_map import torch.multiprocessing as mp +from model_zoo import GPTLMLoss, get_gpt2_components +from torch.utils._pytree import tree_map import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.fx.profiler import parameter_size -from colossalai.utils import free_port, get_current_device from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML -from model_zoo import get_gpt2_components, GPTLMLoss +from colossalai.fx.profiler import parameter_size +from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import free_port, get_current_device + def parse_args(): parser = argparse.ArgumentParser() @@ -24,6 +25,7 @@ def parse_args(): parser.add_argument('--memory_budget', type=float, default=16) return parser.parse_args() + @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') def train_gpt(args): memory_budget = args.memory_budget * 1024 * 1024 * 1024 @@ -33,13 +35,16 @@ def train_gpt(args): # build model model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size) - label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) + label = torch.randint(low=0, high=128, size=( + 64, + 8, + ), device=get_current_device()) criterion = GPTLMLoss() start_time = time.time() model = model_builder() model.train() - param_size = parameter_size(model) / 1024 ** 2 / 2 + param_size = parameter_size(model) / 1024**2 / 2 init_time = time.time() - start_time print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") @@ -74,20 +79,20 @@ def train_gpt(args): torch.cuda.synchronize() exec_time = sum(sorted(time_list)[:5]) / 5 - runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 - runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 print(f'solver_type: {solver_type} | model_type: {model_type}') - print( - f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' - ) + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') print(time_list) + def run(rank, world_size, port, args): config = {} colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') train_gpt(args) + if __name__ == '__main__': args = parse_args() run_func = partial(run, world_size=1, port=free_port(), args=args) diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py index 4993ce25db17..f5fcb80eb77e 100755 --- a/examples/language/opt/train_gemini_opt.py +++ b/examples/language/opt/train_gemini_opt.py @@ -36,11 +36,10 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer from colossalai.nn.parallel import GeminiDDP +from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.utils import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ProcessGroup, ShardSpec - def get_data(batch_size, seq_len, vocab_size): input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) @@ -179,13 +178,15 @@ def main(): # build model if args.model_name_or_path is None: logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev, dtype=torch.half, + with ColoInitContext(device=init_dev, + dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg): model = OPTForCausalLM(config) else: logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev, dtype=torch.half, + with ColoInitContext(device=init_dev, + dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg): model = OPTForCausalLM.from_pretrained(args.model_name_or_path, @@ -198,8 +199,11 @@ def main(): numel = sum([p.numel() for p in model.parameters()]) PLACEMENT_POLICY = 'cpu' - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, - pin_memory=True, strict_ddp_mode=args.shardinit) + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=PLACEMENT_POLICY, + pin_memory=True, + strict_ddp_mode=args.shardinit) optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) SEQ_LEN = 1024 diff --git a/examples/language/roberta/README.md b/examples/language/roberta/README.md index a42b1935dd85..631eac4dc689 100644 --- a/examples/language/roberta/README.md +++ b/examples/language/roberta/README.md @@ -11,7 +11,7 @@ ssh-keygen ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination ``` -- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below. +- In all hosts, edit /etc/hosts to record all hosts' name and ip.The example is shown below. ```bash 192.168.2.1 GPU001 @@ -29,7 +29,7 @@ ssh-copy-id -i ~/.ssh/id_rsa.pub ip_destination service ssh restart ``` -## 1. Corpus Preprocessing +## 1. Corpus Preprocessing ```bash cd preprocessing ``` diff --git a/examples/language/roberta/configs/colossalai_ddp.py b/examples/language/roberta/configs/colossalai_ddp.py index c3c59aa4079c..3ac32ce65dec 100644 --- a/examples/language/roberta/configs/colossalai_ddp.py +++ b/examples/language/roberta/configs/colossalai_ddp.py @@ -1,4 +1,4 @@ -from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.nn.optimizer import FusedAdam +from colossalai.zero.shard_utils import TensorShardStrategy clip_grad_norm = 1.0 diff --git a/examples/language/roberta/configs/colossalai_zero.py b/examples/language/roberta/configs/colossalai_zero.py index c5debdce0988..5a98839695a6 100644 --- a/examples/language/roberta/configs/colossalai_zero.py +++ b/examples/language/roberta/configs/colossalai_zero.py @@ -1,5 +1,5 @@ -from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.nn.optimizer import FusedAdam +from colossalai.zero.shard_utils import TensorShardStrategy # fp16 = dict( # mode=AMP_TYPE.TORCH, @@ -29,4 +29,4 @@ weight_decay=1e-2, ) -# 64433 \ No newline at end of file +# 64433 diff --git a/examples/language/roberta/preprocessing/README.md b/examples/language/roberta/preprocessing/README.md index 1dbd745ab9bd..17cc2f4dc22c 100644 --- a/examples/language/roberta/preprocessing/README.md +++ b/examples/language/roberta/preprocessing/README.md @@ -21,7 +21,7 @@ This folder is used to preprocess chinese corpus with Whole Word Masked. You can ### 2.1. Split Sentence & Split data into multiple shard: -Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch. +Firstly, each file has multiple documents, and each document contains multiple sentences. Split sentence through punctuation, such as `。!`. **Secondly, split data into multiple shard based on server hardware (cpu, cpu memory, hard disk) and corpus size.** Each shard contains a part of corpus, and the model needs to train all the shards as one epoch. In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.** ```python @@ -49,7 +49,7 @@ python sentence_split.py --input_path /orginal_corpus --output_path /shard --sha ] ``` -Output txt: +Output txt: ``` 我今天去打篮球。 @@ -76,7 +76,7 @@ make * `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ... * `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ... -* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) +* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) * `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed** * `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document * `--worker`: number of process @@ -91,7 +91,7 @@ make 下周请假。 ``` -Output h5+numpy: +Output h5+numpy: ``` 'input_ids': [[id0,id1,id2,id3,id4,id5,id6,0,0..], @@ -102,4 +102,4 @@ make ...] 'masked_lm_positions': [[label1,-1,-1,label2,-1...], ...] -``` \ No newline at end of file +``` diff --git a/examples/language/roberta/preprocessing/get_mask.py b/examples/language/roberta/preprocessing/get_mask.py index da297f98e6c9..775f85433697 100644 --- a/examples/language/roberta/preprocessing/get_mask.py +++ b/examples/language/roberta/preprocessing/get_mask.py @@ -1,20 +1,22 @@ -import torch +import collections +import logging import os -from enum import IntEnum -from random import choice import random -import collections import time -import logging +from enum import IntEnum +from random import choice + import jieba +import torch + jieba.setLogLevel(logging.CRITICAL) import re -import numpy as np + import mask +import numpy as np PAD = 0 -MaskedLMInstance = collections.namedtuple("MaskedLMInstance", - ["index", "label"]) +MaskedLMInstance = collections.namedtuple("MaskedLMInstance", ["index", "label"]) def map_to_numpy(data): @@ -22,6 +24,7 @@ def map_to_numpy(data): class PreTrainingDataset(): + def __init__(self, tokenizer, max_seq_length, @@ -43,17 +46,15 @@ def __init__(self, self.mlm_tamper_p = 0.05 self.mlm_maintain_p = 0.1 - def tokenize(self, doc): temp = [] for d in doc: temp.append(self.tokenizer.tokenize(d)) return temp - def create_training_instance(self, instance): is_next = 1 - raw_text_list = self.get_new_segment(instance) + raw_text_list = self.get_new_segment(instance) tokens_a = raw_text_list assert len(tokens_a) == len(instance) # tokens_a, tokens_b, is_next = instance.get_values() @@ -83,8 +84,9 @@ def create_training_instance(self, instance): # Get Masked LM predictions if self.backend == 'c++': - output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions(tokens, original_tokens, self.vocab_words, - self.tokenizer.vocab, self.max_predictions_per_seq, self.masked_lm_prob) + output_tokens, masked_lm_output = mask.create_whole_masked_lm_predictions( + tokens, original_tokens, self.vocab_words, self.tokenizer.vocab, self.max_predictions_per_seq, + self.masked_lm_prob) elif self.backend == 'python': output_tokens, masked_lm_output = self.create_whole_masked_lm_predictions(tokens) @@ -102,29 +104,25 @@ def create_training_instance(self, instance): map_to_numpy(input_mask), map_to_numpy(segment_ids), map_to_numpy(masked_lm_output), - map_to_numpy([is_next]) + map_to_numpy([is_next]) ]) - def create_masked_lm_predictions(self, tokens): cand_indexes = [] for i, token in enumerate(tokens): if token == "[CLS]" or token == "[SEP]": continue - if (self.do_whole_word_mask and len(cand_indexes) >= 1 and - token.startswith("##")): + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): cand_indexes[-1].append(i) else: cand_indexes.append([i]) - + # cand_indexes.append(i) random.shuffle(cand_indexes) output_tokens = list(tokens) - num_to_predict = min( - self.max_predictions_per_seq, - max(1, int(round(len(tokens) * self.masked_lm_prob)))) + num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob)))) masked_lms = [] covered_indexes = set() @@ -145,13 +143,10 @@ def create_masked_lm_predictions(self, tokens): masked_token = tokens[index] # 10% replace w/ random word else: - masked_token = self.vocab_words[random.randint( - 0, - len(self.vocab_words) - 1)] + masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] output_tokens[index] = masked_token - masked_lms.append( - MaskedLMInstance(index=index, label=tokens[index])) + masked_lms.append(MaskedLMInstance(index=index, label=tokens[index])) masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_output = [-1] * len(output_tokens) @@ -160,7 +155,6 @@ def create_masked_lm_predictions(self, tokens): return (output_tokens, masked_lm_output) - def get_new_segment(self, segment): """ 输入一句话,返回一句经过处理的话: 为了支持中文全称mask,将被分开的词,将上特殊标记("#"),使得后续处理模块,能够知道哪些字是属于同一个词的。 @@ -172,7 +166,7 @@ def get_new_segment(self, segment): new_segment = [] i = 0 while i < len(segment): - if len(self.rec.findall(segment[i])) == 0: # 不是中文的,原文加进去。 + if len(self.rec.findall(segment[i])) == 0: # 不是中文的,原文加进去。 new_segment.append(segment[i]) i += 1 continue @@ -181,10 +175,10 @@ def get_new_segment(self, segment): for length in range(3, 0, -1): if i + length > len(segment): continue - if ''.join(segment[i: i+length]) in seq_cws_dict: + if ''.join(segment[i:i + length]) in seq_cws_dict: new_segment.append(segment[i]) for l in range(1, length): - new_segment.append('##' + segment[i+l]) + new_segment.append('##' + segment[i + l]) i += length has_add = True break @@ -193,7 +187,6 @@ def get_new_segment(self, segment): i += 1 return new_segment - def create_whole_masked_lm_predictions(self, tokens): """Creates the predictions for the masked LM objective.""" @@ -210,18 +203,16 @@ def create_whole_masked_lm_predictions(self, tokens): # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. - if (self.do_whole_word_mask and len(cand_indexes) >= 1 and - token.startswith("##")): + if (self.do_whole_word_mask and len(cand_indexes) >= 1 and token.startswith("##")): cand_indexes[-1].append(i) else: cand_indexes.append([i]) random.shuffle(cand_indexes) - output_tokens = [t[2:] if len(self.whole_rec.findall(t))>0 else t for t in tokens] # 去掉"##" + output_tokens = [t[2:] if len(self.whole_rec.findall(t)) > 0 else t for t in tokens] # 去掉"##" - num_to_predict = min(self.max_predictions_per_seq, - max(1, int(round(len(tokens) * self.masked_lm_prob)))) + num_to_predict = min(self.max_predictions_per_seq, max(1, int(round(len(tokens) * self.masked_lm_prob)))) masked_lms = [] covered_indexes = set() @@ -249,14 +240,18 @@ def create_whole_masked_lm_predictions(self, tokens): else: # 10% of the time, keep original if random.random() < 0.5: - masked_token = tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index] # 去掉"##" + masked_token = tokens[index][2:] if len(self.whole_rec.findall( + tokens[index])) > 0 else tokens[index] # 去掉"##" # 10% of the time, replace with random word else: masked_token = self.vocab_words[random.randint(0, len(self.vocab_words) - 1)] output_tokens[index] = masked_token - masked_lms.append(MaskedLMInstance(index=index, label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index]))>0 else tokens[index])) + masked_lms.append( + MaskedLMInstance( + index=index, + label=tokens[index][2:] if len(self.whole_rec.findall(tokens[index])) > 0 else tokens[index])) assert len(masked_lms) <= num_to_predict masked_lms = sorted(masked_lms, key=lambda x: x.index) masked_lm_output = [-1] * len(output_tokens) diff --git a/examples/language/roberta/preprocessing/mask.cpp b/examples/language/roberta/preprocessing/mask.cpp index 8355c45cff0a..13c497ed7edc 100644 --- a/examples/language/roberta/preprocessing/mask.cpp +++ b/examples/language/roberta/preprocessing/mask.cpp @@ -1,184 +1,190 @@ +#include +#include +#include +#include + #include +#include #include #include -#include -#include -#include -#include #include -#include +#include #include -#include -#include #include -#include #include +#include +#include namespace py = pybind11; const int32_t LONG_SENTENCE_LEN = 512; struct MaskedLMInstance { - int index; - std::string label; - MaskedLMInstance(int index, std::string label) { - this->index = index; - this->label = label; - } + int index; + std::string label; + MaskedLMInstance(int index, std::string label) { + this->index = index; + this->label = label; + } }; -auto get_new_segment(std::vector segment, std::vector segment_jieba, const std::vector chinese_vocab) { // const std::unordered_set &chinese_vocab - std::unordered_set seq_cws_dict; - for (auto word : segment_jieba) { - seq_cws_dict.insert(word); +auto get_new_segment(std::vector segment, + std::vector segment_jieba, + const std::vector chinese_vocab) { // const + // std::unordered_set + // &chinese_vocab + std::unordered_set seq_cws_dict; + for (auto word : segment_jieba) { + seq_cws_dict.insert(word); + } + int i = 0; + std::vector new_segment; + int segment_size = segment.size(); + while (i < segment_size) { + if (!chinese_vocab[i]) { // chinese_vocab.find(segment[i]) == + // chinese_vocab.end() + new_segment.emplace_back(segment[i]); + i += 1; + continue; } - int i = 0; - std::vector new_segment; - int segment_size = segment.size(); - while (i < segment_size) { - if (!chinese_vocab[i]) { //chinese_vocab.find(segment[i]) == chinese_vocab.end() - new_segment.emplace_back(segment[i]); - i += 1; - continue; - } - bool has_add = false; - for (int length = 3; length >= 1; length--) { - if (i + length > segment_size) { - continue; - } - std::string chinese_word = ""; - for (int j = i; j < i + length; j++) { - chinese_word += segment[j]; - } - if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) { - new_segment.emplace_back(segment[i]); - for (int j = i + 1; j < i + length; j++) { - new_segment.emplace_back("##" + segment[j]); - } - i += length; - has_add = true; - break; - } - } - if (!has_add) { - new_segment.emplace_back(segment[i]); - i += 1; + bool has_add = false; + for (int length = 3; length >= 1; length--) { + if (i + length > segment_size) { + continue; + } + std::string chinese_word = ""; + for (int j = i; j < i + length; j++) { + chinese_word += segment[j]; + } + if (seq_cws_dict.find(chinese_word) != seq_cws_dict.end()) { + new_segment.emplace_back(segment[i]); + for (int j = i + 1; j < i + length; j++) { + new_segment.emplace_back("##" + segment[j]); } + i += length; + has_add = true; + break; + } + } + if (!has_add) { + new_segment.emplace_back(segment[i]); + i += 1; } + } - return new_segment; + return new_segment; } -bool startsWith(const std::string& s, const std::string& sub) { - return s.find(sub) == 0 ? true : false; +bool startsWith(const std::string &s, const std::string &sub) { + return s.find(sub) == 0 ? true : false; } -auto create_whole_masked_lm_predictions(std::vector &tokens, - const std::vector &original_tokens, - const std::vector &vocab_words, - std::map &vocab, - const int max_predictions_per_seq, - const double masked_lm_prob) { - // for (auto item : vocab) { - // std::cout << "key=" << std::string(py::str(item.first)) << ", " - // << "value=" << std::string(py::str(item.second)) << std::endl; - // } - std::vector > cand_indexes; - std::vector cand_temp; - int tokens_size = tokens.size(); - std::string prefix = "##"; - bool do_whole_masked = true; - - for (int i = 0; i < tokens_size; i++) { - if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") { - continue; - } - if (do_whole_masked && (cand_indexes.size() > 0) && (tokens[i].rfind(prefix, 0) == 0)) { - cand_temp.emplace_back(i); - } - else { - if (cand_temp.size() > 0) { - cand_indexes.emplace_back(cand_temp); - } - cand_temp.clear(); - cand_temp.emplace_back(i); - } +auto create_whole_masked_lm_predictions( + std::vector &tokens, + const std::vector &original_tokens, + const std::vector &vocab_words, + std::map &vocab, const int max_predictions_per_seq, + const double masked_lm_prob) { + // for (auto item : vocab) { + // std::cout << "key=" << std::string(py::str(item.first)) << ", " + // << "value=" << std::string(py::str(item.second)) << + // std::endl; + // } + std::vector > cand_indexes; + std::vector cand_temp; + int tokens_size = tokens.size(); + std::string prefix = "##"; + bool do_whole_masked = true; + + for (int i = 0; i < tokens_size; i++) { + if (tokens[i] == "[CLS]" || tokens[i] == "[SEP]") { + continue; } - auto seed = std::chrono::system_clock::now().time_since_epoch().count(); - std::shuffle(cand_indexes.begin(), cand_indexes.end(), std::default_random_engine(seed)); - // for (auto i : cand_indexes) { - // for (auto j : i) { - // std::cout << tokens[j] << " "; - // } - // std::cout << std::endl; - // } - // for (auto i : output_tokens) { - // std::cout << i; - // } - // std::cout << std::endl; + if (do_whole_masked && (cand_indexes.size() > 0) && + (tokens[i].rfind(prefix, 0) == 0)) { + cand_temp.emplace_back(i); + } else { + if (cand_temp.size() > 0) { + cand_indexes.emplace_back(cand_temp); + } + cand_temp.clear(); + cand_temp.emplace_back(i); + } + } + auto seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::shuffle(cand_indexes.begin(), cand_indexes.end(), + std::default_random_engine(seed)); + // for (auto i : cand_indexes) { + // for (auto j : i) { + // std::cout << tokens[j] << " "; + // } + // std::cout << std::endl; + // } + // for (auto i : output_tokens) { + // std::cout << i; + // } + // std::cout << std::endl; - int num_to_predict = std::min(max_predictions_per_seq, - std::max(1, int(tokens_size * masked_lm_prob))); - // std::cout << num_to_predict << std::endl; - - std::set covered_indexes; - std::vector masked_lm_output(tokens_size, -1); - int vocab_words_len = vocab_words.size(); - std::default_random_engine e(seed); - std::uniform_real_distribution u1(0.0, 1.0); - std::uniform_int_distribution u2(0, vocab_words_len - 1); - int mask_cnt = 0; - std::vector output_tokens; - output_tokens = original_tokens; + int num_to_predict = std::min(max_predictions_per_seq, + std::max(1, int(tokens_size * masked_lm_prob))); + // std::cout << num_to_predict << std::endl; - for (auto index_set : cand_indexes) { - if (mask_cnt > num_to_predict) { - break; - } - int index_set_size = index_set.size(); - if (mask_cnt + index_set_size > num_to_predict) { - continue; - } - bool is_any_index_covered = false; - for (auto index : index_set) { - if (covered_indexes.find(index) != covered_indexes.end()) { - is_any_index_covered = true; - break; - } - } - if (is_any_index_covered) { - continue; - } - for (auto index : index_set) { - - covered_indexes.insert(index); - std::string masked_token; - if (u1(e) < 0.8) { - masked_token = "[MASK]"; - } - else { - if (u1(e) < 0.5) { - masked_token = output_tokens[index]; - } - else { - int random_index = u2(e); - masked_token = vocab_words[random_index]; - } - } - // masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index])); - masked_lm_output[index] = vocab[output_tokens[index]]; - output_tokens[index] = masked_token; - mask_cnt++; + std::set covered_indexes; + std::vector masked_lm_output(tokens_size, -1); + int vocab_words_len = vocab_words.size(); + std::default_random_engine e(seed); + std::uniform_real_distribution u1(0.0, 1.0); + std::uniform_int_distribution u2(0, vocab_words_len - 1); + int mask_cnt = 0; + std::vector output_tokens; + output_tokens = original_tokens; + + for (auto index_set : cand_indexes) { + if (mask_cnt > num_to_predict) { + break; + } + int index_set_size = index_set.size(); + if (mask_cnt + index_set_size > num_to_predict) { + continue; + } + bool is_any_index_covered = false; + for (auto index : index_set) { + if (covered_indexes.find(index) != covered_indexes.end()) { + is_any_index_covered = true; + break; + } + } + if (is_any_index_covered) { + continue; + } + for (auto index : index_set) { + covered_indexes.insert(index); + std::string masked_token; + if (u1(e) < 0.8) { + masked_token = "[MASK]"; + } else { + if (u1(e) < 0.5) { + masked_token = output_tokens[index]; + } else { + int random_index = u2(e); + masked_token = vocab_words[random_index]; } + } + // masked_lms.emplace_back(MaskedLMInstance(index, output_tokens[index])); + masked_lm_output[index] = vocab[output_tokens[index]]; + output_tokens[index] = masked_token; + mask_cnt++; } - - // for (auto p : masked_lms) { - // masked_lm_output[p.index] = vocab[p.label]; - // } - return std::make_tuple(output_tokens, masked_lm_output); + } + + // for (auto p : masked_lms) { + // masked_lm_output[p.index] = vocab[p.label]; + // } + return std::make_tuple(output_tokens, masked_lm_output); } PYBIND11_MODULE(mask, m) { - m.def("create_whole_masked_lm_predictions", &create_whole_masked_lm_predictions); - m.def("get_new_segment", &get_new_segment); + m.def("create_whole_masked_lm_predictions", + &create_whole_masked_lm_predictions); + m.def("get_new_segment", &get_new_segment); } diff --git a/examples/language/roberta/preprocessing/sentence_split.py b/examples/language/roberta/preprocessing/sentence_split.py index 231be152b067..b9096424e1c2 100644 --- a/examples/language/roberta/preprocessing/sentence_split.py +++ b/examples/language/roberta/preprocessing/sentence_split.py @@ -1,13 +1,14 @@ - +import argparse +import functools +import json import multiprocessing import os import re -from tqdm import tqdm -from typing import List -import json import time -import argparse -import functools +from typing import List + +from tqdm import tqdm + def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[str]: """ @@ -20,16 +21,17 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s sent_list = [] try: if flag == "zh": - document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) # 单字符断句符 - document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) # 特殊引号 + document = re.sub('(?P([。?!…](?![”’"\'])))', r'\g\n', document) # 单字符断句符 + document = re.sub('(?P([。?!]|…{1,2})[”’"\'])', r'\g\n', document) # 特殊引号 elif flag == "en": - document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) # 英文单字符断句符 - document = re.sub('(?P([?!.]["\']))', r'\g\n', document) # 特殊引号 + document = re.sub('(?P([.?!](?![”’"\'])))', r'\g\n', document) # 英文单字符断句符 + document = re.sub('(?P([?!.]["\']))', r'\g\n', document) # 特殊引号 else: - document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', document) # 单字符断句符 - + document = re.sub('(?P([。?!….?!](?![”’"\'])))', r'\g\n', + document) # 单字符断句符 + document = re.sub('(?P(([。?!.!?]|…{1,2})[”’"\']))', r'\g\n', - document) # 特殊引号 + document) # 特殊引号 sent_list_ori = document.splitlines() for sent in sent_list_ori: @@ -50,17 +52,15 @@ def split_sentence(document: str, flag: str = "all", limit: int = 510) -> List[s return sent_list -def get_sent(output_path, - input_path, - fin_list=[], host=-1, seq_len=512) -> None: +def get_sent(output_path, input_path, fin_list=[], host=-1, seq_len=512) -> None: workers = 32 if input_path[-1] == '/': input_path = input_path[:-1] - + cur_path = os.path.join(output_path, str(host) + '.txt') - new_split_sentence = functools.partial(split_sentence, limit=seq_len-2) + new_split_sentence = functools.partial(split_sentence, limit=seq_len - 2) with open(cur_path, 'w', encoding='utf-8') as f: for fi, fin_path in enumerate(fin_list): if not os.path.exists(os.path.join(input_path, fin_path[0])): @@ -69,7 +69,7 @@ def get_sent(output_path, continue print("Processing ", fin_path[0], " ", fi) - + with open(os.path.join(input_path, fin_path[0]), 'r') as fin: f_data = [l['content'] for l in json.load(fin)] @@ -106,17 +106,17 @@ def getFileSize(filepath, shard): real_shard.append(temp) accu_size = 0 temp = [] - + if len(temp) > 0: real_shard.append(temp) - + return real_shard def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): import socket host = int(socket.gethostname().split(server_name)[-1]) - + fin_list = real_shard[server_num * base + host - 1] print(fin_list) print(f'I am server {host}, process {server_num * base + host - 1}, len {len(fin_list)}') @@ -133,28 +133,24 @@ def get_start_end(real_shard, base=0, server_num=10, server_name='GPU'): parser.add_argument('--output_path', type=str, required=True, help='output path of shard which has split sentence') args = parser.parse_args() - server_num = args.server_num + server_num = args.server_num seq_len = args.seq_len - shard = args.shard + shard = args.shard input_path = args.input_path - output_path = args.output_path + output_path = args.output_path real_shard = getFileSize(input_path, shard) start = time.time() for index, shard in enumerate(real_shard): - get_sent(output_path, - input_path, - fin_list=shard, - host=index, - seq_len=seq_len) + get_sent(output_path, input_path, fin_list=shard, host=index, seq_len=seq_len) print(f'cost {str(time.time() - start)}') # if you have multiple server, you can use code below or modify code to openmpi - + # for i in range(len(real_shard) // server_num + 1): # fin_list, host = get_start_end(real_shard, i) - + # start = time.time() # get_sent(output_path, # input_path, diff --git a/examples/language/roberta/preprocessing/tokenize_mask.py b/examples/language/roberta/preprocessing/tokenize_mask.py index b33871d5d037..30bfed7036cf 100644 --- a/examples/language/roberta/preprocessing/tokenize_mask.py +++ b/examples/language/roberta/preprocessing/tokenize_mask.py @@ -1,19 +1,19 @@ -import time +import argparse +import multiprocessing import os -import psutil -import h5py import socket -import argparse +import time +from random import shuffle + +import h5py import numpy as np -import multiprocessing +import psutil +from get_mask import PreTrainingDataset from tqdm import tqdm -from random import shuffle from transformers import AutoTokenizer -from get_mask import PreTrainingDataset def get_raw_instance(document, max_sequence_length=512): - """ 获取初步的训练实例,将整段按照max_sequence_length切分成多个部分,并以多个处理好的实例的形式返回。 :param document: 一整段 @@ -26,25 +26,25 @@ def get_raw_instance(document, max_sequence_length=512): sizes = [len(seq) for seq in document] result_list = [] - curr_seq = [] # 当前处理的序列 + curr_seq = [] # 当前处理的序列 sz_idx = 0 while sz_idx < len(sizes): # 当前句子加上新的句子,如果长度小于最大限制,则合并当前句子和新句子;否则即超过了最大限制,那么做为一个新的序列加到目标列表中 - - if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: + + if len(curr_seq) + sizes[sz_idx] <= max_sequence_length_allowed: # or len(curr_seq)==0: curr_seq += document[sz_idx] sz_idx += 1 elif sizes[sz_idx] >= max_sequence_length_allowed: if len(curr_seq) > 0: result_list.append(curr_seq) curr_seq = [] - result_list.append(document[sz_idx][ : max_sequence_length_allowed]) + result_list.append(document[sz_idx][:max_sequence_length_allowed]) sz_idx += 1 else: result_list.append(curr_seq) curr_seq = [] # 对最后一个序列进行处理,如果太短的话,丢弃掉。 - if len(curr_seq) > max_sequence_length_allowed / 2: # /2 + if len(curr_seq) > max_sequence_length_allowed / 2: # /2 result_list.append(curr_seq) # # 计算总共可以得到多少份 @@ -72,8 +72,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): # document = line # if len(document.split("")) <= 3: # continue - if len(line - ) > 0 and line[:2] == "]]": # This is end of document + if len(line) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: @@ -86,8 +85,8 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): # print(len(documents)) # print(len(documents[0])) # print(documents[0][0:10]) - from typing import List import multiprocessing + from typing import List ans = [] for docs in tqdm(documents): @@ -100,7 +99,7 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): raw_ins = get_raw_instance(a) instances.extend(raw_ins) del ans - + print('len instance', len(instances)) sen_num = len(instances) @@ -118,21 +117,15 @@ def split_numpy_chunk(path, tokenizer, pretrain_data, host): masked_lm_output[index] = mask_dict[3] with h5py.File(f'/output/{host}.h5', 'w') as hf: - hf.create_dataset("input_ids", data=input_ids) - hf.create_dataset("input_mask", data=input_ids) - hf.create_dataset("segment_ids", data=segment_ids) - hf.create_dataset("masked_lm_positions", data=masked_lm_output) + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_ids) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) del instances -def split_numpy_chunk_pool(input_path, - output_path, - pretrain_data, - worker, - dupe_factor, - seq_len, - file_name): +def split_numpy_chunk_pool(input_path, output_path, pretrain_data, worker, dupe_factor, seq_len, file_name): if os.path.exists(os.path.join(output_path, f'{file_name}.h5')): print(f'{file_name}.h5 exists') @@ -146,8 +139,7 @@ def split_numpy_chunk_pool(input_path, document = [] for i, line in enumerate(tqdm(fd)): line = line.strip() - if len(line - ) > 0 and line[:2] == "]]": # This is end of document + if len(line) > 0 and line[:2] == "]]": # This is end of document documents.append(document) document = [] elif len(line) >= 2: @@ -155,7 +147,7 @@ def split_numpy_chunk_pool(input_path, if len(document) > 0: documents.append(document) print(f'read_file cost {time.time() - s}, length is {len(documents)}') - + ans = [] s = time.time() pool = multiprocessing.Pool(worker) @@ -171,7 +163,7 @@ def split_numpy_chunk_pool(input_path, raw_ins = get_raw_instance(a, max_sequence_length=seq_len) instances.extend(raw_ins) del ans - + print('len instance', len(instances)) new_instances = [] @@ -201,10 +193,10 @@ def split_numpy_chunk_pool(input_path, print((time.time() - s) / 60) with h5py.File(os.path.join(output_path, f'{file_name}.h5'), 'w') as hf: - hf.create_dataset("input_ids", data=input_ids) - hf.create_dataset("input_mask", data=input_mask) - hf.create_dataset("segment_ids", data=segment_ids) - hf.create_dataset("masked_lm_positions", data=masked_lm_output) + hf.create_dataset("input_ids", data=input_ids) + hf.create_dataset("input_mask", data=input_mask) + hf.create_dataset("segment_ids", data=segment_ids) + hf.create_dataset("masked_lm_positions", data=masked_lm_output) del instances @@ -214,22 +206,31 @@ def split_numpy_chunk_pool(input_path, parser = argparse.ArgumentParser() parser.add_argument('--tokenizer_path', type=str, required=True, default=10, help='path of tokenizer') parser.add_argument('--seq_len', type=int, default=512, help='sequence length') - parser.add_argument('--max_predictions_per_seq', type=int, default=80, help='number of shards, e.g., 10, 50, or 100') + parser.add_argument('--max_predictions_per_seq', + type=int, + default=80, + help='number of shards, e.g., 10, 50, or 100') parser.add_argument('--input_path', type=str, required=True, help='input path of shard which has split sentence') parser.add_argument('--output_path', type=str, required=True, help='output path of h5 contains token id') - parser.add_argument('--backend', type=str, default='python', help='backend of mask token, python, c++, numpy respectively') - parser.add_argument('--dupe_factor', type=int, default=1, help='specifies how many times the preprocessor repeats to create the input from the same article/document') + parser.add_argument('--backend', + type=str, + default='python', + help='backend of mask token, python, c++, numpy respectively') + parser.add_argument( + '--dupe_factor', + type=int, + default=1, + help='specifies how many times the preprocessor repeats to create the input from the same article/document') parser.add_argument('--worker', type=int, default=32, help='number of process') parser.add_argument('--server_num', type=int, default=10, help='number of servers') args = parser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) - pretrain_data = PreTrainingDataset(tokenizer, - args.seq_len, - args.backend, - max_predictions_per_seq=args.max_predictions_per_seq) - - + pretrain_data = PreTrainingDataset(tokenizer, + args.seq_len, + args.backend, + max_predictions_per_seq=args.max_predictions_per_seq) + data_len = len(os.listdir(args.input_path)) for i in range(data_len): @@ -237,15 +238,10 @@ def split_numpy_chunk_pool(input_path, if os.path.exists(input_path): start = time.time() print(f'process {input_path}') - split_numpy_chunk_pool(input_path, - args.output_path, - pretrain_data, - args.worker, - args.dupe_factor, - args.seq_len, - i) + split_numpy_chunk_pool(input_path, args.output_path, pretrain_data, args.worker, args.dupe_factor, + args.seq_len, i) end_ = time.time() - print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024) ) + print(u'memory:%.4f GB' % (psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 / 1024)) print(f'has cost {(end_ - start) / 60}') print('-' * 100) print('') @@ -259,9 +255,9 @@ def split_numpy_chunk_pool(input_path, # if os.path.exists(input_path): # start = time.time() # print(f'I am server {host}, process {input_path}') - # split_numpy_chunk_pool(input_path, - # args.output_path, - # pretrain_data, + # split_numpy_chunk_pool(input_path, + # args.output_path, + # pretrain_data, # args.worker, # args.dupe_factor, # args.seq_len, @@ -271,5 +267,3 @@ def split_numpy_chunk_pool(input_path, # print(f'has cost {(end_ - start) / 60}') # print('-' * 100) # print('') - - diff --git a/examples/language/roberta/pretraining/README.md b/examples/language/roberta/pretraining/README.md index 055d6969654d..c248fc1f5708 100644 --- a/examples/language/roberta/pretraining/README.md +++ b/examples/language/roberta/pretraining/README.md @@ -19,6 +19,5 @@ bash run_pretrain.sh bash run_pretrain_resume.sh ``` * `--resume_train`: whether to resume training -* `--load_pretrain_model`: absolute path which contains model checkpoint -* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint - +* `--load_pretrain_model`: absolute path which contains model checkpoint +* `--load_optimizer_lr`: absolute path which contains optimizer checkpoint diff --git a/examples/language/roberta/pretraining/arguments.py b/examples/language/roberta/pretraining/arguments.py index 3a9370e00b0c..c24bb28d5a75 100644 --- a/examples/language/roberta/pretraining/arguments.py +++ b/examples/language/roberta/pretraining/arguments.py @@ -1,152 +1,62 @@ -import colossalai from numpy import require +import colossalai + __all__ = ['parse_args'] def parse_args(): parser = colossalai.get_default_parser() - - parser.add_argument( - '--lr', - type=float, - required=True, - help='initial learning rate') - parser.add_argument( - '--epoch', - type=int, - required=True, - help='number of epoch') - parser.add_argument( - '--data_path_prefix', - type=str, - required=True, - help="location of the train data corpus") - parser.add_argument( - '--eval_data_path_prefix', - type=str, - required=True, - help='location of the evaluation data corpus') - parser.add_argument( - '--tokenizer_path', - type=str, - required=True, - help='location of the tokenizer') - parser.add_argument( - '--max_seq_length', - type=int, - default=512, - help='sequence length') - parser.add_argument( - '--refresh_bucket_size', - type=int, - default=1, - help= - "This param makes sure that a certain task is repeated for this time steps to \ + + parser.add_argument('--lr', type=float, required=True, help='initial learning rate') + parser.add_argument('--epoch', type=int, required=True, help='number of epoch') + parser.add_argument('--data_path_prefix', type=str, required=True, help="location of the train data corpus") + parser.add_argument('--eval_data_path_prefix', + type=str, + required=True, + help='location of the evaluation data corpus') + parser.add_argument('--tokenizer_path', type=str, required=True, help='location of the tokenizer') + parser.add_argument('--max_seq_length', type=int, default=512, help='sequence length') + parser.add_argument('--refresh_bucket_size', + type=int, + default=1, + help="This param makes sure that a certain task is repeated for this time steps to \ optimise on the back propogation speed with APEX's DistributedDataParallel") - parser.add_argument( - "--max_predictions_per_seq", - "--max_pred", - default=80, - type=int, - help= - "The maximum number of masked tokens in a sequence to be predicted.") - parser.add_argument( - "--gradient_accumulation_steps", - default=1, - type=int, - help="accumulation_steps") - parser.add_argument( - "--train_micro_batch_size_per_gpu", - default=2, - type=int, - required=True, - help="train batch size") - parser.add_argument( - "--eval_micro_batch_size_per_gpu", - default=2, - type=int, - required=True, - help="eval batch size") - parser.add_argument( - "--num_workers", - default=8, - type=int, - help="") - parser.add_argument( - "--async_worker", - action='store_true', - help="") - parser.add_argument( - "--bert_config", - required=True, - type=str, - help="location of config.json") - parser.add_argument( - "--wandb", - action='store_true', - help="use wandb to watch model") - parser.add_argument( - "--wandb_project_name", - default='roberta', - help="wandb project name") - parser.add_argument( - "--log_interval", - default=100, - type=int, - help="report interval") - parser.add_argument( - "--log_path", - type=str, - required=True, - help="log file which records train step") - parser.add_argument( - "--tensorboard_path", - type=str, - required=True, - help="location of tensorboard file") - parser.add_argument( - "--colossal_config", - type=str, - required=True, - help="colossal config, which contains zero config and so on") - parser.add_argument( - "--ckpt_path", - type=str, - required=True, - help="location of saving checkpoint, which contains model and optimizer") - parser.add_argument( - '--seed', - type=int, - default=42, - help="random seed for initialization") - parser.add_argument( - '--vscode_debug', - action='store_true', - help="use vscode to debug") - parser.add_argument( - '--load_pretrain_model', - default='', - type=str, - help="location of model's checkpoin") + parser.add_argument("--max_predictions_per_seq", + "--max_pred", + default=80, + type=int, + help="The maximum number of masked tokens in a sequence to be predicted.") + parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="accumulation_steps") + parser.add_argument("--train_micro_batch_size_per_gpu", default=2, type=int, required=True, help="train batch size") + parser.add_argument("--eval_micro_batch_size_per_gpu", default=2, type=int, required=True, help="eval batch size") + parser.add_argument("--num_workers", default=8, type=int, help="") + parser.add_argument("--async_worker", action='store_true', help="") + parser.add_argument("--bert_config", required=True, type=str, help="location of config.json") + parser.add_argument("--wandb", action='store_true', help="use wandb to watch model") + parser.add_argument("--wandb_project_name", default='roberta', help="wandb project name") + parser.add_argument("--log_interval", default=100, type=int, help="report interval") + parser.add_argument("--log_path", type=str, required=True, help="log file which records train step") + parser.add_argument("--tensorboard_path", type=str, required=True, help="location of tensorboard file") + parser.add_argument("--colossal_config", + type=str, + required=True, + help="colossal config, which contains zero config and so on") + parser.add_argument("--ckpt_path", + type=str, + required=True, + help="location of saving checkpoint, which contains model and optimizer") + parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") + parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug") + parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoin") parser.add_argument( '--load_optimizer_lr', default='', type=str, help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step") - parser.add_argument( - '--resume_train', - action='store_true', - help="whether resume training from a early checkpoint") - parser.add_argument( - '--mlm', - default='bert', - type=str, - help="model type, bert or deberta") - parser.add_argument( - '--checkpoint_activations', - action='store_true', - help="whether to use gradient checkpointing") + parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint") + parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta") + parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing") args = parser.parse_args() return args diff --git a/examples/language/roberta/pretraining/bert_dataset_provider.py b/examples/language/roberta/pretraining/bert_dataset_provider.py index 1d8cf2a910e9..eaf165ed18f4 100644 --- a/examples/language/roberta/pretraining/bert_dataset_provider.py +++ b/examples/language/roberta/pretraining/bert_dataset_provider.py @@ -1,4 +1,5 @@ class BertDatasetProviderInterface: + def get_shard(self, index, shuffle=True): raise NotImplementedError diff --git a/examples/language/roberta/pretraining/evaluation.py b/examples/language/roberta/pretraining/evaluation.py index 83f94082f6c0..c44bcf13ac0f 100644 --- a/examples/language/roberta/pretraining/evaluation.py +++ b/examples/language/roberta/pretraining/evaluation.py @@ -1,9 +1,11 @@ -import os import math +import os + import torch +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider from tqdm import tqdm -from utils.global_vars import get_timers, get_tensorboard_writer -from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from utils.global_vars import get_tensorboard_writer, get_timers + def evaluate(engine, args, logger, global_step): evaluate_dataset_provider = NvidiaBertDatasetProvider(args, evaluate=True) @@ -20,16 +22,19 @@ def evaluate(engine, args, logger, global_step): for shard in range(start_shard, len(os.listdir(args.eval_data_path_prefix))): - timers('eval_shard_time').start() + timers('eval_shard_time').start() dataset_iterator, total_length = evaluate_dataset_provider.get_shard(shard) # evaluate_dataset_provider.prefetch_shard(shard + 1) if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), colour='MAGENTA', smoothing=1) + iterator_data = tqdm(enumerate(dataset_iterator), + total=(total_length // args.eval_micro_batch_size_per_gpu // world_size), + colour='MAGENTA', + smoothing=1) else: iterator_data = enumerate(dataset_iterator) - - for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): + + for step, batch_data in iterator_data: #tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1): # batch_data = pretrain_dataset_provider.get_batch(batch_index) eval_step += 1 @@ -40,8 +45,8 @@ def evaluate(engine, args, logger, global_step): # nsp_label = batch_data[5].cuda() output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - - loss = engine.criterion(output.logits, mlm_label)#prediction_scores + + loss = engine.criterion(output.logits, mlm_label) #prediction_scores evaluate_dataset_provider.prefetch_batch() eval_loss += loss.float().item() @@ -54,10 +59,10 @@ def evaluate(engine, args, logger, global_step): if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() tensorboard_log.log_eval({ - 'loss': cur_loss, - 'ppl': ppl, - 'mins_batch': elapsed_time_per_iteration - }, global_step) + 'loss': cur_loss, + 'ppl': ppl, + 'mins_batch': elapsed_time_per_iteration + }, global_step) eval_log_str = f'evaluation shard: {shard} | step: {eval_step} | elapsed_time: {elapsed_time / 60 :.3f} minutes ' + \ f'| mins/batch: {elapsed_time_per_iteration :.3f} seconds | loss: {cur_loss:.7f} | ppl: {ppl:.7f}' diff --git a/examples/language/roberta/pretraining/loss.py b/examples/language/roberta/pretraining/loss.py index dc4f872a755d..989c2bd5c450 100644 --- a/examples/language/roberta/pretraining/loss.py +++ b/examples/language/roberta/pretraining/loss.py @@ -13,5 +13,5 @@ def __init__(self, vocab_size): def forward(self, prediction_scores, masked_lm_labels, next_sentence_labels=None): masked_lm_loss = self.loss_fn(prediction_scores.view(-1, self.vocab_size), masked_lm_labels.view(-1)) # next_sentence_loss = self.loss_fn(seq_relationship_score.view(-1, 2), next_sentence_labels.view(-1)) - total_loss = masked_lm_loss #+ next_sentence_loss + total_loss = masked_lm_loss #+ next_sentence_loss return total_loss diff --git a/examples/language/roberta/pretraining/model/bert.py b/examples/language/roberta/pretraining/model/bert.py index 67c85f760776..a5da1bea6f65 100644 --- a/examples/language/roberta/pretraining/model/bert.py +++ b/examples/language/roberta/pretraining/model/bert.py @@ -15,7 +15,6 @@ # limitations under the License. """PyTorch BERT model.""" - import math import os import warnings @@ -27,7 +26,6 @@ from packaging import version from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -41,8 +39,9 @@ TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel +from transformers.models.bert.configuration_bert import BertConfig from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer -from transformers.utils import ( +from transformers.utils import ( ModelOutput, add_code_sample_docstrings, add_start_docstrings, @@ -50,8 +49,6 @@ logging, replace_return_docstrings, ) -from transformers.models.bert.configuration_bert import BertConfig - logger = logging.get_logger(__name__) @@ -62,8 +59,7 @@ # TokenClassification docstring _CHECKPOINT_FOR_TOKEN_CLASSIFICATION = "dbmdz/bert-large-cased-finetuned-conll03-english" _TOKEN_CLASS_EXPECTED_OUTPUT = ( - "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] " -) + "['O', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'I-LOC', 'O', 'I-LOC', 'I-LOC'] ") _TOKEN_CLASS_EXPECTED_LOSS = 0.01 # QuestionAnswering docstring @@ -78,7 +74,6 @@ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'" _SEQ_CLASS_EXPECTED_LOSS = 0.01 - BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "bert-base-uncased", "bert-large-uncased", @@ -114,10 +109,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): import numpy as np import tensorflow as tf except ImportError: - logger.error( - "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions." - ) + logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") raise tf_path = os.path.abspath(tf_checkpoint_path) logger.info(f"Converting TensorFlow checkpoint from {tf_path}") @@ -135,10 +128,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): name = name.split("/") # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # which are not required for using pretrained model - if any( - n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] - for n in name - ): + if any(n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] + for n in name): logger.info(f"Skipping {'/'.join(name)}") continue pointer = model @@ -218,7 +209,7 @@ def forward( seq_length = input_shape[1] if position_ids is None: - position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + position_ids = self.position_ids[:, past_key_values_length:seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves @@ -245,13 +236,12 @@ def forward( class BertSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) + raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})") self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) @@ -262,9 +252,7 @@ def __init__(self, config, position_embedding_type=None): self.value = nn.Linear(config.hidden_size, self.all_head_size) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = position_embedding_type or getattr( - config, "position_embedding_type", "absolute" - ) + self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute") if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) @@ -332,7 +320,7 @@ def forward( position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) - positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) @@ -372,6 +360,7 @@ def forward( class BertSelfOutput(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -386,6 +375,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): super().__init__() self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type) @@ -395,9 +385,8 @@ def __init__(self, config, position_embedding_type=None): def prune_heads(self, heads): if len(heads) == 0: return - heads, index = find_pruneable_heads_and_indices( - heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads - ) + heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) @@ -430,11 +419,12 @@ def forward( output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class BertIntermediate(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -450,6 +440,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertOutput(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -464,6 +455,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to class BertLayer(nn.Module): + def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -504,15 +496,14 @@ def forward( outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" - " by setting `config.add_cross_attention=True`" - ) + " by setting `config.add_cross_attention=True`") # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None @@ -526,15 +517,14 @@ def forward( output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights # add cross-attn cache to positions 3,4 of present_key_value tuple cross_attn_present_key_value = cross_attention_outputs[-1] present_key_value = present_key_value + cross_attn_present_key_value - layer_output = apply_chunking_to_forward( - self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output - ) + layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, + self.seq_len_dim, attention_output) outputs = (layer_output,) + outputs # if decoder, return the attn key/values as the last output @@ -550,6 +540,7 @@ def feed_forward_chunk(self, attention_output): class BertEncoder(nn.Module): + def __init__(self, config): super().__init__() self.config = config @@ -585,11 +576,11 @@ def forward( if use_cache: logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") use_cache = False def create_custom_forward(module): + def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) @@ -626,17 +617,13 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, @@ -647,6 +634,7 @@ def custom_forward(*inputs): class BertPooler(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -662,6 +650,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -679,6 +668,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class BertLMPredictionHead(nn.Module): + def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) @@ -699,6 +689,7 @@ def forward(self, hidden_states): class BertOnlyMLMHead(nn.Module): + def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) @@ -709,6 +700,7 @@ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: class BertOnlyNSPHead(nn.Module): + def __init__(self, config): super().__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) @@ -719,6 +711,7 @@ def forward(self, pooled_output): class BertPreTrainingHeads(nn.Module): + def __init__(self, config): super().__init__() self.predictions = BertLMPredictionHead(config) @@ -950,9 +943,8 @@ def forward( `past_key_values`). """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: @@ -1051,6 +1043,7 @@ def forward( BERT_START_DOCSTRING, ) class BertForPreTraining(BertPreTrainedModel): + def __init__(self, config): super().__init__(config) @@ -1151,9 +1144,8 @@ def forward( ) -@add_start_docstrings( - """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING -) +@add_start_docstrings("""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", + BERT_START_DOCSTRING) class BertLMHeadModel(BertPreTrainedModel): _keys_to_ignore_on_load_unexpected = [r"pooler"] @@ -1298,10 +1290,8 @@ def __init__(self, config): super().__init__(config) if config.is_decoder: - logger.warning( - "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " - "bi-directional self-attention." - ) + logger.warning("If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention.") self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) @@ -1367,7 +1357,7 @@ def forward( masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -1390,9 +1380,10 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ raise ValueError("The PAD token should be defined for generation") attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) - dummy_token = torch.full( - (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device - ) + dummy_token = torch.full((effective_batch_size, 1), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device) input_ids = torch.cat([input_ids, dummy_token], dim=1) return {"input_ids": input_ids, "attention_mask": attention_mask} @@ -1403,6 +1394,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ BERT_START_DOCSTRING, ) class BertForNextSentencePrediction(BertPreTrainedModel): + def __init__(self, config): super().__init__(config) @@ -1508,15 +1500,15 @@ def forward( BERT_START_DOCSTRING, ) class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) + classifier_dropout = (config.classifier_dropout + if config.classifier_dropout is not None else config.hidden_dropout_prob) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) @@ -1612,13 +1604,13 @@ def forward( BERT_START_DOCSTRING, ) class BertForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): super().__init__(config) self.bert = BertModel(config) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) + classifier_dropout = (config.classifier_dropout + if config.classifier_dropout is not None else config.hidden_dropout_prob) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, 1) @@ -1658,11 +1650,8 @@ def forward( attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None - inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None - else None - ) + inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None else None) outputs = self.bert( input_ids, @@ -1715,9 +1704,8 @@ def __init__(self, config): self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) - classifier_dropout = ( - config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob - ) + classifier_dropout = (config.classifier_dropout + if config.classifier_dropout is not None else config.hidden_dropout_prob) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) diff --git a/examples/language/roberta/pretraining/model/deberta_v2.py b/examples/language/roberta/pretraining/model/deberta_v2.py index c6ce82847f75..5fc284911e38 100644 --- a/examples/language/roberta/pretraining/model/deberta_v2.py +++ b/examples/language/roberta/pretraining/model/deberta_v2.py @@ -23,7 +23,7 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss - +from transformers import FillMaskPipeline, T5ForConditionalGeneration, T5Tokenizer from transformers.activations import ACT2FN from transformers.modeling_outputs import ( BaseModelOutput, @@ -34,10 +34,14 @@ TokenClassifierOutput, ) from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import softmax_backward_data -from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config -from transformers import T5Tokenizer, T5ForConditionalGeneration, FillMaskPipeline +from transformers.pytorch_utils import softmax_backward_data +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) logger = logging.get_logger(__name__) @@ -55,6 +59,7 @@ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler class ContextPooler(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) @@ -133,15 +138,15 @@ def symbolic(g, self, mask, dim): g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), to_i=sym_help.cast_pytorch_to_onnx["Byte"], ) - output = masked_fill( - g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) - ) + output = masked_fill(g, self, r_mask, + g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))) output = softmax(g, output, dim) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) # Copied from transformers.models.deberta.modeling_deberta.DropoutContext class DropoutContext(object): + def __init__(self): self.dropout = 0 self.mask = None @@ -244,6 +249,7 @@ def get_context(self): # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm class DebertaV2SelfOutput(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -259,6 +265,7 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 class DebertaV2Attention(nn.Module): + def __init__(self, config): super().__init__() self.self = DisentangledSelfAttention(config) @@ -296,6 +303,7 @@ def forward( # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 class DebertaV2Intermediate(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) @@ -312,6 +320,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm class DebertaV2Output(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -328,6 +337,7 @@ def forward(self, hidden_states, input_tensor): # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 class DebertaV2Layer(nn.Module): + def __init__(self, config): super().__init__() self.attention = DebertaV2Attention(config) @@ -362,14 +372,17 @@ def forward( class ConvLayer(nn.Module): + def __init__(self, config): super().__init__() kernel_size = getattr(config, "conv_kernel_size", 3) groups = getattr(config, "conv_groups", 1) self.conv_act = getattr(config, "conv_act", "tanh") - self.conv = nn.Conv1d( - config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups - ) + self.conv = nn.Conv1d(config.hidden_size, + config.hidden_size, + kernel_size, + padding=(kernel_size - 1) // 2, + groups=groups) self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config @@ -452,9 +465,10 @@ def get_attention_mask(self, attention_mask): def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): if self.relative_attention and relative_pos is None: q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) - relative_pos = build_relative_position( - q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions - ) + relative_pos = build_relative_position(q, + hidden_states.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions) return relative_pos def forward( @@ -491,6 +505,7 @@ def forward( if self.gradient_checkpointing and self.training: def create_custom_forward(module): + def custom_forward(*inputs): return module(*inputs, output_attentions) @@ -535,9 +550,9 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions - ) + return BaseModelOutput(last_hidden_state=output_states, + hidden_states=all_hidden_states, + attentions=all_attentions) def make_log_bucket_position(relative_pos, bucket_size, max_position): @@ -610,10 +625,8 @@ class DisentangledSelfAttention(nn.Module): def __init__(self, config): super().__init__() if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " - f"heads ({config.num_attention_heads})" - ) + raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})") self.num_attention_heads = config.num_attention_heads _attention_head_size = config.hidden_size // config.num_attention_heads self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) @@ -706,28 +719,22 @@ def forward( attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale if self.relative_attention: rel_embeddings = self.pos_dropout(rel_embeddings) - rel_att = self.disentangled_attention_bias( - query_layer, key_layer, relative_pos, rel_embeddings, scale_factor - ) + rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, + scale_factor) if rel_att is not None: attention_scores = attention_scores + rel_att attention_scores = attention_scores - attention_scores = attention_scores.view( - -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) - ) + attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), + attention_scores.size(-1)) # bsz x height x length x dimension attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) attention_probs = self.dropout(attention_probs) - context_layer = torch.bmm( - attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer - ) - context_layer = ( - context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) - .permute(0, 2, 1, 3) - .contiguous() - ) + context_layer = torch.bmm(attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), + value_layer) + context_layer = (context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), + context_layer.size(-1)).permute(0, 2, 1, 3).contiguous()) new_context_layer_shape = context_layer.size()[:-2] + (-1,) context_layer = context_layer.view(new_context_layer_shape) if output_attentions: @@ -738,9 +745,10 @@ def forward( def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): if relative_pos is None: q = query_layer.size(-2) - relative_pos = build_relative_position( - q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions - ) + relative_pos = build_relative_position(q, + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions) if relative_pos.dim() == 2: relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) elif relative_pos.dim() == 3: @@ -758,25 +766,22 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ # rel_embeddings = rel_embeddings.unsqueeze(0) # rel_embeddings = rel_embeddings[0 : att_span * 2, :].unsqueeze(0) if self.share_att_key: - pos_query_layer = self.transpose_for_scores( - self.query_proj(rel_embeddings), self.num_attention_heads - ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_query_layer = self.transpose_for_scores(self.query_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1) pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1 - ) + query_layer.size(0) // self.num_attention_heads, 1, 1) else: if "c2p" in self.pos_att_type: - pos_key_layer = self.transpose_for_scores( - self.pos_key_proj(rel_embeddings), self.num_attention_heads - ).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1 - ) # .split(self.all_head_size, dim=-1) + pos_key_layer = self.transpose_for_scores(self.pos_key_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, + 1) # .split(self.all_head_size, dim=-1) if "p2c" in self.pos_att_type: - pos_query_layer = self.transpose_for_scores( - self.pos_query_proj(rel_embeddings), self.num_attention_heads - ).repeat( - query_layer.size(0) // self.num_attention_heads, 1, 1 - ) # .split(self.all_head_size, dim=-1) + pos_query_layer = self.transpose_for_scores(self.pos_query_proj(rel_embeddings), + self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, + 1) # .split(self.all_head_size, dim=-1) score = 0 # content->position @@ -787,7 +792,9 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ c2p_att = torch.gather( c2p_att, dim=-1, - index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + index=c2p_pos.squeeze(0).expand([query_layer.size(0), + query_layer.size(1), + relative_pos.size(-1)]), ) score += c2p_att / scale @@ -810,7 +817,9 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_ p2c_att = torch.gather( p2c_att, dim=-1, - index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + index=p2c_pos.squeeze(0).expand([query_layer.size(0), + key_layer.size(-2), + key_layer.size(-2)]), ).transpose(-1, -2) score += p2c_att / scale @@ -990,6 +999,7 @@ def _set_gradient_checkpointing(self, module, value=False): ) # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 class DebertaV2Model(DebertaV2PreTrainedModel): + def __init__(self, config): super().__init__(config) @@ -1032,9 +1042,8 @@ def forward( return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: @@ -1091,7 +1100,7 @@ def forward( sequence_output = encoded_layers[-1] if not return_dict: - return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2):] return BaseModelOutput( last_hidden_state=sequence_output, @@ -1165,7 +1174,7 @@ def forward( masked_lm_loss = None if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token + loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: @@ -1182,6 +1191,7 @@ def forward( # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta class DebertaV2PredictionHeadTransform(nn.Module): + def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) @@ -1200,6 +1210,7 @@ def forward(self, hidden_states): # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta class DebertaV2LMPredictionHead(nn.Module): + def __init__(self, config): super().__init__() self.transform = DebertaV2PredictionHeadTransform(config) @@ -1221,6 +1232,7 @@ def forward(self, hidden_states): # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta class DebertaV2OnlyMLMHead(nn.Module): + def __init__(self, config): super().__init__() self.predictions = DebertaV2LMPredictionHead(config) @@ -1239,6 +1251,7 @@ def forward(self, sequence_output): ) # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + def __init__(self, config): super().__init__(config) @@ -1318,9 +1331,8 @@ def forward( label_index = (labels >= 0).nonzero() labels = labels.long() if label_index.size(0) > 0: - labeled_logits = torch.gather( - logits, 0, label_index.expand(label_index.size(0), logits.size(1)) - ) + labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), + logits.size(1))) labels = torch.gather(labels, 0, label_index.view(-1)) loss_fct = CrossEntropyLoss() loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) @@ -1345,9 +1357,10 @@ def forward( output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutput( - loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions - ) + return SequenceClassifierOutput(loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions) @add_start_docstrings( @@ -1422,9 +1435,10 @@ def forward( output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output - return TokenClassifierOutput( - loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions - ) + return TokenClassifierOutput(loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions) @add_start_docstrings( @@ -1536,6 +1550,7 @@ def forward( DEBERTA_START_DOCSTRING, ) class DebertaV2ForMultipleChoice(DebertaV2PreTrainedModel): + def __init__(self, config): super().__init__(config) @@ -1591,11 +1606,8 @@ def forward( flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - flat_inputs_embeds = ( - inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) - if inputs_embeds is not None - else None - ) + flat_inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None else None) outputs = self.deberta( flat_input_ids, diff --git a/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py index cce836913505..72c7bd852a40 100644 --- a/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py +++ b/examples/language/roberta/pretraining/nvidia_bert_dataset_provider.py @@ -1,24 +1,25 @@ +import json +import logging import os import random -import h5py -import logging -import json import time from concurrent.futures import ProcessPoolExecutor +import h5py import numpy as np - import torch import torch.distributed as dist +from bert_dataset_provider import BertDatasetProviderInterface from torch.utils.data import DataLoader, Dataset -from torch.utils.data.sampler import RandomSampler from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import RandomSampler -from bert_dataset_provider import BertDatasetProviderInterface import colossalai.utils as utils + # Workaround because python functions are not picklable class WorkerInitObj(object): + def __init__(self, seed): self.seed = seed @@ -27,29 +28,25 @@ def __call__(self, id): random.seed(self.seed + id) -def create_pretraining_dataset(input_file, max_predictions_per_seq, - num_workers, train_batch_size, worker_init, +def create_pretraining_dataset(input_file, max_predictions_per_seq, num_workers, train_batch_size, worker_init, data_sampler): - train_data = pretraining_dataset( - input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) + train_data = pretraining_dataset(input_file=input_file, max_predictions_per_seq=max_predictions_per_seq) train_dataloader = DataLoader(train_data, sampler=data_sampler(train_data), batch_size=train_batch_size, num_workers=num_workers, worker_init_fn=worker_init, - pin_memory=True - ) + pin_memory=True) return train_dataloader, len(train_data) class pretraining_dataset(Dataset): + def __init__(self, input_file, max_predictions_per_seq): self.input_file = input_file self.max_predictions_per_seq = max_predictions_per_seq f = h5py.File(input_file, "r") - keys = [ - 'input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions' - ] + keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions'] self.inputs = [np.asarray(f[key][:]) for key in keys] f.close() @@ -59,21 +56,16 @@ def __len__(self): def __getitem__(self, index): - [ - input_ids, input_mask, segment_ids, masked_lm_labels - ] = [ - torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else - torch.from_numpy(np.asarray(input[index].astype(np.int64))) - for indice, input in enumerate(self.inputs) + [input_ids, input_mask, segment_ids, masked_lm_labels] = [ + torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy( + np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs) ] - return [ - input_ids, input_mask, - segment_ids, masked_lm_labels - ] + return [input_ids, input_mask, segment_ids, masked_lm_labels] class NvidiaBertDatasetProvider(BertDatasetProviderInterface): + def __init__(self, args, evaluate=False): self.num_workers = args.num_workers self.max_seq_length = args.max_seq_length @@ -85,22 +77,24 @@ def __init__(self, args, evaluate=False): else: self.train_micro_batch_size_per_gpu = args.eval_micro_batch_size_per_gpu self.logger = args.logger - + self.global_rank = dist.get_rank() self.world_size = dist.get_world_size() # Initialize dataset files if not evaluate: self.dataset_files = [ - os.path.join(args.data_path_prefix, f) for f in os.listdir(args.data_path_prefix) if - os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f + os.path.join(args.data_path_prefix, f) + for f in os.listdir(args.data_path_prefix) + if os.path.isfile(os.path.join(args.data_path_prefix, f)) and 'h5' in f ] else: self.dataset_files = [ - os.path.join(args.eval_data_path_prefix, f) for f in os.listdir(args.eval_data_path_prefix) if - os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f + os.path.join(args.eval_data_path_prefix, f) + for f in os.listdir(args.eval_data_path_prefix) + if os.path.isfile(os.path.join(args.eval_data_path_prefix, f)) and 'h5' in f ] - + self.dataset_files.sort() # random.shuffle(self.dataset_files) self.num_files = len(self.dataset_files) @@ -114,9 +108,7 @@ def __init__(self, args, evaluate=False): self.shuffle = True if self.global_rank == 0: - self.logger.info( - f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}" - ) + self.logger.info(f"NvidiaBertDatasetProvider - Initialization: num_files = {self.num_files}") def get_shard(self, index): start = time.time() @@ -130,9 +122,8 @@ def get_shard(self, index): worker_init=self.worker_init, data_sampler=self.data_sampler) else: - self.train_dataloader, sample_count = self.dataset_future.result( - timeout=None) - + self.train_dataloader, sample_count = self.dataset_future.result(timeout=None) + self.logger.info( f"Data Loading Completed for Pretraining Data from {self.data_file} with {sample_count} samples took {time.time()-start:.2f}s." ) @@ -145,11 +136,9 @@ def release_shard(self): def prefetch_shard(self, index): self.data_file = self._get_shard_file(index) - self.dataset_future = self.pool.submit( - create_pretraining_dataset, self.data_file, - self.max_predictions_per_seq, self.num_workers, - self.train_micro_batch_size_per_gpu, self.worker_init, - self.data_sampler) + self.dataset_future = self.pool.submit(create_pretraining_dataset, self.data_file, self.max_predictions_per_seq, + self.num_workers, self.train_micro_batch_size_per_gpu, self.worker_init, + self.data_sampler) def get_batch(self, batch_iter): return batch_iter @@ -179,4 +168,3 @@ def shuffle_dataset(self, epoch): indices = torch.randperm(self.num_files, generator=g).tolist() new_dataset = [self.dataset_files[i] for i in indices] self.dataset_files = new_dataset - \ No newline at end of file diff --git a/examples/language/roberta/pretraining/pretrain_utils.py b/examples/language/roberta/pretraining/pretrain_utils.py index ba17b0f5ee09..e018bab5788e 100644 --- a/examples/language/roberta/pretraining/pretrain_utils.py +++ b/examples/language/roberta/pretraining/pretrain_utils.py @@ -1,35 +1,45 @@ -import transformers import logging -from colossalai.nn.lr_scheduler import LinearWarmupLR -from transformers import get_linear_schedule_with_warmup -from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig -from transformers import GPT2Config, GPT2LMHeadModel -from transformers import AutoTokenizer, AutoModelForMaskedLM -from colossalai.nn.optimizer import FusedAdam -from torch.optim import AdamW -from colossalai.core import global_context as gpc -import torch import os import sys -sys.path.append(os.getcwd()) -from model.deberta_v2 import DebertaV2ForMaskedLM -from model.bert import BertForMaskedLM -import torch.nn as nn +import torch +import transformers +from torch.optim import AdamW +from transformers import ( + AutoModelForMaskedLM, + AutoTokenizer, + BertForPreTraining, + GPT2Config, + GPT2LMHeadModel, + RobertaConfig, + RobertaForMaskedLM, + get_linear_schedule_with_warmup, +) + +from colossalai.core import global_context as gpc +from colossalai.nn.lr_scheduler import LinearWarmupLR +from colossalai.nn.optimizer import FusedAdam + +sys.path.append(os.getcwd()) from collections import OrderedDict +import torch.nn as nn +from model.bert import BertForMaskedLM +from model.deberta_v2 import DebertaV2ForMaskedLM + __all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining'] def get_new_state_dict(state_dict, start_index=13): - new_state_dict = OrderedDict() + new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[start_index:] - new_state_dict[name] = v + new_state_dict[name] = v return new_state_dict class LMModel(nn.Module): + def __init__(self, model, config, args): super().__init__() @@ -58,16 +68,18 @@ def get_model(args, logger): if len(args.load_pretrain_model) > 0: assert os.path.exists(args.load_pretrain_model) # load_checkpoint(args.load_pretrain_model, model, strict=False) - m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) + m_state_dict = torch.load(args.load_pretrain_model, + map_location=torch.device(f"cuda:{torch.cuda.current_device()}")) # new_state_dict = get_new_state_dict(m_state_dict) - model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!! + model.load_state_dict(m_state_dict, + strict=True) # must insure that every process have identical parameters !!!!!!! logger.info("load model success") - + numel = sum([p.numel() for p in model.parameters()]) if args.checkpoint_activations: model.gradient_checkpointing_enable() # model = LMModel(model, config, args) - + return config, model, numel @@ -89,7 +101,10 @@ def get_optimizer(model, lr): def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1): # warmup_steps = int(total_steps * warmup_ratio) - lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch) + lr_scheduler = get_linear_schedule_with_warmup(optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=total_steps, + last_epoch=last_epoch) # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps) return lr_scheduler @@ -103,10 +118,7 @@ def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step): checkpoint['epoch'] = epoch checkpoint['shard'] = shard checkpoint['global_step'] = global_step - model_state = model.state_dict() #each process must run model.state_dict() + model_state = model.state_dict() #each process must run model.state_dict() if gpc.get_global_rank() == 0: torch.save(checkpoint, optimizer_lr_path) torch.save(model_state, model_path) - - - diff --git a/examples/language/roberta/pretraining/run_pretrain.sh b/examples/language/roberta/pretraining/run_pretrain.sh index 144cd0ab96fd..50f08db0c32f 100644 --- a/examples/language/roberta/pretraining/run_pretrain.sh +++ b/examples/language/roberta/pretraining/run_pretrain.sh @@ -37,4 +37,3 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ --mlm bert \ --wandb \ --checkpoint_activations \ - \ No newline at end of file diff --git a/examples/language/roberta/pretraining/run_pretrain_resume.sh b/examples/language/roberta/pretraining/run_pretrain_resume.sh index a0704cf7c517..e74d668165cc 100644 --- a/examples/language/roberta/pretraining/run_pretrain_resume.sh +++ b/examples/language/roberta/pretraining/run_pretrain_resume.sh @@ -40,4 +40,3 @@ env OMP_NUM_THREADS=40 colossalai run --hostfile ./hostfile \ --resume_train \ --load_pretrain_model /ckpt/1.pt \ --load_optimizer_lr /ckpt/1.op_lrs \ - \ No newline at end of file diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py index 9840a122cbc4..f95bc20b1335 100644 --- a/examples/language/roberta/pretraining/run_pretraining.py +++ b/examples/language/roberta/pretraining/run_pretraining.py @@ -1,69 +1,70 @@ -import colossalai import math +import os +import time +from functools import partial + import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -import colossalai.nn as col_nn from arguments import parse_args -from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt -from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args -from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer -from utils.logger import Logger from evaluation import evaluate from loss import LossForPretraining - -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt from tqdm import tqdm -import os -import time -from functools import partial - from transformers import AutoTokenizer +from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator +from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables +from utils.logger import Logger +import colossalai +import colossalai.nn as col_nn +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.gemini import ChunkManager, GeminiManager -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.utils import get_current_device +from colossalai.nn.optimizer import HybridAdam from colossalai.nn.parallel import ZeroDDP -from colossalai.zero import ZeroOptimizer from colossalai.tensor import ProcessGroup -from colossalai.nn.optimizer import HybridAdam +from colossalai.utils import get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import TensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_optim import ShardedOptimizerV2 def main(): args = parse_args() launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) - + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) os.environ['CUDA_LAUNCH_BLOCKING'] = '1' - + logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) - + if args.vscode_debug: colossalai.launch(config={}, - rank=args.rank, - world_size=args.world_size, - host=args.host, - port=args.port, - backend=args.backend) + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend) args.local_rank = -1 args.log_interval = 1 else: - colossalai.launch_from_torch(args.colossal_config) #args.colossal_config + colossalai.launch_from_torch(args.colossal_config) #args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) - logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + - f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}') + logger.info( + f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + + f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}' + ) log_args(logger, args) args.tokenizer = tokenizer args.logger = logger set_global_variables(launch_time, args.tensorboard_path) - + use_zero = hasattr(gpc.config, 'zero') world_size = torch.distributed.get_world_size() @@ -71,8 +72,8 @@ def main(): if use_zero: shard_strategy = TensorShardStrategy() with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, - shard_param=True): - + shard_param=True): + config, model, numel = get_model(args, logger) # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) else: @@ -82,9 +83,9 @@ def main(): os.mkdir(os.path.join(args.ckpt_path, launch_time)) logger.info(f'Model numel: {numel}') - + get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) - steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) total_steps = steps_per_epoch * args.epoch # build optimizer and lr_scheduler @@ -98,18 +99,22 @@ def main(): o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 optimizer = get_optimizer(model, lr=args.lr) optimizer.load_state_dict(o_l_state_dict['optimizer']) - lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] + lr_scheduler = get_lr_scheduler( + optimizer, total_steps=total_steps, + last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") # if you want delete the above three code, have to move the model to gpu, because in optimizer.step() lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) - + start_epoch = o_l_state_dict['epoch'] start_shard = o_l_state_dict['shard'] + 1 # global_step = o_l_state_dict['global_step'] + 1 - logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') + logger.info( + f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' + ) else: optimizer = get_optimizer(model, lr=args.lr) lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) @@ -124,12 +129,11 @@ def main(): # initialize with colossalai engine, _, _, lr_scheduelr = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler) - + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler) + logger.info(get_mem_info(prefix='After init model, ')) - best_loss = None eval_loss = 0 @@ -146,13 +150,16 @@ def main(): dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) + iterator_data = tqdm(enumerate(dataset_iterator), + total=(total_length // args.train_micro_batch_size_per_gpu // world_size), + colour='cyan', + smoothing=1) else: iterator_data = enumerate(dataset_iterator) engine.train() - - for step, batch_data in iterator_data: + + for step, batch_data in iterator_data: # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") @@ -162,7 +169,7 @@ def main(): # nsp_label = batch_data[5].cuda() output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - + loss = engine.criterion(output.logits, mlm_label) pretrain_dataset_provider.prefetch_batch() @@ -172,14 +179,15 @@ def main(): engine.step() lr_scheduelr.step() engine.zero_grad() - + global_step += 1 if global_step % args.log_interval == 0 and global_step != 0 \ and torch.distributed.get_rank() == 0: elapsed_time = timers('interval_time').elapsed(reset=False) elapsed_time_per_iteration = elapsed_time / global_step - samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size) + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( + numel, args, config, elapsed_time, global_step, world_size) cur_loss = train_loss / args.log_interval current_lr = lr_scheduelr.get_last_lr()[0] @@ -189,12 +197,13 @@ def main(): if args.wandb: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_train({ - 'lr': current_lr, - 'loss': cur_loss, - 'ppl': math.exp(cur_loss), - 'mins_batch': elapsed_time_per_iteration - }, global_step) + tensorboard_log.log_train( + { + 'lr': current_lr, + 'loss': cur_loss, + 'ppl': math.exp(cur_loss), + 'mins_batch': elapsed_time_per_iteration + }, global_step) train_loss = 0 @@ -202,9 +211,10 @@ def main(): logger.info('*' * 100) eval_loss += evaluate(engine, args, logger, global_step) - save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) - - + save_ckpt(engine.model, optimizer, lr_scheduelr, + os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, + shard, global_step) + eval_loss /= len(os.listdir(args.data_path_prefix)) logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \ f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') diff --git a/examples/language/roberta/pretraining/utils/WandbLog.py b/examples/language/roberta/pretraining/utils/WandbLog.py index 9dd28a98186b..b68ba8387dcd 100644 --- a/examples/language/roberta/pretraining/utils/WandbLog.py +++ b/examples/language/roberta/pretraining/utils/WandbLog.py @@ -1,8 +1,10 @@ +import os import time + import wandb -import os from torch.utils.tensorboard import SummaryWriter + class WandbLog: @classmethod @@ -15,7 +17,7 @@ def log(cls, result, model=None, gradient=None): if model: wandb.watch(model) - + if gradient: wandb.watch(gradient) @@ -30,7 +32,7 @@ def __init__(self, location, name=time.strftime("%Y-%m-%d %H:%M:%S", time.localt def log_train(self, result, step): for k, v in result.items(): self.writer.add_scalar(f'{k}/train', v, step) - + def log_eval(self, result, step): for k, v in result.items(): self.writer.add_scalar(f'{k}/eval', v, step) @@ -38,9 +40,3 @@ def log_eval(self, result, step): def log_zeroshot(self, result, step): for k, v in result.items(): self.writer.add_scalar(f'{k}_acc/eval', v, step) - - - - - - diff --git a/examples/language/roberta/pretraining/utils/exp_util.py b/examples/language/roberta/pretraining/utils/exp_util.py index a02b0872acbc..0cdb56bad031 100644 --- a/examples/language/roberta/pretraining/utils/exp_util.py +++ b/examples/language/roberta/pretraining/utils/exp_util.py @@ -1,9 +1,13 @@ import functools -import os, shutil -import torch +import os +import shutil + import psutil +import torch + from colossalai.core import global_context as gpc + def logging(s, log_path, print_=True, log_=True): if print_: print(s) @@ -11,9 +15,11 @@ def logging(s, log_path, print_=True, log_=True): with open(log_path, 'a+') as f_log: f_log.write(s + '\n') + def get_logger(log_path, **kwargs): return functools.partial(logging, log_path=log_path, **kwargs) + def create_exp_dir(dir_path, scripts_to_save=None, debug=False): if debug: print('Debug Mode : no experiment dir created') @@ -33,6 +39,7 @@ def create_exp_dir(dir_path, scripts_to_save=None, debug=False): return get_logger(log_path=os.path.join(dir_path, 'log.txt')) + def get_cpu_mem(): return psutil.Process().memory_info().rss / 1024**2 @@ -52,11 +59,15 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): def get_parameters_in_billions(model, world_size=1): gpus_per_model = world_size - approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()]) - for model_module in model]) + approx_parameters_in_billions = sum([ + sum([p.ds_numel if hasattr(p, 'ds_id') else p.nelement() + for p in model_module.parameters()]) + for model_module in model + ]) return approx_parameters_in_billions * gpus_per_model / (1e9) + def throughput_calculator(numel, args, config, iteration_time, total_iterations, world_size=1): gpus_per_model = 1 batch_size = args.train_micro_batch_size_per_gpu @@ -76,10 +87,13 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations, # The factor of 4 is when used with activation check-pointing, # otherwise it will be 3. checkpoint_activations_factor = 4 if args.checkpoint_activations else 3 - flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + (vocab_size / (16. * num_layers * hidden_size))) + flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * args.max_seq_length * num_layers * + (hidden_size**2)) * (1. + (args.max_seq_length / (6. * hidden_size)) + + (vocab_size / (16. * num_layers * hidden_size))) tflops = flops_per_iteration / (elapsed_time_per_iter * (10**12)) return samples_per_second, tflops, approx_parameters_in_billions + def synchronize(): if not torch.distributed.is_available(): return @@ -90,10 +104,11 @@ def synchronize(): return torch.distributed.barrier() + def log_args(logger, args): logger.info('--------args----------') message = '\n'.join([f'{k:<30}: {v}' for k, v in vars(args).items()]) message += '\n' message += '\n'.join([f'{k:<30}: {v}' for k, v in gpc.config.items()]) logger.info(message) - logger.info('--------args----------\n') \ No newline at end of file + logger.info('--------args----------\n') diff --git a/examples/language/roberta/pretraining/utils/global_vars.py b/examples/language/roberta/pretraining/utils/global_vars.py index 363cbf91c065..7b0c5a2be73d 100644 --- a/examples/language/roberta/pretraining/utils/global_vars.py +++ b/examples/language/roberta/pretraining/utils/global_vars.py @@ -1,5 +1,7 @@ import time + import torch + from .WandbLog import TensorboardLog _GLOBAL_TIMERS = None @@ -10,30 +12,34 @@ def set_global_variables(launch_time, tensorboard_path): _set_timers() _set_tensorboard_writer(launch_time, tensorboard_path) + def _set_timers(): """Initialize timers.""" global _GLOBAL_TIMERS _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') _GLOBAL_TIMERS = Timers() + def _set_tensorboard_writer(launch_time, tensorboard_path): """Set tensorboard writer.""" global _GLOBAL_TENSORBOARD_WRITER - _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, - 'tensorboard writer') + _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, 'tensorboard writer') if torch.distributed.get_rank() == 0: _GLOBAL_TENSORBOARD_WRITER = TensorboardLog(tensorboard_path + f'/{launch_time}', launch_time) - + + def get_timers(): """Return timers.""" _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') return _GLOBAL_TIMERS + def get_tensorboard_writer(): """Return tensorboard writer. It can be None so no need to check if it is initialized.""" return _GLOBAL_TENSORBOARD_WRITER + def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" assert var is not None, '{} is not initialized.'.format(name) @@ -115,12 +121,10 @@ def log(self, names, normalizer=1.0, reset=True): assert normalizer > 0.0 string = 'time (ms)' for name in names: - elapsed_time = self.timers[name].elapsed( - reset=reset) * 1000.0 / normalizer + elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer string += ' | {}: {:.2f}'.format(name, elapsed_time) if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == ( - torch.distributed.get_world_size() - 1): + if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): print(string, flush=True) else: print(string, flush=True) diff --git a/examples/language/roberta/pretraining/utils/logger.py b/examples/language/roberta/pretraining/utils/logger.py index 481c4c6ce61b..75c9bf4bef25 100644 --- a/examples/language/roberta/pretraining/utils/logger.py +++ b/examples/language/roberta/pretraining/utils/logger.py @@ -1,22 +1,22 @@ -import os import logging +import os + import torch.distributed as dist -logging.basicConfig( - format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO) +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) logger = logging.getLogger(__name__) class Logger(): + def __init__(self, log_path, cuda=False, debug=False): self.logger = logging.getLogger(__name__) self.cuda = cuda self.log_path = log_path self.debug = debug - def info(self, message, log_=True, print_=True, *args, **kwargs): if (self.cuda and dist.get_rank() == 0) or not self.cuda: if print_: @@ -26,6 +26,5 @@ def info(self, message, log_=True, print_=True, *args, **kwargs): with open(self.log_path, 'a+') as f_log: f_log.write(message + '\n') - def error(self, message, *args, **kwargs): self.logger.error(message, *args, **kwargs) diff --git a/examples/tutorial/fp8/mnist/README.md b/examples/tutorial/fp8/mnist/README.md index 46711f9ebdd8..e1128c1054b7 100644 --- a/examples/tutorial/fp8/mnist/README.md +++ b/examples/tutorial/fp8/mnist/README.md @@ -1,13 +1,13 @@ -# Basic MNIST Example with optional FP8 of TransformerEngine - -[TransformerEngine](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference. - -Thanks for the contribution to this tutorial from NVIDIA. - -```bash -python main.py -python main.py --use-te # Linear layers from TransformerEngine -python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers -``` - -> We are working to integrate it with Colossal-AI and will finish it soon. +# Basic MNIST Example with optional FP8 of TransformerEngine + +[TransformerEngine](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference. + +Thanks for the contribution to this tutorial from NVIDIA. + +```bash +python main.py +python main.py --use-te # Linear layers from TransformerEngine +python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers +``` + +> We are working to integrate it with Colossal-AI and will finish it soon. diff --git a/examples/tutorial/fp8/mnist/main.py b/examples/tutorial/fp8/mnist/main.py index 000ded2f111f..a534663d380f 100644 --- a/examples/tutorial/fp8/mnist/main.py +++ b/examples/tutorial/fp8/mnist/main.py @@ -3,12 +3,13 @@ # See LICENSE for license information. import argparse + import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from torchvision import datasets, transforms from torch.optim.lr_scheduler import StepLR +from torchvision import datasets, transforms try: from transformer_engine import pytorch as te @@ -18,6 +19,7 @@ class Net(nn.Module): + def __init__(self, use_te=False): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) @@ -62,12 +64,10 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: - print( - f"Train Epoch: {epoch} " - f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " - f"({100. * batch_idx / len(train_loader):.0f}%)]\t" - f"Loss: {loss.item():.6f}" - ) + print(f"Train Epoch: {epoch} " + f"[{batch_idx * len(data)}/{len(train_loader.dataset)} " + f"({100. * batch_idx / len(train_loader):.0f}%)]\t" + f"Loss: {loss.item():.6f}") if args.dry_run: break @@ -83,6 +83,7 @@ def calibrate(model, device, test_loader): with te.fp8_autocast(enabled=False, calibrating=True): output = model(data) + def test(model, device, test_loader, use_fp8): """Testing function.""" model.eval() @@ -93,21 +94,15 @@ def test(model, device, test_loader, use_fp8): data, target = data.to(device), target.to(device) with te.fp8_autocast(enabled=use_fp8): output = model(data) - test_loss += F.nll_loss( - output, target, reduction="sum" - ).item() # sum up batch loss - pred = output.argmax( - dim=1, keepdim=True - ) # get the index of the max log-probability + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print( - f"\nTest set: Average loss: {test_loss:.4f}, " - f"Accuracy: {correct}/{len(test_loader.dataset)} " - f"({100. * correct / len(test_loader.dataset):.0f}%)\n" - ) + print(f"\nTest set: Average loss: {test_loss:.4f}, " + f"Accuracy: {correct}/{len(test_loader.dataset)} " + f"({100. * correct / len(test_loader.dataset):.0f}%)\n") def main(): @@ -154,9 +149,7 @@ def main(): default=False, help="quickly check a single pass", ) - parser.add_argument( - "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)" - ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") parser.add_argument( "--log-interval", type=int, @@ -170,15 +163,12 @@ def main(): default=False, help="For Saving the current Model", ) - parser.add_argument( - "--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration" - ) - parser.add_argument( - "--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only" - ) - parser.add_argument( - "--use-te", action="store_true", default=False, help="Use Transformer Engine" - ) + parser.add_argument("--use-fp8", + action="store_true", + default=False, + help="Use FP8 for inference and training without recalibration") + parser.add_argument("--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only") + parser.add_argument("--use-te", action="store_true", default=False, help="Use Transformer Engine") args = parser.parse_args() use_cuda = torch.cuda.is_available() @@ -205,9 +195,7 @@ def main(): train_kwargs.update(cuda_kwargs) test_kwargs.update(cuda_kwargs) - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] - ) + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) @@ -227,7 +215,7 @@ def main(): if args.save_model or args.use_fp8_infer: torch.save(model.state_dict(), "mnist_cnn.pt") - print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer)) + print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8_infer)) weights = torch.load("mnist_cnn.pt") model.load_state_dict(weights) test(model, device, test_loader, args.use_fp8_infer) diff --git a/examples/tutorial/opt/inference/batch.py b/examples/tutorial/opt/inference/batch.py index 1a0876ca8338..5bc131b0e683 100644 --- a/examples/tutorial/opt/inference/batch.py +++ b/examples/tutorial/opt/inference/batch.py @@ -1,9 +1,11 @@ +from typing import Any, Deque, Hashable, List, Tuple + import torch -from typing import List, Deque, Tuple, Hashable, Any from energonai import BatchManager, SubmitEntry, TaskEntry class BatchManagerForGeneration(BatchManager): + def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None: super().__init__() self.max_batch_size = max_batch_size diff --git a/examples/tutorial/opt/inference/benchmark/locustfile.py b/examples/tutorial/opt/inference/benchmark/locustfile.py index 4d829e5d83bf..984cc7d31046 100644 --- a/examples/tutorial/opt/inference/benchmark/locustfile.py +++ b/examples/tutorial/opt/inference/benchmark/locustfile.py @@ -1,8 +1,10 @@ -from locust import HttpUser, task from json import JSONDecodeError +from locust import HttpUser, task + class GenerationUser(HttpUser): + @task def generate(self): prompt = 'Question: What is the longest river on the earth? Answer:' diff --git a/examples/tutorial/opt/inference/cache.py b/examples/tutorial/opt/inference/cache.py index 30febc44fbb3..98731121ae1a 100644 --- a/examples/tutorial/opt/inference/cache.py +++ b/examples/tutorial/opt/inference/cache.py @@ -1,7 +1,7 @@ from collections import OrderedDict -from threading import Lock from contextlib import contextmanager -from typing import List, Any, Hashable, Dict +from threading import Lock +from typing import Any, Dict, Hashable, List class MissCacheError(Exception): @@ -9,6 +9,7 @@ class MissCacheError(Exception): class ListCache: + def __init__(self, cache_size: int, list_size: int, fixed_keys: List[Hashable] = []) -> None: """Cache a list of values. The fixed keys won't be removed. For other keys, LRU is applied. When the value list is not full, a cache miss occurs. Otherwise, a cache hit occurs. Redundant values will be removed. diff --git a/examples/tutorial/opt/inference/opt_fastapi.py b/examples/tutorial/opt/inference/opt_fastapi.py index cbfc2a22e7c0..41f011c56505 100644 --- a/examples/tutorial/opt/inference/opt_fastapi.py +++ b/examples/tutorial/opt/inference/opt_fastapi.py @@ -4,20 +4,22 @@ from typing import Optional import uvicorn +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError from energonai import QueueFullError, launch_engine from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel, Field from transformers import GPT2Tokenizer -from batch import BatchManagerForGeneration -from cache import ListCache, MissCacheError - class GenerationTaskReq(BaseModel): max_tokens: int = Field(gt=0, le=256, example=64) prompt: str = Field( - min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + min_length=1, + example= + 'Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:' + ) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) @@ -64,12 +66,7 @@ async def shutdown(*_): def get_model_fn(model_name: str): - model_map = { - 'opt-125m': opt_125M, - 'opt-6.7b': opt_6B, - 'opt-30b': opt_30B, - 'opt-175b': opt_175B - } + model_map = {'opt-125m': opt_125M, 'opt-6.7b': opt_6B, 'opt-30b': opt_30B, 'opt-175b': opt_175B} return model_map[model_name] @@ -80,9 +77,12 @@ def print_args(args: argparse.Namespace): FIXED_CACHE_KEYS = [ - ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), - ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), - ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', + 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', + 64), + ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", + 64) ] if __name__ == '__main__': @@ -112,7 +112,12 @@ def print_args(args: argparse.Namespace): cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) else: cache = None - engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + engine = launch_engine(args.tp, + 1, + args.master_host, + args.master_port, + args.rpc_port, + get_model_fn(args.model), batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id), pipe_size=args.pipe_size, diff --git a/examples/tutorial/opt/inference/opt_server.py b/examples/tutorial/opt/inference/opt_server.py index 8dab82622c59..a4389ba0a704 100644 --- a/examples/tutorial/opt/inference/opt_server.py +++ b/examples/tutorial/opt/inference/opt_server.py @@ -1,24 +1,28 @@ -import logging import argparse +import logging import random -from torch import Tensor -from pydantic import BaseModel, Field from typing import Optional -from energonai.model import opt_125M, opt_30B, opt_175B, opt_6B -from transformers import GPT2Tokenizer -from energonai import launch_engine, QueueFullError + +from batch import BatchManagerForGeneration +from cache import ListCache, MissCacheError +from energonai import QueueFullError, launch_engine +from energonai.model import opt_6B, opt_30B, opt_125M, opt_175B +from pydantic import BaseModel, Field from sanic import Sanic from sanic.request import Request from sanic.response import json -from sanic_ext import validate, openapi -from batch import BatchManagerForGeneration -from cache import ListCache, MissCacheError +from sanic_ext import openapi, validate +from torch import Tensor +from transformers import GPT2Tokenizer class GenerationTaskReq(BaseModel): max_tokens: int = Field(gt=0, le=256, example=64) prompt: str = Field( - min_length=1, example='Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:') + min_length=1, + example= + 'Question: Where were the 2004 Olympics held?\nAnswer: Athens, Greece\n\nQuestion: What is the longest river on the earth?\nAnswer:' + ) top_k: Optional[int] = Field(default=None, gt=0, example=50) top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5) temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7) @@ -65,12 +69,7 @@ def shutdown(*_): def get_model_fn(model_name: str): - model_map = { - 'opt-125m': opt_125M, - 'opt-6.7b': opt_6B, - 'opt-30b': opt_30B, - 'opt-175b': opt_175B - } + model_map = {'opt-125m': opt_125M, 'opt-6.7b': opt_6B, 'opt-30b': opt_30B, 'opt-175b': opt_175B} return model_map[model_name] @@ -81,9 +80,12 @@ def print_args(args: argparse.Namespace): FIXED_CACHE_KEYS = [ - ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', 64), - ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', 64), - ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", 64) + ('Question: What is the name of the largest continent on earth?\nAnswer: Asia\n\nQuestion: What is at the center of the solar system?\nAnswer:', + 64), + ('A chat between a salesman and a student.\n\nSalesman: Hi boy, are you looking for a new phone?\nStudent: Yes, my phone is not functioning well.\nSalesman: What is your budget? \nStudent: I have received my scholarship so I am fine with any phone.\nSalesman: Great, then perhaps this latest flagship phone is just right for you.', + 64), + ("English: I am happy today.\nChinese: 我今天很开心。\n\nEnglish: I am going to play basketball.\nChinese: 我一会去打篮球。\n\nEnglish: Let's celebrate our anniversary.\nChinese:", + 64) ] if __name__ == '__main__': @@ -113,7 +115,12 @@ def print_args(args: argparse.Namespace): cache = ListCache(args.cache_size, args.cache_list_size, fixed_keys=FIXED_CACHE_KEYS) else: cache = None - engine = launch_engine(args.tp, 1, args.master_host, args.master_port, args.rpc_port, get_model_fn(args.model), + engine = launch_engine(args.tp, + 1, + args.master_host, + args.master_port, + args.rpc_port, + get_model_fn(args.model), batch_manager=BatchManagerForGeneration(max_batch_size=args.max_batch_size, pad_token_id=tokenizer.pad_token_id), pipe_size=args.pipe_size, diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/README.md b/examples/tutorial/opt/inference/script/process-opt-175b/README.md index bc3cba72df33..665c459fec69 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/README.md +++ b/examples/tutorial/opt/inference/script/process-opt-175b/README.md @@ -43,4 +43,3 @@ Finally, you will get 8 files in `` with following checksums: 5d63b8750d827a1aa7c8ae5b02a3a2ca reshard-model_part-6.pt f888bd41e009096804fe9a4b48c7ffe8 reshard-model_part-7.pt ``` - diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py index a17ddd4fa173..b849c79daede 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py +++ b/examples/tutorial/opt/inference/script/process-opt-175b/convert_ckpt.py @@ -39,7 +39,8 @@ def convert(flat_dir: str, output_dir: str, part: int): flat_param = flat_sd['model'][flat_key] assert sum(param_meta['numels']) == flat_param.numel( ), f'flat {flat_key} {flat_param.numel()} vs {sum(param_meta["numels"])}' - for name, shape, param in zip(param_meta['names'], param_meta['shapes'], flat_param.split(param_meta['numels'])): + for name, shape, param in zip(param_meta['names'], param_meta['shapes'], + flat_param.split(param_meta['numels'])): output_sd[name] = param.view(shape) torch.save(output_sd, output_path) diff --git a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json index 59d285565cfd..ce70451cc4e5 100644 --- a/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json +++ b/examples/tutorial/opt/inference/script/process-opt-175b/flat-meta.json @@ -1 +1,6944 @@ -{"flat_param_0": {"names": ["decoder.embed_tokens.weight", "decoder.embed_positions.weight", "decoder.layer_norm.weight", "decoder.layer_norm.bias"], "shapes": [[6284, 12288], [2050, 12288], [12288], [12288]], "numels": [77217792, 25190400, 12288, 12288]}, "decoder.layers.0.flat_param_0": {"names": ["decoder.layers.0.self_attn.qkv_proj.weight", "decoder.layers.0.self_attn.qkv_proj.bias", "decoder.layers.0.self_attn.out_proj.weight", "decoder.layers.0.self_attn.out_proj.bias", "decoder.layers.0.self_attn_layer_norm.weight", "decoder.layers.0.self_attn_layer_norm.bias", "decoder.layers.0.fc1.weight", "decoder.layers.0.fc1.bias", "decoder.layers.0.fc2.weight", "decoder.layers.0.fc2.bias", "decoder.layers.0.final_layer_norm.weight", "decoder.layers.0.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.1.flat_param_0": {"names": ["decoder.layers.1.self_attn.qkv_proj.weight", "decoder.layers.1.self_attn.qkv_proj.bias", "decoder.layers.1.self_attn.out_proj.weight", "decoder.layers.1.self_attn.out_proj.bias", "decoder.layers.1.self_attn_layer_norm.weight", "decoder.layers.1.self_attn_layer_norm.bias", "decoder.layers.1.fc1.weight", "decoder.layers.1.fc1.bias", "decoder.layers.1.fc2.weight", "decoder.layers.1.fc2.bias", "decoder.layers.1.final_layer_norm.weight", "decoder.layers.1.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.2.flat_param_0": {"names": ["decoder.layers.2.self_attn.qkv_proj.weight", "decoder.layers.2.self_attn.qkv_proj.bias", "decoder.layers.2.self_attn.out_proj.weight", "decoder.layers.2.self_attn.out_proj.bias", "decoder.layers.2.self_attn_layer_norm.weight", "decoder.layers.2.self_attn_layer_norm.bias", "decoder.layers.2.fc1.weight", "decoder.layers.2.fc1.bias", "decoder.layers.2.fc2.weight", "decoder.layers.2.fc2.bias", "decoder.layers.2.final_layer_norm.weight", "decoder.layers.2.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.3.flat_param_0": {"names": ["decoder.layers.3.self_attn.qkv_proj.weight", "decoder.layers.3.self_attn.qkv_proj.bias", "decoder.layers.3.self_attn.out_proj.weight", "decoder.layers.3.self_attn.out_proj.bias", "decoder.layers.3.self_attn_layer_norm.weight", "decoder.layers.3.self_attn_layer_norm.bias", "decoder.layers.3.fc1.weight", "decoder.layers.3.fc1.bias", "decoder.layers.3.fc2.weight", "decoder.layers.3.fc2.bias", "decoder.layers.3.final_layer_norm.weight", "decoder.layers.3.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.4.flat_param_0": {"names": ["decoder.layers.4.self_attn.qkv_proj.weight", "decoder.layers.4.self_attn.qkv_proj.bias", "decoder.layers.4.self_attn.out_proj.weight", "decoder.layers.4.self_attn.out_proj.bias", "decoder.layers.4.self_attn_layer_norm.weight", "decoder.layers.4.self_attn_layer_norm.bias", "decoder.layers.4.fc1.weight", "decoder.layers.4.fc1.bias", "decoder.layers.4.fc2.weight", "decoder.layers.4.fc2.bias", "decoder.layers.4.final_layer_norm.weight", "decoder.layers.4.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.5.flat_param_0": {"names": ["decoder.layers.5.self_attn.qkv_proj.weight", "decoder.layers.5.self_attn.qkv_proj.bias", "decoder.layers.5.self_attn.out_proj.weight", "decoder.layers.5.self_attn.out_proj.bias", "decoder.layers.5.self_attn_layer_norm.weight", "decoder.layers.5.self_attn_layer_norm.bias", "decoder.layers.5.fc1.weight", "decoder.layers.5.fc1.bias", "decoder.layers.5.fc2.weight", "decoder.layers.5.fc2.bias", "decoder.layers.5.final_layer_norm.weight", "decoder.layers.5.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.6.flat_param_0": {"names": ["decoder.layers.6.self_attn.qkv_proj.weight", "decoder.layers.6.self_attn.qkv_proj.bias", "decoder.layers.6.self_attn.out_proj.weight", "decoder.layers.6.self_attn.out_proj.bias", "decoder.layers.6.self_attn_layer_norm.weight", "decoder.layers.6.self_attn_layer_norm.bias", "decoder.layers.6.fc1.weight", "decoder.layers.6.fc1.bias", "decoder.layers.6.fc2.weight", "decoder.layers.6.fc2.bias", "decoder.layers.6.final_layer_norm.weight", "decoder.layers.6.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.7.flat_param_0": {"names": ["decoder.layers.7.self_attn.qkv_proj.weight", "decoder.layers.7.self_attn.qkv_proj.bias", "decoder.layers.7.self_attn.out_proj.weight", "decoder.layers.7.self_attn.out_proj.bias", "decoder.layers.7.self_attn_layer_norm.weight", "decoder.layers.7.self_attn_layer_norm.bias", "decoder.layers.7.fc1.weight", "decoder.layers.7.fc1.bias", "decoder.layers.7.fc2.weight", "decoder.layers.7.fc2.bias", "decoder.layers.7.final_layer_norm.weight", "decoder.layers.7.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.8.flat_param_0": {"names": ["decoder.layers.8.self_attn.qkv_proj.weight", "decoder.layers.8.self_attn.qkv_proj.bias", "decoder.layers.8.self_attn.out_proj.weight", "decoder.layers.8.self_attn.out_proj.bias", "decoder.layers.8.self_attn_layer_norm.weight", "decoder.layers.8.self_attn_layer_norm.bias", "decoder.layers.8.fc1.weight", "decoder.layers.8.fc1.bias", "decoder.layers.8.fc2.weight", "decoder.layers.8.fc2.bias", "decoder.layers.8.final_layer_norm.weight", "decoder.layers.8.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.9.flat_param_0": {"names": ["decoder.layers.9.self_attn.qkv_proj.weight", "decoder.layers.9.self_attn.qkv_proj.bias", "decoder.layers.9.self_attn.out_proj.weight", "decoder.layers.9.self_attn.out_proj.bias", "decoder.layers.9.self_attn_layer_norm.weight", "decoder.layers.9.self_attn_layer_norm.bias", "decoder.layers.9.fc1.weight", "decoder.layers.9.fc1.bias", "decoder.layers.9.fc2.weight", "decoder.layers.9.fc2.bias", "decoder.layers.9.final_layer_norm.weight", "decoder.layers.9.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.10.flat_param_0": {"names": ["decoder.layers.10.self_attn.qkv_proj.weight", "decoder.layers.10.self_attn.qkv_proj.bias", "decoder.layers.10.self_attn.out_proj.weight", "decoder.layers.10.self_attn.out_proj.bias", "decoder.layers.10.self_attn_layer_norm.weight", "decoder.layers.10.self_attn_layer_norm.bias", "decoder.layers.10.fc1.weight", "decoder.layers.10.fc1.bias", "decoder.layers.10.fc2.weight", "decoder.layers.10.fc2.bias", "decoder.layers.10.final_layer_norm.weight", "decoder.layers.10.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.11.flat_param_0": {"names": ["decoder.layers.11.self_attn.qkv_proj.weight", "decoder.layers.11.self_attn.qkv_proj.bias", "decoder.layers.11.self_attn.out_proj.weight", "decoder.layers.11.self_attn.out_proj.bias", "decoder.layers.11.self_attn_layer_norm.weight", "decoder.layers.11.self_attn_layer_norm.bias", "decoder.layers.11.fc1.weight", "decoder.layers.11.fc1.bias", "decoder.layers.11.fc2.weight", "decoder.layers.11.fc2.bias", "decoder.layers.11.final_layer_norm.weight", "decoder.layers.11.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.12.flat_param_0": {"names": ["decoder.layers.12.self_attn.qkv_proj.weight", "decoder.layers.12.self_attn.qkv_proj.bias", "decoder.layers.12.self_attn.out_proj.weight", "decoder.layers.12.self_attn.out_proj.bias", "decoder.layers.12.self_attn_layer_norm.weight", "decoder.layers.12.self_attn_layer_norm.bias", "decoder.layers.12.fc1.weight", "decoder.layers.12.fc1.bias", "decoder.layers.12.fc2.weight", "decoder.layers.12.fc2.bias", "decoder.layers.12.final_layer_norm.weight", "decoder.layers.12.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.13.flat_param_0": {"names": ["decoder.layers.13.self_attn.qkv_proj.weight", "decoder.layers.13.self_attn.qkv_proj.bias", "decoder.layers.13.self_attn.out_proj.weight", "decoder.layers.13.self_attn.out_proj.bias", "decoder.layers.13.self_attn_layer_norm.weight", "decoder.layers.13.self_attn_layer_norm.bias", "decoder.layers.13.fc1.weight", "decoder.layers.13.fc1.bias", "decoder.layers.13.fc2.weight", "decoder.layers.13.fc2.bias", "decoder.layers.13.final_layer_norm.weight", "decoder.layers.13.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.14.flat_param_0": {"names": ["decoder.layers.14.self_attn.qkv_proj.weight", "decoder.layers.14.self_attn.qkv_proj.bias", "decoder.layers.14.self_attn.out_proj.weight", "decoder.layers.14.self_attn.out_proj.bias", "decoder.layers.14.self_attn_layer_norm.weight", "decoder.layers.14.self_attn_layer_norm.bias", "decoder.layers.14.fc1.weight", "decoder.layers.14.fc1.bias", "decoder.layers.14.fc2.weight", "decoder.layers.14.fc2.bias", "decoder.layers.14.final_layer_norm.weight", "decoder.layers.14.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.15.flat_param_0": {"names": ["decoder.layers.15.self_attn.qkv_proj.weight", "decoder.layers.15.self_attn.qkv_proj.bias", "decoder.layers.15.self_attn.out_proj.weight", "decoder.layers.15.self_attn.out_proj.bias", "decoder.layers.15.self_attn_layer_norm.weight", "decoder.layers.15.self_attn_layer_norm.bias", "decoder.layers.15.fc1.weight", "decoder.layers.15.fc1.bias", "decoder.layers.15.fc2.weight", "decoder.layers.15.fc2.bias", "decoder.layers.15.final_layer_norm.weight", "decoder.layers.15.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.16.flat_param_0": {"names": ["decoder.layers.16.self_attn.qkv_proj.weight", "decoder.layers.16.self_attn.qkv_proj.bias", "decoder.layers.16.self_attn.out_proj.weight", "decoder.layers.16.self_attn.out_proj.bias", "decoder.layers.16.self_attn_layer_norm.weight", "decoder.layers.16.self_attn_layer_norm.bias", "decoder.layers.16.fc1.weight", "decoder.layers.16.fc1.bias", "decoder.layers.16.fc2.weight", "decoder.layers.16.fc2.bias", "decoder.layers.16.final_layer_norm.weight", "decoder.layers.16.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.17.flat_param_0": {"names": ["decoder.layers.17.self_attn.qkv_proj.weight", "decoder.layers.17.self_attn.qkv_proj.bias", "decoder.layers.17.self_attn.out_proj.weight", "decoder.layers.17.self_attn.out_proj.bias", "decoder.layers.17.self_attn_layer_norm.weight", "decoder.layers.17.self_attn_layer_norm.bias", "decoder.layers.17.fc1.weight", "decoder.layers.17.fc1.bias", "decoder.layers.17.fc2.weight", "decoder.layers.17.fc2.bias", "decoder.layers.17.final_layer_norm.weight", "decoder.layers.17.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.18.flat_param_0": {"names": ["decoder.layers.18.self_attn.qkv_proj.weight", "decoder.layers.18.self_attn.qkv_proj.bias", "decoder.layers.18.self_attn.out_proj.weight", "decoder.layers.18.self_attn.out_proj.bias", "decoder.layers.18.self_attn_layer_norm.weight", "decoder.layers.18.self_attn_layer_norm.bias", "decoder.layers.18.fc1.weight", "decoder.layers.18.fc1.bias", "decoder.layers.18.fc2.weight", "decoder.layers.18.fc2.bias", "decoder.layers.18.final_layer_norm.weight", "decoder.layers.18.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.19.flat_param_0": {"names": ["decoder.layers.19.self_attn.qkv_proj.weight", "decoder.layers.19.self_attn.qkv_proj.bias", "decoder.layers.19.self_attn.out_proj.weight", "decoder.layers.19.self_attn.out_proj.bias", "decoder.layers.19.self_attn_layer_norm.weight", "decoder.layers.19.self_attn_layer_norm.bias", "decoder.layers.19.fc1.weight", "decoder.layers.19.fc1.bias", "decoder.layers.19.fc2.weight", "decoder.layers.19.fc2.bias", "decoder.layers.19.final_layer_norm.weight", "decoder.layers.19.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.20.flat_param_0": {"names": ["decoder.layers.20.self_attn.qkv_proj.weight", "decoder.layers.20.self_attn.qkv_proj.bias", "decoder.layers.20.self_attn.out_proj.weight", "decoder.layers.20.self_attn.out_proj.bias", "decoder.layers.20.self_attn_layer_norm.weight", "decoder.layers.20.self_attn_layer_norm.bias", "decoder.layers.20.fc1.weight", "decoder.layers.20.fc1.bias", "decoder.layers.20.fc2.weight", "decoder.layers.20.fc2.bias", "decoder.layers.20.final_layer_norm.weight", "decoder.layers.20.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.21.flat_param_0": {"names": ["decoder.layers.21.self_attn.qkv_proj.weight", "decoder.layers.21.self_attn.qkv_proj.bias", "decoder.layers.21.self_attn.out_proj.weight", "decoder.layers.21.self_attn.out_proj.bias", "decoder.layers.21.self_attn_layer_norm.weight", "decoder.layers.21.self_attn_layer_norm.bias", "decoder.layers.21.fc1.weight", "decoder.layers.21.fc1.bias", "decoder.layers.21.fc2.weight", "decoder.layers.21.fc2.bias", "decoder.layers.21.final_layer_norm.weight", "decoder.layers.21.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.22.flat_param_0": {"names": ["decoder.layers.22.self_attn.qkv_proj.weight", "decoder.layers.22.self_attn.qkv_proj.bias", "decoder.layers.22.self_attn.out_proj.weight", "decoder.layers.22.self_attn.out_proj.bias", "decoder.layers.22.self_attn_layer_norm.weight", "decoder.layers.22.self_attn_layer_norm.bias", "decoder.layers.22.fc1.weight", "decoder.layers.22.fc1.bias", "decoder.layers.22.fc2.weight", "decoder.layers.22.fc2.bias", "decoder.layers.22.final_layer_norm.weight", "decoder.layers.22.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.23.flat_param_0": {"names": ["decoder.layers.23.self_attn.qkv_proj.weight", "decoder.layers.23.self_attn.qkv_proj.bias", "decoder.layers.23.self_attn.out_proj.weight", "decoder.layers.23.self_attn.out_proj.bias", "decoder.layers.23.self_attn_layer_norm.weight", "decoder.layers.23.self_attn_layer_norm.bias", "decoder.layers.23.fc1.weight", "decoder.layers.23.fc1.bias", "decoder.layers.23.fc2.weight", "decoder.layers.23.fc2.bias", "decoder.layers.23.final_layer_norm.weight", "decoder.layers.23.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.24.flat_param_0": {"names": ["decoder.layers.24.self_attn.qkv_proj.weight", "decoder.layers.24.self_attn.qkv_proj.bias", "decoder.layers.24.self_attn.out_proj.weight", "decoder.layers.24.self_attn.out_proj.bias", "decoder.layers.24.self_attn_layer_norm.weight", "decoder.layers.24.self_attn_layer_norm.bias", "decoder.layers.24.fc1.weight", "decoder.layers.24.fc1.bias", "decoder.layers.24.fc2.weight", "decoder.layers.24.fc2.bias", "decoder.layers.24.final_layer_norm.weight", "decoder.layers.24.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.25.flat_param_0": {"names": ["decoder.layers.25.self_attn.qkv_proj.weight", "decoder.layers.25.self_attn.qkv_proj.bias", "decoder.layers.25.self_attn.out_proj.weight", "decoder.layers.25.self_attn.out_proj.bias", "decoder.layers.25.self_attn_layer_norm.weight", "decoder.layers.25.self_attn_layer_norm.bias", "decoder.layers.25.fc1.weight", "decoder.layers.25.fc1.bias", "decoder.layers.25.fc2.weight", "decoder.layers.25.fc2.bias", "decoder.layers.25.final_layer_norm.weight", "decoder.layers.25.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.26.flat_param_0": {"names": ["decoder.layers.26.self_attn.qkv_proj.weight", "decoder.layers.26.self_attn.qkv_proj.bias", "decoder.layers.26.self_attn.out_proj.weight", "decoder.layers.26.self_attn.out_proj.bias", "decoder.layers.26.self_attn_layer_norm.weight", "decoder.layers.26.self_attn_layer_norm.bias", "decoder.layers.26.fc1.weight", "decoder.layers.26.fc1.bias", "decoder.layers.26.fc2.weight", "decoder.layers.26.fc2.bias", "decoder.layers.26.final_layer_norm.weight", "decoder.layers.26.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.27.flat_param_0": {"names": ["decoder.layers.27.self_attn.qkv_proj.weight", "decoder.layers.27.self_attn.qkv_proj.bias", "decoder.layers.27.self_attn.out_proj.weight", "decoder.layers.27.self_attn.out_proj.bias", "decoder.layers.27.self_attn_layer_norm.weight", "decoder.layers.27.self_attn_layer_norm.bias", "decoder.layers.27.fc1.weight", "decoder.layers.27.fc1.bias", "decoder.layers.27.fc2.weight", "decoder.layers.27.fc2.bias", "decoder.layers.27.final_layer_norm.weight", "decoder.layers.27.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.28.flat_param_0": {"names": ["decoder.layers.28.self_attn.qkv_proj.weight", "decoder.layers.28.self_attn.qkv_proj.bias", "decoder.layers.28.self_attn.out_proj.weight", "decoder.layers.28.self_attn.out_proj.bias", "decoder.layers.28.self_attn_layer_norm.weight", "decoder.layers.28.self_attn_layer_norm.bias", "decoder.layers.28.fc1.weight", "decoder.layers.28.fc1.bias", "decoder.layers.28.fc2.weight", "decoder.layers.28.fc2.bias", "decoder.layers.28.final_layer_norm.weight", "decoder.layers.28.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.29.flat_param_0": {"names": ["decoder.layers.29.self_attn.qkv_proj.weight", "decoder.layers.29.self_attn.qkv_proj.bias", "decoder.layers.29.self_attn.out_proj.weight", "decoder.layers.29.self_attn.out_proj.bias", "decoder.layers.29.self_attn_layer_norm.weight", "decoder.layers.29.self_attn_layer_norm.bias", "decoder.layers.29.fc1.weight", "decoder.layers.29.fc1.bias", "decoder.layers.29.fc2.weight", "decoder.layers.29.fc2.bias", "decoder.layers.29.final_layer_norm.weight", "decoder.layers.29.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.30.flat_param_0": {"names": ["decoder.layers.30.self_attn.qkv_proj.weight", "decoder.layers.30.self_attn.qkv_proj.bias", "decoder.layers.30.self_attn.out_proj.weight", "decoder.layers.30.self_attn.out_proj.bias", "decoder.layers.30.self_attn_layer_norm.weight", "decoder.layers.30.self_attn_layer_norm.bias", "decoder.layers.30.fc1.weight", "decoder.layers.30.fc1.bias", "decoder.layers.30.fc2.weight", "decoder.layers.30.fc2.bias", "decoder.layers.30.final_layer_norm.weight", "decoder.layers.30.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.31.flat_param_0": {"names": ["decoder.layers.31.self_attn.qkv_proj.weight", "decoder.layers.31.self_attn.qkv_proj.bias", "decoder.layers.31.self_attn.out_proj.weight", "decoder.layers.31.self_attn.out_proj.bias", "decoder.layers.31.self_attn_layer_norm.weight", "decoder.layers.31.self_attn_layer_norm.bias", "decoder.layers.31.fc1.weight", "decoder.layers.31.fc1.bias", "decoder.layers.31.fc2.weight", "decoder.layers.31.fc2.bias", "decoder.layers.31.final_layer_norm.weight", "decoder.layers.31.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.32.flat_param_0": {"names": ["decoder.layers.32.self_attn.qkv_proj.weight", "decoder.layers.32.self_attn.qkv_proj.bias", "decoder.layers.32.self_attn.out_proj.weight", "decoder.layers.32.self_attn.out_proj.bias", "decoder.layers.32.self_attn_layer_norm.weight", "decoder.layers.32.self_attn_layer_norm.bias", "decoder.layers.32.fc1.weight", "decoder.layers.32.fc1.bias", "decoder.layers.32.fc2.weight", "decoder.layers.32.fc2.bias", "decoder.layers.32.final_layer_norm.weight", "decoder.layers.32.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.33.flat_param_0": {"names": ["decoder.layers.33.self_attn.qkv_proj.weight", "decoder.layers.33.self_attn.qkv_proj.bias", "decoder.layers.33.self_attn.out_proj.weight", "decoder.layers.33.self_attn.out_proj.bias", "decoder.layers.33.self_attn_layer_norm.weight", "decoder.layers.33.self_attn_layer_norm.bias", "decoder.layers.33.fc1.weight", "decoder.layers.33.fc1.bias", "decoder.layers.33.fc2.weight", "decoder.layers.33.fc2.bias", "decoder.layers.33.final_layer_norm.weight", "decoder.layers.33.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.34.flat_param_0": {"names": ["decoder.layers.34.self_attn.qkv_proj.weight", "decoder.layers.34.self_attn.qkv_proj.bias", "decoder.layers.34.self_attn.out_proj.weight", "decoder.layers.34.self_attn.out_proj.bias", "decoder.layers.34.self_attn_layer_norm.weight", "decoder.layers.34.self_attn_layer_norm.bias", "decoder.layers.34.fc1.weight", "decoder.layers.34.fc1.bias", "decoder.layers.34.fc2.weight", "decoder.layers.34.fc2.bias", "decoder.layers.34.final_layer_norm.weight", "decoder.layers.34.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.35.flat_param_0": {"names": ["decoder.layers.35.self_attn.qkv_proj.weight", "decoder.layers.35.self_attn.qkv_proj.bias", "decoder.layers.35.self_attn.out_proj.weight", "decoder.layers.35.self_attn.out_proj.bias", "decoder.layers.35.self_attn_layer_norm.weight", "decoder.layers.35.self_attn_layer_norm.bias", "decoder.layers.35.fc1.weight", "decoder.layers.35.fc1.bias", "decoder.layers.35.fc2.weight", "decoder.layers.35.fc2.bias", "decoder.layers.35.final_layer_norm.weight", "decoder.layers.35.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.36.flat_param_0": {"names": ["decoder.layers.36.self_attn.qkv_proj.weight", "decoder.layers.36.self_attn.qkv_proj.bias", "decoder.layers.36.self_attn.out_proj.weight", "decoder.layers.36.self_attn.out_proj.bias", "decoder.layers.36.self_attn_layer_norm.weight", "decoder.layers.36.self_attn_layer_norm.bias", "decoder.layers.36.fc1.weight", "decoder.layers.36.fc1.bias", "decoder.layers.36.fc2.weight", "decoder.layers.36.fc2.bias", "decoder.layers.36.final_layer_norm.weight", "decoder.layers.36.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.37.flat_param_0": {"names": ["decoder.layers.37.self_attn.qkv_proj.weight", "decoder.layers.37.self_attn.qkv_proj.bias", "decoder.layers.37.self_attn.out_proj.weight", "decoder.layers.37.self_attn.out_proj.bias", "decoder.layers.37.self_attn_layer_norm.weight", "decoder.layers.37.self_attn_layer_norm.bias", "decoder.layers.37.fc1.weight", "decoder.layers.37.fc1.bias", "decoder.layers.37.fc2.weight", "decoder.layers.37.fc2.bias", "decoder.layers.37.final_layer_norm.weight", "decoder.layers.37.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.38.flat_param_0": {"names": ["decoder.layers.38.self_attn.qkv_proj.weight", "decoder.layers.38.self_attn.qkv_proj.bias", "decoder.layers.38.self_attn.out_proj.weight", "decoder.layers.38.self_attn.out_proj.bias", "decoder.layers.38.self_attn_layer_norm.weight", "decoder.layers.38.self_attn_layer_norm.bias", "decoder.layers.38.fc1.weight", "decoder.layers.38.fc1.bias", "decoder.layers.38.fc2.weight", "decoder.layers.38.fc2.bias", "decoder.layers.38.final_layer_norm.weight", "decoder.layers.38.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.39.flat_param_0": {"names": ["decoder.layers.39.self_attn.qkv_proj.weight", "decoder.layers.39.self_attn.qkv_proj.bias", "decoder.layers.39.self_attn.out_proj.weight", "decoder.layers.39.self_attn.out_proj.bias", "decoder.layers.39.self_attn_layer_norm.weight", "decoder.layers.39.self_attn_layer_norm.bias", "decoder.layers.39.fc1.weight", "decoder.layers.39.fc1.bias", "decoder.layers.39.fc2.weight", "decoder.layers.39.fc2.bias", "decoder.layers.39.final_layer_norm.weight", "decoder.layers.39.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.40.flat_param_0": {"names": ["decoder.layers.40.self_attn.qkv_proj.weight", "decoder.layers.40.self_attn.qkv_proj.bias", "decoder.layers.40.self_attn.out_proj.weight", "decoder.layers.40.self_attn.out_proj.bias", "decoder.layers.40.self_attn_layer_norm.weight", "decoder.layers.40.self_attn_layer_norm.bias", "decoder.layers.40.fc1.weight", "decoder.layers.40.fc1.bias", "decoder.layers.40.fc2.weight", "decoder.layers.40.fc2.bias", "decoder.layers.40.final_layer_norm.weight", "decoder.layers.40.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.41.flat_param_0": {"names": ["decoder.layers.41.self_attn.qkv_proj.weight", "decoder.layers.41.self_attn.qkv_proj.bias", "decoder.layers.41.self_attn.out_proj.weight", "decoder.layers.41.self_attn.out_proj.bias", "decoder.layers.41.self_attn_layer_norm.weight", "decoder.layers.41.self_attn_layer_norm.bias", "decoder.layers.41.fc1.weight", "decoder.layers.41.fc1.bias", "decoder.layers.41.fc2.weight", "decoder.layers.41.fc2.bias", "decoder.layers.41.final_layer_norm.weight", "decoder.layers.41.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.42.flat_param_0": {"names": ["decoder.layers.42.self_attn.qkv_proj.weight", "decoder.layers.42.self_attn.qkv_proj.bias", "decoder.layers.42.self_attn.out_proj.weight", "decoder.layers.42.self_attn.out_proj.bias", "decoder.layers.42.self_attn_layer_norm.weight", "decoder.layers.42.self_attn_layer_norm.bias", "decoder.layers.42.fc1.weight", "decoder.layers.42.fc1.bias", "decoder.layers.42.fc2.weight", "decoder.layers.42.fc2.bias", "decoder.layers.42.final_layer_norm.weight", "decoder.layers.42.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.43.flat_param_0": {"names": ["decoder.layers.43.self_attn.qkv_proj.weight", "decoder.layers.43.self_attn.qkv_proj.bias", "decoder.layers.43.self_attn.out_proj.weight", "decoder.layers.43.self_attn.out_proj.bias", "decoder.layers.43.self_attn_layer_norm.weight", "decoder.layers.43.self_attn_layer_norm.bias", "decoder.layers.43.fc1.weight", "decoder.layers.43.fc1.bias", "decoder.layers.43.fc2.weight", "decoder.layers.43.fc2.bias", "decoder.layers.43.final_layer_norm.weight", "decoder.layers.43.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.44.flat_param_0": {"names": ["decoder.layers.44.self_attn.qkv_proj.weight", "decoder.layers.44.self_attn.qkv_proj.bias", "decoder.layers.44.self_attn.out_proj.weight", "decoder.layers.44.self_attn.out_proj.bias", "decoder.layers.44.self_attn_layer_norm.weight", "decoder.layers.44.self_attn_layer_norm.bias", "decoder.layers.44.fc1.weight", "decoder.layers.44.fc1.bias", "decoder.layers.44.fc2.weight", "decoder.layers.44.fc2.bias", "decoder.layers.44.final_layer_norm.weight", "decoder.layers.44.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.45.flat_param_0": {"names": ["decoder.layers.45.self_attn.qkv_proj.weight", "decoder.layers.45.self_attn.qkv_proj.bias", "decoder.layers.45.self_attn.out_proj.weight", "decoder.layers.45.self_attn.out_proj.bias", "decoder.layers.45.self_attn_layer_norm.weight", "decoder.layers.45.self_attn_layer_norm.bias", "decoder.layers.45.fc1.weight", "decoder.layers.45.fc1.bias", "decoder.layers.45.fc2.weight", "decoder.layers.45.fc2.bias", "decoder.layers.45.final_layer_norm.weight", "decoder.layers.45.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.46.flat_param_0": {"names": ["decoder.layers.46.self_attn.qkv_proj.weight", "decoder.layers.46.self_attn.qkv_proj.bias", "decoder.layers.46.self_attn.out_proj.weight", "decoder.layers.46.self_attn.out_proj.bias", "decoder.layers.46.self_attn_layer_norm.weight", "decoder.layers.46.self_attn_layer_norm.bias", "decoder.layers.46.fc1.weight", "decoder.layers.46.fc1.bias", "decoder.layers.46.fc2.weight", "decoder.layers.46.fc2.bias", "decoder.layers.46.final_layer_norm.weight", "decoder.layers.46.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.47.flat_param_0": {"names": ["decoder.layers.47.self_attn.qkv_proj.weight", "decoder.layers.47.self_attn.qkv_proj.bias", "decoder.layers.47.self_attn.out_proj.weight", "decoder.layers.47.self_attn.out_proj.bias", "decoder.layers.47.self_attn_layer_norm.weight", "decoder.layers.47.self_attn_layer_norm.bias", "decoder.layers.47.fc1.weight", "decoder.layers.47.fc1.bias", "decoder.layers.47.fc2.weight", "decoder.layers.47.fc2.bias", "decoder.layers.47.final_layer_norm.weight", "decoder.layers.47.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.48.flat_param_0": {"names": ["decoder.layers.48.self_attn.qkv_proj.weight", "decoder.layers.48.self_attn.qkv_proj.bias", "decoder.layers.48.self_attn.out_proj.weight", "decoder.layers.48.self_attn.out_proj.bias", "decoder.layers.48.self_attn_layer_norm.weight", "decoder.layers.48.self_attn_layer_norm.bias", "decoder.layers.48.fc1.weight", "decoder.layers.48.fc1.bias", "decoder.layers.48.fc2.weight", "decoder.layers.48.fc2.bias", "decoder.layers.48.final_layer_norm.weight", "decoder.layers.48.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.49.flat_param_0": {"names": ["decoder.layers.49.self_attn.qkv_proj.weight", "decoder.layers.49.self_attn.qkv_proj.bias", "decoder.layers.49.self_attn.out_proj.weight", "decoder.layers.49.self_attn.out_proj.bias", "decoder.layers.49.self_attn_layer_norm.weight", "decoder.layers.49.self_attn_layer_norm.bias", "decoder.layers.49.fc1.weight", "decoder.layers.49.fc1.bias", "decoder.layers.49.fc2.weight", "decoder.layers.49.fc2.bias", "decoder.layers.49.final_layer_norm.weight", "decoder.layers.49.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.50.flat_param_0": {"names": ["decoder.layers.50.self_attn.qkv_proj.weight", "decoder.layers.50.self_attn.qkv_proj.bias", "decoder.layers.50.self_attn.out_proj.weight", "decoder.layers.50.self_attn.out_proj.bias", "decoder.layers.50.self_attn_layer_norm.weight", "decoder.layers.50.self_attn_layer_norm.bias", "decoder.layers.50.fc1.weight", "decoder.layers.50.fc1.bias", "decoder.layers.50.fc2.weight", "decoder.layers.50.fc2.bias", "decoder.layers.50.final_layer_norm.weight", "decoder.layers.50.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.51.flat_param_0": {"names": ["decoder.layers.51.self_attn.qkv_proj.weight", "decoder.layers.51.self_attn.qkv_proj.bias", "decoder.layers.51.self_attn.out_proj.weight", "decoder.layers.51.self_attn.out_proj.bias", "decoder.layers.51.self_attn_layer_norm.weight", "decoder.layers.51.self_attn_layer_norm.bias", "decoder.layers.51.fc1.weight", "decoder.layers.51.fc1.bias", "decoder.layers.51.fc2.weight", "decoder.layers.51.fc2.bias", "decoder.layers.51.final_layer_norm.weight", "decoder.layers.51.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.52.flat_param_0": {"names": ["decoder.layers.52.self_attn.qkv_proj.weight", "decoder.layers.52.self_attn.qkv_proj.bias", "decoder.layers.52.self_attn.out_proj.weight", "decoder.layers.52.self_attn.out_proj.bias", "decoder.layers.52.self_attn_layer_norm.weight", "decoder.layers.52.self_attn_layer_norm.bias", "decoder.layers.52.fc1.weight", "decoder.layers.52.fc1.bias", "decoder.layers.52.fc2.weight", "decoder.layers.52.fc2.bias", "decoder.layers.52.final_layer_norm.weight", "decoder.layers.52.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.53.flat_param_0": {"names": ["decoder.layers.53.self_attn.qkv_proj.weight", "decoder.layers.53.self_attn.qkv_proj.bias", "decoder.layers.53.self_attn.out_proj.weight", "decoder.layers.53.self_attn.out_proj.bias", "decoder.layers.53.self_attn_layer_norm.weight", "decoder.layers.53.self_attn_layer_norm.bias", "decoder.layers.53.fc1.weight", "decoder.layers.53.fc1.bias", "decoder.layers.53.fc2.weight", "decoder.layers.53.fc2.bias", "decoder.layers.53.final_layer_norm.weight", "decoder.layers.53.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.54.flat_param_0": {"names": ["decoder.layers.54.self_attn.qkv_proj.weight", "decoder.layers.54.self_attn.qkv_proj.bias", "decoder.layers.54.self_attn.out_proj.weight", "decoder.layers.54.self_attn.out_proj.bias", "decoder.layers.54.self_attn_layer_norm.weight", "decoder.layers.54.self_attn_layer_norm.bias", "decoder.layers.54.fc1.weight", "decoder.layers.54.fc1.bias", "decoder.layers.54.fc2.weight", "decoder.layers.54.fc2.bias", "decoder.layers.54.final_layer_norm.weight", "decoder.layers.54.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.55.flat_param_0": {"names": ["decoder.layers.55.self_attn.qkv_proj.weight", "decoder.layers.55.self_attn.qkv_proj.bias", "decoder.layers.55.self_attn.out_proj.weight", "decoder.layers.55.self_attn.out_proj.bias", "decoder.layers.55.self_attn_layer_norm.weight", "decoder.layers.55.self_attn_layer_norm.bias", "decoder.layers.55.fc1.weight", "decoder.layers.55.fc1.bias", "decoder.layers.55.fc2.weight", "decoder.layers.55.fc2.bias", "decoder.layers.55.final_layer_norm.weight", "decoder.layers.55.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.56.flat_param_0": {"names": ["decoder.layers.56.self_attn.qkv_proj.weight", "decoder.layers.56.self_attn.qkv_proj.bias", "decoder.layers.56.self_attn.out_proj.weight", "decoder.layers.56.self_attn.out_proj.bias", "decoder.layers.56.self_attn_layer_norm.weight", "decoder.layers.56.self_attn_layer_norm.bias", "decoder.layers.56.fc1.weight", "decoder.layers.56.fc1.bias", "decoder.layers.56.fc2.weight", "decoder.layers.56.fc2.bias", "decoder.layers.56.final_layer_norm.weight", "decoder.layers.56.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.57.flat_param_0": {"names": ["decoder.layers.57.self_attn.qkv_proj.weight", "decoder.layers.57.self_attn.qkv_proj.bias", "decoder.layers.57.self_attn.out_proj.weight", "decoder.layers.57.self_attn.out_proj.bias", "decoder.layers.57.self_attn_layer_norm.weight", "decoder.layers.57.self_attn_layer_norm.bias", "decoder.layers.57.fc1.weight", "decoder.layers.57.fc1.bias", "decoder.layers.57.fc2.weight", "decoder.layers.57.fc2.bias", "decoder.layers.57.final_layer_norm.weight", "decoder.layers.57.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.58.flat_param_0": {"names": ["decoder.layers.58.self_attn.qkv_proj.weight", "decoder.layers.58.self_attn.qkv_proj.bias", "decoder.layers.58.self_attn.out_proj.weight", "decoder.layers.58.self_attn.out_proj.bias", "decoder.layers.58.self_attn_layer_norm.weight", "decoder.layers.58.self_attn_layer_norm.bias", "decoder.layers.58.fc1.weight", "decoder.layers.58.fc1.bias", "decoder.layers.58.fc2.weight", "decoder.layers.58.fc2.bias", "decoder.layers.58.final_layer_norm.weight", "decoder.layers.58.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.59.flat_param_0": {"names": ["decoder.layers.59.self_attn.qkv_proj.weight", "decoder.layers.59.self_attn.qkv_proj.bias", "decoder.layers.59.self_attn.out_proj.weight", "decoder.layers.59.self_attn.out_proj.bias", "decoder.layers.59.self_attn_layer_norm.weight", "decoder.layers.59.self_attn_layer_norm.bias", "decoder.layers.59.fc1.weight", "decoder.layers.59.fc1.bias", "decoder.layers.59.fc2.weight", "decoder.layers.59.fc2.bias", "decoder.layers.59.final_layer_norm.weight", "decoder.layers.59.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.60.flat_param_0": {"names": ["decoder.layers.60.self_attn.qkv_proj.weight", "decoder.layers.60.self_attn.qkv_proj.bias", "decoder.layers.60.self_attn.out_proj.weight", "decoder.layers.60.self_attn.out_proj.bias", "decoder.layers.60.self_attn_layer_norm.weight", "decoder.layers.60.self_attn_layer_norm.bias", "decoder.layers.60.fc1.weight", "decoder.layers.60.fc1.bias", "decoder.layers.60.fc2.weight", "decoder.layers.60.fc2.bias", "decoder.layers.60.final_layer_norm.weight", "decoder.layers.60.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.61.flat_param_0": {"names": ["decoder.layers.61.self_attn.qkv_proj.weight", "decoder.layers.61.self_attn.qkv_proj.bias", "decoder.layers.61.self_attn.out_proj.weight", "decoder.layers.61.self_attn.out_proj.bias", "decoder.layers.61.self_attn_layer_norm.weight", "decoder.layers.61.self_attn_layer_norm.bias", "decoder.layers.61.fc1.weight", "decoder.layers.61.fc1.bias", "decoder.layers.61.fc2.weight", "decoder.layers.61.fc2.bias", "decoder.layers.61.final_layer_norm.weight", "decoder.layers.61.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.62.flat_param_0": {"names": ["decoder.layers.62.self_attn.qkv_proj.weight", "decoder.layers.62.self_attn.qkv_proj.bias", "decoder.layers.62.self_attn.out_proj.weight", "decoder.layers.62.self_attn.out_proj.bias", "decoder.layers.62.self_attn_layer_norm.weight", "decoder.layers.62.self_attn_layer_norm.bias", "decoder.layers.62.fc1.weight", "decoder.layers.62.fc1.bias", "decoder.layers.62.fc2.weight", "decoder.layers.62.fc2.bias", "decoder.layers.62.final_layer_norm.weight", "decoder.layers.62.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.63.flat_param_0": {"names": ["decoder.layers.63.self_attn.qkv_proj.weight", "decoder.layers.63.self_attn.qkv_proj.bias", "decoder.layers.63.self_attn.out_proj.weight", "decoder.layers.63.self_attn.out_proj.bias", "decoder.layers.63.self_attn_layer_norm.weight", "decoder.layers.63.self_attn_layer_norm.bias", "decoder.layers.63.fc1.weight", "decoder.layers.63.fc1.bias", "decoder.layers.63.fc2.weight", "decoder.layers.63.fc2.bias", "decoder.layers.63.final_layer_norm.weight", "decoder.layers.63.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.64.flat_param_0": {"names": ["decoder.layers.64.self_attn.qkv_proj.weight", "decoder.layers.64.self_attn.qkv_proj.bias", "decoder.layers.64.self_attn.out_proj.weight", "decoder.layers.64.self_attn.out_proj.bias", "decoder.layers.64.self_attn_layer_norm.weight", "decoder.layers.64.self_attn_layer_norm.bias", "decoder.layers.64.fc1.weight", "decoder.layers.64.fc1.bias", "decoder.layers.64.fc2.weight", "decoder.layers.64.fc2.bias", "decoder.layers.64.final_layer_norm.weight", "decoder.layers.64.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.65.flat_param_0": {"names": ["decoder.layers.65.self_attn.qkv_proj.weight", "decoder.layers.65.self_attn.qkv_proj.bias", "decoder.layers.65.self_attn.out_proj.weight", "decoder.layers.65.self_attn.out_proj.bias", "decoder.layers.65.self_attn_layer_norm.weight", "decoder.layers.65.self_attn_layer_norm.bias", "decoder.layers.65.fc1.weight", "decoder.layers.65.fc1.bias", "decoder.layers.65.fc2.weight", "decoder.layers.65.fc2.bias", "decoder.layers.65.final_layer_norm.weight", "decoder.layers.65.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.66.flat_param_0": {"names": ["decoder.layers.66.self_attn.qkv_proj.weight", "decoder.layers.66.self_attn.qkv_proj.bias", "decoder.layers.66.self_attn.out_proj.weight", "decoder.layers.66.self_attn.out_proj.bias", "decoder.layers.66.self_attn_layer_norm.weight", "decoder.layers.66.self_attn_layer_norm.bias", "decoder.layers.66.fc1.weight", "decoder.layers.66.fc1.bias", "decoder.layers.66.fc2.weight", "decoder.layers.66.fc2.bias", "decoder.layers.66.final_layer_norm.weight", "decoder.layers.66.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.67.flat_param_0": {"names": ["decoder.layers.67.self_attn.qkv_proj.weight", "decoder.layers.67.self_attn.qkv_proj.bias", "decoder.layers.67.self_attn.out_proj.weight", "decoder.layers.67.self_attn.out_proj.bias", "decoder.layers.67.self_attn_layer_norm.weight", "decoder.layers.67.self_attn_layer_norm.bias", "decoder.layers.67.fc1.weight", "decoder.layers.67.fc1.bias", "decoder.layers.67.fc2.weight", "decoder.layers.67.fc2.bias", "decoder.layers.67.final_layer_norm.weight", "decoder.layers.67.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.68.flat_param_0": {"names": ["decoder.layers.68.self_attn.qkv_proj.weight", "decoder.layers.68.self_attn.qkv_proj.bias", "decoder.layers.68.self_attn.out_proj.weight", "decoder.layers.68.self_attn.out_proj.bias", "decoder.layers.68.self_attn_layer_norm.weight", "decoder.layers.68.self_attn_layer_norm.bias", "decoder.layers.68.fc1.weight", "decoder.layers.68.fc1.bias", "decoder.layers.68.fc2.weight", "decoder.layers.68.fc2.bias", "decoder.layers.68.final_layer_norm.weight", "decoder.layers.68.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.69.flat_param_0": {"names": ["decoder.layers.69.self_attn.qkv_proj.weight", "decoder.layers.69.self_attn.qkv_proj.bias", "decoder.layers.69.self_attn.out_proj.weight", "decoder.layers.69.self_attn.out_proj.bias", "decoder.layers.69.self_attn_layer_norm.weight", "decoder.layers.69.self_attn_layer_norm.bias", "decoder.layers.69.fc1.weight", "decoder.layers.69.fc1.bias", "decoder.layers.69.fc2.weight", "decoder.layers.69.fc2.bias", "decoder.layers.69.final_layer_norm.weight", "decoder.layers.69.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.70.flat_param_0": {"names": ["decoder.layers.70.self_attn.qkv_proj.weight", "decoder.layers.70.self_attn.qkv_proj.bias", "decoder.layers.70.self_attn.out_proj.weight", "decoder.layers.70.self_attn.out_proj.bias", "decoder.layers.70.self_attn_layer_norm.weight", "decoder.layers.70.self_attn_layer_norm.bias", "decoder.layers.70.fc1.weight", "decoder.layers.70.fc1.bias", "decoder.layers.70.fc2.weight", "decoder.layers.70.fc2.bias", "decoder.layers.70.final_layer_norm.weight", "decoder.layers.70.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.71.flat_param_0": {"names": ["decoder.layers.71.self_attn.qkv_proj.weight", "decoder.layers.71.self_attn.qkv_proj.bias", "decoder.layers.71.self_attn.out_proj.weight", "decoder.layers.71.self_attn.out_proj.bias", "decoder.layers.71.self_attn_layer_norm.weight", "decoder.layers.71.self_attn_layer_norm.bias", "decoder.layers.71.fc1.weight", "decoder.layers.71.fc1.bias", "decoder.layers.71.fc2.weight", "decoder.layers.71.fc2.bias", "decoder.layers.71.final_layer_norm.weight", "decoder.layers.71.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.72.flat_param_0": {"names": ["decoder.layers.72.self_attn.qkv_proj.weight", "decoder.layers.72.self_attn.qkv_proj.bias", "decoder.layers.72.self_attn.out_proj.weight", "decoder.layers.72.self_attn.out_proj.bias", "decoder.layers.72.self_attn_layer_norm.weight", "decoder.layers.72.self_attn_layer_norm.bias", "decoder.layers.72.fc1.weight", "decoder.layers.72.fc1.bias", "decoder.layers.72.fc2.weight", "decoder.layers.72.fc2.bias", "decoder.layers.72.final_layer_norm.weight", "decoder.layers.72.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.73.flat_param_0": {"names": ["decoder.layers.73.self_attn.qkv_proj.weight", "decoder.layers.73.self_attn.qkv_proj.bias", "decoder.layers.73.self_attn.out_proj.weight", "decoder.layers.73.self_attn.out_proj.bias", "decoder.layers.73.self_attn_layer_norm.weight", "decoder.layers.73.self_attn_layer_norm.bias", "decoder.layers.73.fc1.weight", "decoder.layers.73.fc1.bias", "decoder.layers.73.fc2.weight", "decoder.layers.73.fc2.bias", "decoder.layers.73.final_layer_norm.weight", "decoder.layers.73.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.74.flat_param_0": {"names": ["decoder.layers.74.self_attn.qkv_proj.weight", "decoder.layers.74.self_attn.qkv_proj.bias", "decoder.layers.74.self_attn.out_proj.weight", "decoder.layers.74.self_attn.out_proj.bias", "decoder.layers.74.self_attn_layer_norm.weight", "decoder.layers.74.self_attn_layer_norm.bias", "decoder.layers.74.fc1.weight", "decoder.layers.74.fc1.bias", "decoder.layers.74.fc2.weight", "decoder.layers.74.fc2.bias", "decoder.layers.74.final_layer_norm.weight", "decoder.layers.74.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.75.flat_param_0": {"names": ["decoder.layers.75.self_attn.qkv_proj.weight", "decoder.layers.75.self_attn.qkv_proj.bias", "decoder.layers.75.self_attn.out_proj.weight", "decoder.layers.75.self_attn.out_proj.bias", "decoder.layers.75.self_attn_layer_norm.weight", "decoder.layers.75.self_attn_layer_norm.bias", "decoder.layers.75.fc1.weight", "decoder.layers.75.fc1.bias", "decoder.layers.75.fc2.weight", "decoder.layers.75.fc2.bias", "decoder.layers.75.final_layer_norm.weight", "decoder.layers.75.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.76.flat_param_0": {"names": ["decoder.layers.76.self_attn.qkv_proj.weight", "decoder.layers.76.self_attn.qkv_proj.bias", "decoder.layers.76.self_attn.out_proj.weight", "decoder.layers.76.self_attn.out_proj.bias", "decoder.layers.76.self_attn_layer_norm.weight", "decoder.layers.76.self_attn_layer_norm.bias", "decoder.layers.76.fc1.weight", "decoder.layers.76.fc1.bias", "decoder.layers.76.fc2.weight", "decoder.layers.76.fc2.bias", "decoder.layers.76.final_layer_norm.weight", "decoder.layers.76.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.77.flat_param_0": {"names": ["decoder.layers.77.self_attn.qkv_proj.weight", "decoder.layers.77.self_attn.qkv_proj.bias", "decoder.layers.77.self_attn.out_proj.weight", "decoder.layers.77.self_attn.out_proj.bias", "decoder.layers.77.self_attn_layer_norm.weight", "decoder.layers.77.self_attn_layer_norm.bias", "decoder.layers.77.fc1.weight", "decoder.layers.77.fc1.bias", "decoder.layers.77.fc2.weight", "decoder.layers.77.fc2.bias", "decoder.layers.77.final_layer_norm.weight", "decoder.layers.77.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.78.flat_param_0": {"names": ["decoder.layers.78.self_attn.qkv_proj.weight", "decoder.layers.78.self_attn.qkv_proj.bias", "decoder.layers.78.self_attn.out_proj.weight", "decoder.layers.78.self_attn.out_proj.bias", "decoder.layers.78.self_attn_layer_norm.weight", "decoder.layers.78.self_attn_layer_norm.bias", "decoder.layers.78.fc1.weight", "decoder.layers.78.fc1.bias", "decoder.layers.78.fc2.weight", "decoder.layers.78.fc2.bias", "decoder.layers.78.final_layer_norm.weight", "decoder.layers.78.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.79.flat_param_0": {"names": ["decoder.layers.79.self_attn.qkv_proj.weight", "decoder.layers.79.self_attn.qkv_proj.bias", "decoder.layers.79.self_attn.out_proj.weight", "decoder.layers.79.self_attn.out_proj.bias", "decoder.layers.79.self_attn_layer_norm.weight", "decoder.layers.79.self_attn_layer_norm.bias", "decoder.layers.79.fc1.weight", "decoder.layers.79.fc1.bias", "decoder.layers.79.fc2.weight", "decoder.layers.79.fc2.bias", "decoder.layers.79.final_layer_norm.weight", "decoder.layers.79.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.80.flat_param_0": {"names": ["decoder.layers.80.self_attn.qkv_proj.weight", "decoder.layers.80.self_attn.qkv_proj.bias", "decoder.layers.80.self_attn.out_proj.weight", "decoder.layers.80.self_attn.out_proj.bias", "decoder.layers.80.self_attn_layer_norm.weight", "decoder.layers.80.self_attn_layer_norm.bias", "decoder.layers.80.fc1.weight", "decoder.layers.80.fc1.bias", "decoder.layers.80.fc2.weight", "decoder.layers.80.fc2.bias", "decoder.layers.80.final_layer_norm.weight", "decoder.layers.80.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.81.flat_param_0": {"names": ["decoder.layers.81.self_attn.qkv_proj.weight", "decoder.layers.81.self_attn.qkv_proj.bias", "decoder.layers.81.self_attn.out_proj.weight", "decoder.layers.81.self_attn.out_proj.bias", "decoder.layers.81.self_attn_layer_norm.weight", "decoder.layers.81.self_attn_layer_norm.bias", "decoder.layers.81.fc1.weight", "decoder.layers.81.fc1.bias", "decoder.layers.81.fc2.weight", "decoder.layers.81.fc2.bias", "decoder.layers.81.final_layer_norm.weight", "decoder.layers.81.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.82.flat_param_0": {"names": ["decoder.layers.82.self_attn.qkv_proj.weight", "decoder.layers.82.self_attn.qkv_proj.bias", "decoder.layers.82.self_attn.out_proj.weight", "decoder.layers.82.self_attn.out_proj.bias", "decoder.layers.82.self_attn_layer_norm.weight", "decoder.layers.82.self_attn_layer_norm.bias", "decoder.layers.82.fc1.weight", "decoder.layers.82.fc1.bias", "decoder.layers.82.fc2.weight", "decoder.layers.82.fc2.bias", "decoder.layers.82.final_layer_norm.weight", "decoder.layers.82.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.83.flat_param_0": {"names": ["decoder.layers.83.self_attn.qkv_proj.weight", "decoder.layers.83.self_attn.qkv_proj.bias", "decoder.layers.83.self_attn.out_proj.weight", "decoder.layers.83.self_attn.out_proj.bias", "decoder.layers.83.self_attn_layer_norm.weight", "decoder.layers.83.self_attn_layer_norm.bias", "decoder.layers.83.fc1.weight", "decoder.layers.83.fc1.bias", "decoder.layers.83.fc2.weight", "decoder.layers.83.fc2.bias", "decoder.layers.83.final_layer_norm.weight", "decoder.layers.83.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.84.flat_param_0": {"names": ["decoder.layers.84.self_attn.qkv_proj.weight", "decoder.layers.84.self_attn.qkv_proj.bias", "decoder.layers.84.self_attn.out_proj.weight", "decoder.layers.84.self_attn.out_proj.bias", "decoder.layers.84.self_attn_layer_norm.weight", "decoder.layers.84.self_attn_layer_norm.bias", "decoder.layers.84.fc1.weight", "decoder.layers.84.fc1.bias", "decoder.layers.84.fc2.weight", "decoder.layers.84.fc2.bias", "decoder.layers.84.final_layer_norm.weight", "decoder.layers.84.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.85.flat_param_0": {"names": ["decoder.layers.85.self_attn.qkv_proj.weight", "decoder.layers.85.self_attn.qkv_proj.bias", "decoder.layers.85.self_attn.out_proj.weight", "decoder.layers.85.self_attn.out_proj.bias", "decoder.layers.85.self_attn_layer_norm.weight", "decoder.layers.85.self_attn_layer_norm.bias", "decoder.layers.85.fc1.weight", "decoder.layers.85.fc1.bias", "decoder.layers.85.fc2.weight", "decoder.layers.85.fc2.bias", "decoder.layers.85.final_layer_norm.weight", "decoder.layers.85.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.86.flat_param_0": {"names": ["decoder.layers.86.self_attn.qkv_proj.weight", "decoder.layers.86.self_attn.qkv_proj.bias", "decoder.layers.86.self_attn.out_proj.weight", "decoder.layers.86.self_attn.out_proj.bias", "decoder.layers.86.self_attn_layer_norm.weight", "decoder.layers.86.self_attn_layer_norm.bias", "decoder.layers.86.fc1.weight", "decoder.layers.86.fc1.bias", "decoder.layers.86.fc2.weight", "decoder.layers.86.fc2.bias", "decoder.layers.86.final_layer_norm.weight", "decoder.layers.86.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.87.flat_param_0": {"names": ["decoder.layers.87.self_attn.qkv_proj.weight", "decoder.layers.87.self_attn.qkv_proj.bias", "decoder.layers.87.self_attn.out_proj.weight", "decoder.layers.87.self_attn.out_proj.bias", "decoder.layers.87.self_attn_layer_norm.weight", "decoder.layers.87.self_attn_layer_norm.bias", "decoder.layers.87.fc1.weight", "decoder.layers.87.fc1.bias", "decoder.layers.87.fc2.weight", "decoder.layers.87.fc2.bias", "decoder.layers.87.final_layer_norm.weight", "decoder.layers.87.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.88.flat_param_0": {"names": ["decoder.layers.88.self_attn.qkv_proj.weight", "decoder.layers.88.self_attn.qkv_proj.bias", "decoder.layers.88.self_attn.out_proj.weight", "decoder.layers.88.self_attn.out_proj.bias", "decoder.layers.88.self_attn_layer_norm.weight", "decoder.layers.88.self_attn_layer_norm.bias", "decoder.layers.88.fc1.weight", "decoder.layers.88.fc1.bias", "decoder.layers.88.fc2.weight", "decoder.layers.88.fc2.bias", "decoder.layers.88.final_layer_norm.weight", "decoder.layers.88.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.89.flat_param_0": {"names": ["decoder.layers.89.self_attn.qkv_proj.weight", "decoder.layers.89.self_attn.qkv_proj.bias", "decoder.layers.89.self_attn.out_proj.weight", "decoder.layers.89.self_attn.out_proj.bias", "decoder.layers.89.self_attn_layer_norm.weight", "decoder.layers.89.self_attn_layer_norm.bias", "decoder.layers.89.fc1.weight", "decoder.layers.89.fc1.bias", "decoder.layers.89.fc2.weight", "decoder.layers.89.fc2.bias", "decoder.layers.89.final_layer_norm.weight", "decoder.layers.89.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.90.flat_param_0": {"names": ["decoder.layers.90.self_attn.qkv_proj.weight", "decoder.layers.90.self_attn.qkv_proj.bias", "decoder.layers.90.self_attn.out_proj.weight", "decoder.layers.90.self_attn.out_proj.bias", "decoder.layers.90.self_attn_layer_norm.weight", "decoder.layers.90.self_attn_layer_norm.bias", "decoder.layers.90.fc1.weight", "decoder.layers.90.fc1.bias", "decoder.layers.90.fc2.weight", "decoder.layers.90.fc2.bias", "decoder.layers.90.final_layer_norm.weight", "decoder.layers.90.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.91.flat_param_0": {"names": ["decoder.layers.91.self_attn.qkv_proj.weight", "decoder.layers.91.self_attn.qkv_proj.bias", "decoder.layers.91.self_attn.out_proj.weight", "decoder.layers.91.self_attn.out_proj.bias", "decoder.layers.91.self_attn_layer_norm.weight", "decoder.layers.91.self_attn_layer_norm.bias", "decoder.layers.91.fc1.weight", "decoder.layers.91.fc1.bias", "decoder.layers.91.fc2.weight", "decoder.layers.91.fc2.bias", "decoder.layers.91.final_layer_norm.weight", "decoder.layers.91.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.92.flat_param_0": {"names": ["decoder.layers.92.self_attn.qkv_proj.weight", "decoder.layers.92.self_attn.qkv_proj.bias", "decoder.layers.92.self_attn.out_proj.weight", "decoder.layers.92.self_attn.out_proj.bias", "decoder.layers.92.self_attn_layer_norm.weight", "decoder.layers.92.self_attn_layer_norm.bias", "decoder.layers.92.fc1.weight", "decoder.layers.92.fc1.bias", "decoder.layers.92.fc2.weight", "decoder.layers.92.fc2.bias", "decoder.layers.92.final_layer_norm.weight", "decoder.layers.92.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.93.flat_param_0": {"names": ["decoder.layers.93.self_attn.qkv_proj.weight", "decoder.layers.93.self_attn.qkv_proj.bias", "decoder.layers.93.self_attn.out_proj.weight", "decoder.layers.93.self_attn.out_proj.bias", "decoder.layers.93.self_attn_layer_norm.weight", "decoder.layers.93.self_attn_layer_norm.bias", "decoder.layers.93.fc1.weight", "decoder.layers.93.fc1.bias", "decoder.layers.93.fc2.weight", "decoder.layers.93.fc2.bias", "decoder.layers.93.final_layer_norm.weight", "decoder.layers.93.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.94.flat_param_0": {"names": ["decoder.layers.94.self_attn.qkv_proj.weight", "decoder.layers.94.self_attn.qkv_proj.bias", "decoder.layers.94.self_attn.out_proj.weight", "decoder.layers.94.self_attn.out_proj.bias", "decoder.layers.94.self_attn_layer_norm.weight", "decoder.layers.94.self_attn_layer_norm.bias", "decoder.layers.94.fc1.weight", "decoder.layers.94.fc1.bias", "decoder.layers.94.fc2.weight", "decoder.layers.94.fc2.bias", "decoder.layers.94.final_layer_norm.weight", "decoder.layers.94.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}, "decoder.layers.95.flat_param_0": {"names": ["decoder.layers.95.self_attn.qkv_proj.weight", "decoder.layers.95.self_attn.qkv_proj.bias", "decoder.layers.95.self_attn.out_proj.weight", "decoder.layers.95.self_attn.out_proj.bias", "decoder.layers.95.self_attn_layer_norm.weight", "decoder.layers.95.self_attn_layer_norm.bias", "decoder.layers.95.fc1.weight", "decoder.layers.95.fc1.bias", "decoder.layers.95.fc2.weight", "decoder.layers.95.fc2.bias", "decoder.layers.95.final_layer_norm.weight", "decoder.layers.95.final_layer_norm.bias"], "shapes": [[4608, 12288], [4608], [12288, 1536], [12288], [12288], [12288], [6144, 12288], [6144], [12288, 6144], [12288], [12288], [12288]], "numels": [56623104, 4608, 18874368, 12288, 12288, 12288, 75497472, 6144, 75497472, 12288, 12288, 12288]}} \ No newline at end of file +{ + "flat_param_0": { + "names": [ + "decoder.embed_tokens.weight", + "decoder.embed_positions.weight", + "decoder.layer_norm.weight", + "decoder.layer_norm.bias" + ], + "shapes": [ + [ + 6284, + 12288 + ], + [ + 2050, + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 77217792, + 25190400, + 12288, + 12288 + ] + }, + "decoder.layers.0.flat_param_0": { + "names": [ + "decoder.layers.0.self_attn.qkv_proj.weight", + "decoder.layers.0.self_attn.qkv_proj.bias", + "decoder.layers.0.self_attn.out_proj.weight", + "decoder.layers.0.self_attn.out_proj.bias", + "decoder.layers.0.self_attn_layer_norm.weight", + "decoder.layers.0.self_attn_layer_norm.bias", + "decoder.layers.0.fc1.weight", + "decoder.layers.0.fc1.bias", + "decoder.layers.0.fc2.weight", + "decoder.layers.0.fc2.bias", + "decoder.layers.0.final_layer_norm.weight", + "decoder.layers.0.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.1.flat_param_0": { + "names": [ + "decoder.layers.1.self_attn.qkv_proj.weight", + "decoder.layers.1.self_attn.qkv_proj.bias", + "decoder.layers.1.self_attn.out_proj.weight", + "decoder.layers.1.self_attn.out_proj.bias", + "decoder.layers.1.self_attn_layer_norm.weight", + "decoder.layers.1.self_attn_layer_norm.bias", + "decoder.layers.1.fc1.weight", + "decoder.layers.1.fc1.bias", + "decoder.layers.1.fc2.weight", + "decoder.layers.1.fc2.bias", + "decoder.layers.1.final_layer_norm.weight", + "decoder.layers.1.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.2.flat_param_0": { + "names": [ + "decoder.layers.2.self_attn.qkv_proj.weight", + "decoder.layers.2.self_attn.qkv_proj.bias", + "decoder.layers.2.self_attn.out_proj.weight", + "decoder.layers.2.self_attn.out_proj.bias", + "decoder.layers.2.self_attn_layer_norm.weight", + "decoder.layers.2.self_attn_layer_norm.bias", + "decoder.layers.2.fc1.weight", + "decoder.layers.2.fc1.bias", + "decoder.layers.2.fc2.weight", + "decoder.layers.2.fc2.bias", + "decoder.layers.2.final_layer_norm.weight", + "decoder.layers.2.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.3.flat_param_0": { + "names": [ + "decoder.layers.3.self_attn.qkv_proj.weight", + "decoder.layers.3.self_attn.qkv_proj.bias", + "decoder.layers.3.self_attn.out_proj.weight", + "decoder.layers.3.self_attn.out_proj.bias", + "decoder.layers.3.self_attn_layer_norm.weight", + "decoder.layers.3.self_attn_layer_norm.bias", + "decoder.layers.3.fc1.weight", + "decoder.layers.3.fc1.bias", + "decoder.layers.3.fc2.weight", + "decoder.layers.3.fc2.bias", + "decoder.layers.3.final_layer_norm.weight", + "decoder.layers.3.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.4.flat_param_0": { + "names": [ + "decoder.layers.4.self_attn.qkv_proj.weight", + "decoder.layers.4.self_attn.qkv_proj.bias", + "decoder.layers.4.self_attn.out_proj.weight", + "decoder.layers.4.self_attn.out_proj.bias", + "decoder.layers.4.self_attn_layer_norm.weight", + "decoder.layers.4.self_attn_layer_norm.bias", + "decoder.layers.4.fc1.weight", + "decoder.layers.4.fc1.bias", + "decoder.layers.4.fc2.weight", + "decoder.layers.4.fc2.bias", + "decoder.layers.4.final_layer_norm.weight", + "decoder.layers.4.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.5.flat_param_0": { + "names": [ + "decoder.layers.5.self_attn.qkv_proj.weight", + "decoder.layers.5.self_attn.qkv_proj.bias", + "decoder.layers.5.self_attn.out_proj.weight", + "decoder.layers.5.self_attn.out_proj.bias", + "decoder.layers.5.self_attn_layer_norm.weight", + "decoder.layers.5.self_attn_layer_norm.bias", + "decoder.layers.5.fc1.weight", + "decoder.layers.5.fc1.bias", + "decoder.layers.5.fc2.weight", + "decoder.layers.5.fc2.bias", + "decoder.layers.5.final_layer_norm.weight", + "decoder.layers.5.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.6.flat_param_0": { + "names": [ + "decoder.layers.6.self_attn.qkv_proj.weight", + "decoder.layers.6.self_attn.qkv_proj.bias", + "decoder.layers.6.self_attn.out_proj.weight", + "decoder.layers.6.self_attn.out_proj.bias", + "decoder.layers.6.self_attn_layer_norm.weight", + "decoder.layers.6.self_attn_layer_norm.bias", + "decoder.layers.6.fc1.weight", + "decoder.layers.6.fc1.bias", + "decoder.layers.6.fc2.weight", + "decoder.layers.6.fc2.bias", + "decoder.layers.6.final_layer_norm.weight", + "decoder.layers.6.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.7.flat_param_0": { + "names": [ + "decoder.layers.7.self_attn.qkv_proj.weight", + "decoder.layers.7.self_attn.qkv_proj.bias", + "decoder.layers.7.self_attn.out_proj.weight", + "decoder.layers.7.self_attn.out_proj.bias", + "decoder.layers.7.self_attn_layer_norm.weight", + "decoder.layers.7.self_attn_layer_norm.bias", + "decoder.layers.7.fc1.weight", + "decoder.layers.7.fc1.bias", + "decoder.layers.7.fc2.weight", + "decoder.layers.7.fc2.bias", + "decoder.layers.7.final_layer_norm.weight", + "decoder.layers.7.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.8.flat_param_0": { + "names": [ + "decoder.layers.8.self_attn.qkv_proj.weight", + "decoder.layers.8.self_attn.qkv_proj.bias", + "decoder.layers.8.self_attn.out_proj.weight", + "decoder.layers.8.self_attn.out_proj.bias", + "decoder.layers.8.self_attn_layer_norm.weight", + "decoder.layers.8.self_attn_layer_norm.bias", + "decoder.layers.8.fc1.weight", + "decoder.layers.8.fc1.bias", + "decoder.layers.8.fc2.weight", + "decoder.layers.8.fc2.bias", + "decoder.layers.8.final_layer_norm.weight", + "decoder.layers.8.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.9.flat_param_0": { + "names": [ + "decoder.layers.9.self_attn.qkv_proj.weight", + "decoder.layers.9.self_attn.qkv_proj.bias", + "decoder.layers.9.self_attn.out_proj.weight", + "decoder.layers.9.self_attn.out_proj.bias", + "decoder.layers.9.self_attn_layer_norm.weight", + "decoder.layers.9.self_attn_layer_norm.bias", + "decoder.layers.9.fc1.weight", + "decoder.layers.9.fc1.bias", + "decoder.layers.9.fc2.weight", + "decoder.layers.9.fc2.bias", + "decoder.layers.9.final_layer_norm.weight", + "decoder.layers.9.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.10.flat_param_0": { + "names": [ + "decoder.layers.10.self_attn.qkv_proj.weight", + "decoder.layers.10.self_attn.qkv_proj.bias", + "decoder.layers.10.self_attn.out_proj.weight", + "decoder.layers.10.self_attn.out_proj.bias", + "decoder.layers.10.self_attn_layer_norm.weight", + "decoder.layers.10.self_attn_layer_norm.bias", + "decoder.layers.10.fc1.weight", + "decoder.layers.10.fc1.bias", + "decoder.layers.10.fc2.weight", + "decoder.layers.10.fc2.bias", + "decoder.layers.10.final_layer_norm.weight", + "decoder.layers.10.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.11.flat_param_0": { + "names": [ + "decoder.layers.11.self_attn.qkv_proj.weight", + "decoder.layers.11.self_attn.qkv_proj.bias", + "decoder.layers.11.self_attn.out_proj.weight", + "decoder.layers.11.self_attn.out_proj.bias", + "decoder.layers.11.self_attn_layer_norm.weight", + "decoder.layers.11.self_attn_layer_norm.bias", + "decoder.layers.11.fc1.weight", + "decoder.layers.11.fc1.bias", + "decoder.layers.11.fc2.weight", + "decoder.layers.11.fc2.bias", + "decoder.layers.11.final_layer_norm.weight", + "decoder.layers.11.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.12.flat_param_0": { + "names": [ + "decoder.layers.12.self_attn.qkv_proj.weight", + "decoder.layers.12.self_attn.qkv_proj.bias", + "decoder.layers.12.self_attn.out_proj.weight", + "decoder.layers.12.self_attn.out_proj.bias", + "decoder.layers.12.self_attn_layer_norm.weight", + "decoder.layers.12.self_attn_layer_norm.bias", + "decoder.layers.12.fc1.weight", + "decoder.layers.12.fc1.bias", + "decoder.layers.12.fc2.weight", + "decoder.layers.12.fc2.bias", + "decoder.layers.12.final_layer_norm.weight", + "decoder.layers.12.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.13.flat_param_0": { + "names": [ + "decoder.layers.13.self_attn.qkv_proj.weight", + "decoder.layers.13.self_attn.qkv_proj.bias", + "decoder.layers.13.self_attn.out_proj.weight", + "decoder.layers.13.self_attn.out_proj.bias", + "decoder.layers.13.self_attn_layer_norm.weight", + "decoder.layers.13.self_attn_layer_norm.bias", + "decoder.layers.13.fc1.weight", + "decoder.layers.13.fc1.bias", + "decoder.layers.13.fc2.weight", + "decoder.layers.13.fc2.bias", + "decoder.layers.13.final_layer_norm.weight", + "decoder.layers.13.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.14.flat_param_0": { + "names": [ + "decoder.layers.14.self_attn.qkv_proj.weight", + "decoder.layers.14.self_attn.qkv_proj.bias", + "decoder.layers.14.self_attn.out_proj.weight", + "decoder.layers.14.self_attn.out_proj.bias", + "decoder.layers.14.self_attn_layer_norm.weight", + "decoder.layers.14.self_attn_layer_norm.bias", + "decoder.layers.14.fc1.weight", + "decoder.layers.14.fc1.bias", + "decoder.layers.14.fc2.weight", + "decoder.layers.14.fc2.bias", + "decoder.layers.14.final_layer_norm.weight", + "decoder.layers.14.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.15.flat_param_0": { + "names": [ + "decoder.layers.15.self_attn.qkv_proj.weight", + "decoder.layers.15.self_attn.qkv_proj.bias", + "decoder.layers.15.self_attn.out_proj.weight", + "decoder.layers.15.self_attn.out_proj.bias", + "decoder.layers.15.self_attn_layer_norm.weight", + "decoder.layers.15.self_attn_layer_norm.bias", + "decoder.layers.15.fc1.weight", + "decoder.layers.15.fc1.bias", + "decoder.layers.15.fc2.weight", + "decoder.layers.15.fc2.bias", + "decoder.layers.15.final_layer_norm.weight", + "decoder.layers.15.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.16.flat_param_0": { + "names": [ + "decoder.layers.16.self_attn.qkv_proj.weight", + "decoder.layers.16.self_attn.qkv_proj.bias", + "decoder.layers.16.self_attn.out_proj.weight", + "decoder.layers.16.self_attn.out_proj.bias", + "decoder.layers.16.self_attn_layer_norm.weight", + "decoder.layers.16.self_attn_layer_norm.bias", + "decoder.layers.16.fc1.weight", + "decoder.layers.16.fc1.bias", + "decoder.layers.16.fc2.weight", + "decoder.layers.16.fc2.bias", + "decoder.layers.16.final_layer_norm.weight", + "decoder.layers.16.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.17.flat_param_0": { + "names": [ + "decoder.layers.17.self_attn.qkv_proj.weight", + "decoder.layers.17.self_attn.qkv_proj.bias", + "decoder.layers.17.self_attn.out_proj.weight", + "decoder.layers.17.self_attn.out_proj.bias", + "decoder.layers.17.self_attn_layer_norm.weight", + "decoder.layers.17.self_attn_layer_norm.bias", + "decoder.layers.17.fc1.weight", + "decoder.layers.17.fc1.bias", + "decoder.layers.17.fc2.weight", + "decoder.layers.17.fc2.bias", + "decoder.layers.17.final_layer_norm.weight", + "decoder.layers.17.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.18.flat_param_0": { + "names": [ + "decoder.layers.18.self_attn.qkv_proj.weight", + "decoder.layers.18.self_attn.qkv_proj.bias", + "decoder.layers.18.self_attn.out_proj.weight", + "decoder.layers.18.self_attn.out_proj.bias", + "decoder.layers.18.self_attn_layer_norm.weight", + "decoder.layers.18.self_attn_layer_norm.bias", + "decoder.layers.18.fc1.weight", + "decoder.layers.18.fc1.bias", + "decoder.layers.18.fc2.weight", + "decoder.layers.18.fc2.bias", + "decoder.layers.18.final_layer_norm.weight", + "decoder.layers.18.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.19.flat_param_0": { + "names": [ + "decoder.layers.19.self_attn.qkv_proj.weight", + "decoder.layers.19.self_attn.qkv_proj.bias", + "decoder.layers.19.self_attn.out_proj.weight", + "decoder.layers.19.self_attn.out_proj.bias", + "decoder.layers.19.self_attn_layer_norm.weight", + "decoder.layers.19.self_attn_layer_norm.bias", + "decoder.layers.19.fc1.weight", + "decoder.layers.19.fc1.bias", + "decoder.layers.19.fc2.weight", + "decoder.layers.19.fc2.bias", + "decoder.layers.19.final_layer_norm.weight", + "decoder.layers.19.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.20.flat_param_0": { + "names": [ + "decoder.layers.20.self_attn.qkv_proj.weight", + "decoder.layers.20.self_attn.qkv_proj.bias", + "decoder.layers.20.self_attn.out_proj.weight", + "decoder.layers.20.self_attn.out_proj.bias", + "decoder.layers.20.self_attn_layer_norm.weight", + "decoder.layers.20.self_attn_layer_norm.bias", + "decoder.layers.20.fc1.weight", + "decoder.layers.20.fc1.bias", + "decoder.layers.20.fc2.weight", + "decoder.layers.20.fc2.bias", + "decoder.layers.20.final_layer_norm.weight", + "decoder.layers.20.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.21.flat_param_0": { + "names": [ + "decoder.layers.21.self_attn.qkv_proj.weight", + "decoder.layers.21.self_attn.qkv_proj.bias", + "decoder.layers.21.self_attn.out_proj.weight", + "decoder.layers.21.self_attn.out_proj.bias", + "decoder.layers.21.self_attn_layer_norm.weight", + "decoder.layers.21.self_attn_layer_norm.bias", + "decoder.layers.21.fc1.weight", + "decoder.layers.21.fc1.bias", + "decoder.layers.21.fc2.weight", + "decoder.layers.21.fc2.bias", + "decoder.layers.21.final_layer_norm.weight", + "decoder.layers.21.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.22.flat_param_0": { + "names": [ + "decoder.layers.22.self_attn.qkv_proj.weight", + "decoder.layers.22.self_attn.qkv_proj.bias", + "decoder.layers.22.self_attn.out_proj.weight", + "decoder.layers.22.self_attn.out_proj.bias", + "decoder.layers.22.self_attn_layer_norm.weight", + "decoder.layers.22.self_attn_layer_norm.bias", + "decoder.layers.22.fc1.weight", + "decoder.layers.22.fc1.bias", + "decoder.layers.22.fc2.weight", + "decoder.layers.22.fc2.bias", + "decoder.layers.22.final_layer_norm.weight", + "decoder.layers.22.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.23.flat_param_0": { + "names": [ + "decoder.layers.23.self_attn.qkv_proj.weight", + "decoder.layers.23.self_attn.qkv_proj.bias", + "decoder.layers.23.self_attn.out_proj.weight", + "decoder.layers.23.self_attn.out_proj.bias", + "decoder.layers.23.self_attn_layer_norm.weight", + "decoder.layers.23.self_attn_layer_norm.bias", + "decoder.layers.23.fc1.weight", + "decoder.layers.23.fc1.bias", + "decoder.layers.23.fc2.weight", + "decoder.layers.23.fc2.bias", + "decoder.layers.23.final_layer_norm.weight", + "decoder.layers.23.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.24.flat_param_0": { + "names": [ + "decoder.layers.24.self_attn.qkv_proj.weight", + "decoder.layers.24.self_attn.qkv_proj.bias", + "decoder.layers.24.self_attn.out_proj.weight", + "decoder.layers.24.self_attn.out_proj.bias", + "decoder.layers.24.self_attn_layer_norm.weight", + "decoder.layers.24.self_attn_layer_norm.bias", + "decoder.layers.24.fc1.weight", + "decoder.layers.24.fc1.bias", + "decoder.layers.24.fc2.weight", + "decoder.layers.24.fc2.bias", + "decoder.layers.24.final_layer_norm.weight", + "decoder.layers.24.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.25.flat_param_0": { + "names": [ + "decoder.layers.25.self_attn.qkv_proj.weight", + "decoder.layers.25.self_attn.qkv_proj.bias", + "decoder.layers.25.self_attn.out_proj.weight", + "decoder.layers.25.self_attn.out_proj.bias", + "decoder.layers.25.self_attn_layer_norm.weight", + "decoder.layers.25.self_attn_layer_norm.bias", + "decoder.layers.25.fc1.weight", + "decoder.layers.25.fc1.bias", + "decoder.layers.25.fc2.weight", + "decoder.layers.25.fc2.bias", + "decoder.layers.25.final_layer_norm.weight", + "decoder.layers.25.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.26.flat_param_0": { + "names": [ + "decoder.layers.26.self_attn.qkv_proj.weight", + "decoder.layers.26.self_attn.qkv_proj.bias", + "decoder.layers.26.self_attn.out_proj.weight", + "decoder.layers.26.self_attn.out_proj.bias", + "decoder.layers.26.self_attn_layer_norm.weight", + "decoder.layers.26.self_attn_layer_norm.bias", + "decoder.layers.26.fc1.weight", + "decoder.layers.26.fc1.bias", + "decoder.layers.26.fc2.weight", + "decoder.layers.26.fc2.bias", + "decoder.layers.26.final_layer_norm.weight", + "decoder.layers.26.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.27.flat_param_0": { + "names": [ + "decoder.layers.27.self_attn.qkv_proj.weight", + "decoder.layers.27.self_attn.qkv_proj.bias", + "decoder.layers.27.self_attn.out_proj.weight", + "decoder.layers.27.self_attn.out_proj.bias", + "decoder.layers.27.self_attn_layer_norm.weight", + "decoder.layers.27.self_attn_layer_norm.bias", + "decoder.layers.27.fc1.weight", + "decoder.layers.27.fc1.bias", + "decoder.layers.27.fc2.weight", + "decoder.layers.27.fc2.bias", + "decoder.layers.27.final_layer_norm.weight", + "decoder.layers.27.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.28.flat_param_0": { + "names": [ + "decoder.layers.28.self_attn.qkv_proj.weight", + "decoder.layers.28.self_attn.qkv_proj.bias", + "decoder.layers.28.self_attn.out_proj.weight", + "decoder.layers.28.self_attn.out_proj.bias", + "decoder.layers.28.self_attn_layer_norm.weight", + "decoder.layers.28.self_attn_layer_norm.bias", + "decoder.layers.28.fc1.weight", + "decoder.layers.28.fc1.bias", + "decoder.layers.28.fc2.weight", + "decoder.layers.28.fc2.bias", + "decoder.layers.28.final_layer_norm.weight", + "decoder.layers.28.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.29.flat_param_0": { + "names": [ + "decoder.layers.29.self_attn.qkv_proj.weight", + "decoder.layers.29.self_attn.qkv_proj.bias", + "decoder.layers.29.self_attn.out_proj.weight", + "decoder.layers.29.self_attn.out_proj.bias", + "decoder.layers.29.self_attn_layer_norm.weight", + "decoder.layers.29.self_attn_layer_norm.bias", + "decoder.layers.29.fc1.weight", + "decoder.layers.29.fc1.bias", + "decoder.layers.29.fc2.weight", + "decoder.layers.29.fc2.bias", + "decoder.layers.29.final_layer_norm.weight", + "decoder.layers.29.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.30.flat_param_0": { + "names": [ + "decoder.layers.30.self_attn.qkv_proj.weight", + "decoder.layers.30.self_attn.qkv_proj.bias", + "decoder.layers.30.self_attn.out_proj.weight", + "decoder.layers.30.self_attn.out_proj.bias", + "decoder.layers.30.self_attn_layer_norm.weight", + "decoder.layers.30.self_attn_layer_norm.bias", + "decoder.layers.30.fc1.weight", + "decoder.layers.30.fc1.bias", + "decoder.layers.30.fc2.weight", + "decoder.layers.30.fc2.bias", + "decoder.layers.30.final_layer_norm.weight", + "decoder.layers.30.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.31.flat_param_0": { + "names": [ + "decoder.layers.31.self_attn.qkv_proj.weight", + "decoder.layers.31.self_attn.qkv_proj.bias", + "decoder.layers.31.self_attn.out_proj.weight", + "decoder.layers.31.self_attn.out_proj.bias", + "decoder.layers.31.self_attn_layer_norm.weight", + "decoder.layers.31.self_attn_layer_norm.bias", + "decoder.layers.31.fc1.weight", + "decoder.layers.31.fc1.bias", + "decoder.layers.31.fc2.weight", + "decoder.layers.31.fc2.bias", + "decoder.layers.31.final_layer_norm.weight", + "decoder.layers.31.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.32.flat_param_0": { + "names": [ + "decoder.layers.32.self_attn.qkv_proj.weight", + "decoder.layers.32.self_attn.qkv_proj.bias", + "decoder.layers.32.self_attn.out_proj.weight", + "decoder.layers.32.self_attn.out_proj.bias", + "decoder.layers.32.self_attn_layer_norm.weight", + "decoder.layers.32.self_attn_layer_norm.bias", + "decoder.layers.32.fc1.weight", + "decoder.layers.32.fc1.bias", + "decoder.layers.32.fc2.weight", + "decoder.layers.32.fc2.bias", + "decoder.layers.32.final_layer_norm.weight", + "decoder.layers.32.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.33.flat_param_0": { + "names": [ + "decoder.layers.33.self_attn.qkv_proj.weight", + "decoder.layers.33.self_attn.qkv_proj.bias", + "decoder.layers.33.self_attn.out_proj.weight", + "decoder.layers.33.self_attn.out_proj.bias", + "decoder.layers.33.self_attn_layer_norm.weight", + "decoder.layers.33.self_attn_layer_norm.bias", + "decoder.layers.33.fc1.weight", + "decoder.layers.33.fc1.bias", + "decoder.layers.33.fc2.weight", + "decoder.layers.33.fc2.bias", + "decoder.layers.33.final_layer_norm.weight", + "decoder.layers.33.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.34.flat_param_0": { + "names": [ + "decoder.layers.34.self_attn.qkv_proj.weight", + "decoder.layers.34.self_attn.qkv_proj.bias", + "decoder.layers.34.self_attn.out_proj.weight", + "decoder.layers.34.self_attn.out_proj.bias", + "decoder.layers.34.self_attn_layer_norm.weight", + "decoder.layers.34.self_attn_layer_norm.bias", + "decoder.layers.34.fc1.weight", + "decoder.layers.34.fc1.bias", + "decoder.layers.34.fc2.weight", + "decoder.layers.34.fc2.bias", + "decoder.layers.34.final_layer_norm.weight", + "decoder.layers.34.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.35.flat_param_0": { + "names": [ + "decoder.layers.35.self_attn.qkv_proj.weight", + "decoder.layers.35.self_attn.qkv_proj.bias", + "decoder.layers.35.self_attn.out_proj.weight", + "decoder.layers.35.self_attn.out_proj.bias", + "decoder.layers.35.self_attn_layer_norm.weight", + "decoder.layers.35.self_attn_layer_norm.bias", + "decoder.layers.35.fc1.weight", + "decoder.layers.35.fc1.bias", + "decoder.layers.35.fc2.weight", + "decoder.layers.35.fc2.bias", + "decoder.layers.35.final_layer_norm.weight", + "decoder.layers.35.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.36.flat_param_0": { + "names": [ + "decoder.layers.36.self_attn.qkv_proj.weight", + "decoder.layers.36.self_attn.qkv_proj.bias", + "decoder.layers.36.self_attn.out_proj.weight", + "decoder.layers.36.self_attn.out_proj.bias", + "decoder.layers.36.self_attn_layer_norm.weight", + "decoder.layers.36.self_attn_layer_norm.bias", + "decoder.layers.36.fc1.weight", + "decoder.layers.36.fc1.bias", + "decoder.layers.36.fc2.weight", + "decoder.layers.36.fc2.bias", + "decoder.layers.36.final_layer_norm.weight", + "decoder.layers.36.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.37.flat_param_0": { + "names": [ + "decoder.layers.37.self_attn.qkv_proj.weight", + "decoder.layers.37.self_attn.qkv_proj.bias", + "decoder.layers.37.self_attn.out_proj.weight", + "decoder.layers.37.self_attn.out_proj.bias", + "decoder.layers.37.self_attn_layer_norm.weight", + "decoder.layers.37.self_attn_layer_norm.bias", + "decoder.layers.37.fc1.weight", + "decoder.layers.37.fc1.bias", + "decoder.layers.37.fc2.weight", + "decoder.layers.37.fc2.bias", + "decoder.layers.37.final_layer_norm.weight", + "decoder.layers.37.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.38.flat_param_0": { + "names": [ + "decoder.layers.38.self_attn.qkv_proj.weight", + "decoder.layers.38.self_attn.qkv_proj.bias", + "decoder.layers.38.self_attn.out_proj.weight", + "decoder.layers.38.self_attn.out_proj.bias", + "decoder.layers.38.self_attn_layer_norm.weight", + "decoder.layers.38.self_attn_layer_norm.bias", + "decoder.layers.38.fc1.weight", + "decoder.layers.38.fc1.bias", + "decoder.layers.38.fc2.weight", + "decoder.layers.38.fc2.bias", + "decoder.layers.38.final_layer_norm.weight", + "decoder.layers.38.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.39.flat_param_0": { + "names": [ + "decoder.layers.39.self_attn.qkv_proj.weight", + "decoder.layers.39.self_attn.qkv_proj.bias", + "decoder.layers.39.self_attn.out_proj.weight", + "decoder.layers.39.self_attn.out_proj.bias", + "decoder.layers.39.self_attn_layer_norm.weight", + "decoder.layers.39.self_attn_layer_norm.bias", + "decoder.layers.39.fc1.weight", + "decoder.layers.39.fc1.bias", + "decoder.layers.39.fc2.weight", + "decoder.layers.39.fc2.bias", + "decoder.layers.39.final_layer_norm.weight", + "decoder.layers.39.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.40.flat_param_0": { + "names": [ + "decoder.layers.40.self_attn.qkv_proj.weight", + "decoder.layers.40.self_attn.qkv_proj.bias", + "decoder.layers.40.self_attn.out_proj.weight", + "decoder.layers.40.self_attn.out_proj.bias", + "decoder.layers.40.self_attn_layer_norm.weight", + "decoder.layers.40.self_attn_layer_norm.bias", + "decoder.layers.40.fc1.weight", + "decoder.layers.40.fc1.bias", + "decoder.layers.40.fc2.weight", + "decoder.layers.40.fc2.bias", + "decoder.layers.40.final_layer_norm.weight", + "decoder.layers.40.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.41.flat_param_0": { + "names": [ + "decoder.layers.41.self_attn.qkv_proj.weight", + "decoder.layers.41.self_attn.qkv_proj.bias", + "decoder.layers.41.self_attn.out_proj.weight", + "decoder.layers.41.self_attn.out_proj.bias", + "decoder.layers.41.self_attn_layer_norm.weight", + "decoder.layers.41.self_attn_layer_norm.bias", + "decoder.layers.41.fc1.weight", + "decoder.layers.41.fc1.bias", + "decoder.layers.41.fc2.weight", + "decoder.layers.41.fc2.bias", + "decoder.layers.41.final_layer_norm.weight", + "decoder.layers.41.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.42.flat_param_0": { + "names": [ + "decoder.layers.42.self_attn.qkv_proj.weight", + "decoder.layers.42.self_attn.qkv_proj.bias", + "decoder.layers.42.self_attn.out_proj.weight", + "decoder.layers.42.self_attn.out_proj.bias", + "decoder.layers.42.self_attn_layer_norm.weight", + "decoder.layers.42.self_attn_layer_norm.bias", + "decoder.layers.42.fc1.weight", + "decoder.layers.42.fc1.bias", + "decoder.layers.42.fc2.weight", + "decoder.layers.42.fc2.bias", + "decoder.layers.42.final_layer_norm.weight", + "decoder.layers.42.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.43.flat_param_0": { + "names": [ + "decoder.layers.43.self_attn.qkv_proj.weight", + "decoder.layers.43.self_attn.qkv_proj.bias", + "decoder.layers.43.self_attn.out_proj.weight", + "decoder.layers.43.self_attn.out_proj.bias", + "decoder.layers.43.self_attn_layer_norm.weight", + "decoder.layers.43.self_attn_layer_norm.bias", + "decoder.layers.43.fc1.weight", + "decoder.layers.43.fc1.bias", + "decoder.layers.43.fc2.weight", + "decoder.layers.43.fc2.bias", + "decoder.layers.43.final_layer_norm.weight", + "decoder.layers.43.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.44.flat_param_0": { + "names": [ + "decoder.layers.44.self_attn.qkv_proj.weight", + "decoder.layers.44.self_attn.qkv_proj.bias", + "decoder.layers.44.self_attn.out_proj.weight", + "decoder.layers.44.self_attn.out_proj.bias", + "decoder.layers.44.self_attn_layer_norm.weight", + "decoder.layers.44.self_attn_layer_norm.bias", + "decoder.layers.44.fc1.weight", + "decoder.layers.44.fc1.bias", + "decoder.layers.44.fc2.weight", + "decoder.layers.44.fc2.bias", + "decoder.layers.44.final_layer_norm.weight", + "decoder.layers.44.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.45.flat_param_0": { + "names": [ + "decoder.layers.45.self_attn.qkv_proj.weight", + "decoder.layers.45.self_attn.qkv_proj.bias", + "decoder.layers.45.self_attn.out_proj.weight", + "decoder.layers.45.self_attn.out_proj.bias", + "decoder.layers.45.self_attn_layer_norm.weight", + "decoder.layers.45.self_attn_layer_norm.bias", + "decoder.layers.45.fc1.weight", + "decoder.layers.45.fc1.bias", + "decoder.layers.45.fc2.weight", + "decoder.layers.45.fc2.bias", + "decoder.layers.45.final_layer_norm.weight", + "decoder.layers.45.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.46.flat_param_0": { + "names": [ + "decoder.layers.46.self_attn.qkv_proj.weight", + "decoder.layers.46.self_attn.qkv_proj.bias", + "decoder.layers.46.self_attn.out_proj.weight", + "decoder.layers.46.self_attn.out_proj.bias", + "decoder.layers.46.self_attn_layer_norm.weight", + "decoder.layers.46.self_attn_layer_norm.bias", + "decoder.layers.46.fc1.weight", + "decoder.layers.46.fc1.bias", + "decoder.layers.46.fc2.weight", + "decoder.layers.46.fc2.bias", + "decoder.layers.46.final_layer_norm.weight", + "decoder.layers.46.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.47.flat_param_0": { + "names": [ + "decoder.layers.47.self_attn.qkv_proj.weight", + "decoder.layers.47.self_attn.qkv_proj.bias", + "decoder.layers.47.self_attn.out_proj.weight", + "decoder.layers.47.self_attn.out_proj.bias", + "decoder.layers.47.self_attn_layer_norm.weight", + "decoder.layers.47.self_attn_layer_norm.bias", + "decoder.layers.47.fc1.weight", + "decoder.layers.47.fc1.bias", + "decoder.layers.47.fc2.weight", + "decoder.layers.47.fc2.bias", + "decoder.layers.47.final_layer_norm.weight", + "decoder.layers.47.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.48.flat_param_0": { + "names": [ + "decoder.layers.48.self_attn.qkv_proj.weight", + "decoder.layers.48.self_attn.qkv_proj.bias", + "decoder.layers.48.self_attn.out_proj.weight", + "decoder.layers.48.self_attn.out_proj.bias", + "decoder.layers.48.self_attn_layer_norm.weight", + "decoder.layers.48.self_attn_layer_norm.bias", + "decoder.layers.48.fc1.weight", + "decoder.layers.48.fc1.bias", + "decoder.layers.48.fc2.weight", + "decoder.layers.48.fc2.bias", + "decoder.layers.48.final_layer_norm.weight", + "decoder.layers.48.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.49.flat_param_0": { + "names": [ + "decoder.layers.49.self_attn.qkv_proj.weight", + "decoder.layers.49.self_attn.qkv_proj.bias", + "decoder.layers.49.self_attn.out_proj.weight", + "decoder.layers.49.self_attn.out_proj.bias", + "decoder.layers.49.self_attn_layer_norm.weight", + "decoder.layers.49.self_attn_layer_norm.bias", + "decoder.layers.49.fc1.weight", + "decoder.layers.49.fc1.bias", + "decoder.layers.49.fc2.weight", + "decoder.layers.49.fc2.bias", + "decoder.layers.49.final_layer_norm.weight", + "decoder.layers.49.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.50.flat_param_0": { + "names": [ + "decoder.layers.50.self_attn.qkv_proj.weight", + "decoder.layers.50.self_attn.qkv_proj.bias", + "decoder.layers.50.self_attn.out_proj.weight", + "decoder.layers.50.self_attn.out_proj.bias", + "decoder.layers.50.self_attn_layer_norm.weight", + "decoder.layers.50.self_attn_layer_norm.bias", + "decoder.layers.50.fc1.weight", + "decoder.layers.50.fc1.bias", + "decoder.layers.50.fc2.weight", + "decoder.layers.50.fc2.bias", + "decoder.layers.50.final_layer_norm.weight", + "decoder.layers.50.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.51.flat_param_0": { + "names": [ + "decoder.layers.51.self_attn.qkv_proj.weight", + "decoder.layers.51.self_attn.qkv_proj.bias", + "decoder.layers.51.self_attn.out_proj.weight", + "decoder.layers.51.self_attn.out_proj.bias", + "decoder.layers.51.self_attn_layer_norm.weight", + "decoder.layers.51.self_attn_layer_norm.bias", + "decoder.layers.51.fc1.weight", + "decoder.layers.51.fc1.bias", + "decoder.layers.51.fc2.weight", + "decoder.layers.51.fc2.bias", + "decoder.layers.51.final_layer_norm.weight", + "decoder.layers.51.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.52.flat_param_0": { + "names": [ + "decoder.layers.52.self_attn.qkv_proj.weight", + "decoder.layers.52.self_attn.qkv_proj.bias", + "decoder.layers.52.self_attn.out_proj.weight", + "decoder.layers.52.self_attn.out_proj.bias", + "decoder.layers.52.self_attn_layer_norm.weight", + "decoder.layers.52.self_attn_layer_norm.bias", + "decoder.layers.52.fc1.weight", + "decoder.layers.52.fc1.bias", + "decoder.layers.52.fc2.weight", + "decoder.layers.52.fc2.bias", + "decoder.layers.52.final_layer_norm.weight", + "decoder.layers.52.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.53.flat_param_0": { + "names": [ + "decoder.layers.53.self_attn.qkv_proj.weight", + "decoder.layers.53.self_attn.qkv_proj.bias", + "decoder.layers.53.self_attn.out_proj.weight", + "decoder.layers.53.self_attn.out_proj.bias", + "decoder.layers.53.self_attn_layer_norm.weight", + "decoder.layers.53.self_attn_layer_norm.bias", + "decoder.layers.53.fc1.weight", + "decoder.layers.53.fc1.bias", + "decoder.layers.53.fc2.weight", + "decoder.layers.53.fc2.bias", + "decoder.layers.53.final_layer_norm.weight", + "decoder.layers.53.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.54.flat_param_0": { + "names": [ + "decoder.layers.54.self_attn.qkv_proj.weight", + "decoder.layers.54.self_attn.qkv_proj.bias", + "decoder.layers.54.self_attn.out_proj.weight", + "decoder.layers.54.self_attn.out_proj.bias", + "decoder.layers.54.self_attn_layer_norm.weight", + "decoder.layers.54.self_attn_layer_norm.bias", + "decoder.layers.54.fc1.weight", + "decoder.layers.54.fc1.bias", + "decoder.layers.54.fc2.weight", + "decoder.layers.54.fc2.bias", + "decoder.layers.54.final_layer_norm.weight", + "decoder.layers.54.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.55.flat_param_0": { + "names": [ + "decoder.layers.55.self_attn.qkv_proj.weight", + "decoder.layers.55.self_attn.qkv_proj.bias", + "decoder.layers.55.self_attn.out_proj.weight", + "decoder.layers.55.self_attn.out_proj.bias", + "decoder.layers.55.self_attn_layer_norm.weight", + "decoder.layers.55.self_attn_layer_norm.bias", + "decoder.layers.55.fc1.weight", + "decoder.layers.55.fc1.bias", + "decoder.layers.55.fc2.weight", + "decoder.layers.55.fc2.bias", + "decoder.layers.55.final_layer_norm.weight", + "decoder.layers.55.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.56.flat_param_0": { + "names": [ + "decoder.layers.56.self_attn.qkv_proj.weight", + "decoder.layers.56.self_attn.qkv_proj.bias", + "decoder.layers.56.self_attn.out_proj.weight", + "decoder.layers.56.self_attn.out_proj.bias", + "decoder.layers.56.self_attn_layer_norm.weight", + "decoder.layers.56.self_attn_layer_norm.bias", + "decoder.layers.56.fc1.weight", + "decoder.layers.56.fc1.bias", + "decoder.layers.56.fc2.weight", + "decoder.layers.56.fc2.bias", + "decoder.layers.56.final_layer_norm.weight", + "decoder.layers.56.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.57.flat_param_0": { + "names": [ + "decoder.layers.57.self_attn.qkv_proj.weight", + "decoder.layers.57.self_attn.qkv_proj.bias", + "decoder.layers.57.self_attn.out_proj.weight", + "decoder.layers.57.self_attn.out_proj.bias", + "decoder.layers.57.self_attn_layer_norm.weight", + "decoder.layers.57.self_attn_layer_norm.bias", + "decoder.layers.57.fc1.weight", + "decoder.layers.57.fc1.bias", + "decoder.layers.57.fc2.weight", + "decoder.layers.57.fc2.bias", + "decoder.layers.57.final_layer_norm.weight", + "decoder.layers.57.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.58.flat_param_0": { + "names": [ + "decoder.layers.58.self_attn.qkv_proj.weight", + "decoder.layers.58.self_attn.qkv_proj.bias", + "decoder.layers.58.self_attn.out_proj.weight", + "decoder.layers.58.self_attn.out_proj.bias", + "decoder.layers.58.self_attn_layer_norm.weight", + "decoder.layers.58.self_attn_layer_norm.bias", + "decoder.layers.58.fc1.weight", + "decoder.layers.58.fc1.bias", + "decoder.layers.58.fc2.weight", + "decoder.layers.58.fc2.bias", + "decoder.layers.58.final_layer_norm.weight", + "decoder.layers.58.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.59.flat_param_0": { + "names": [ + "decoder.layers.59.self_attn.qkv_proj.weight", + "decoder.layers.59.self_attn.qkv_proj.bias", + "decoder.layers.59.self_attn.out_proj.weight", + "decoder.layers.59.self_attn.out_proj.bias", + "decoder.layers.59.self_attn_layer_norm.weight", + "decoder.layers.59.self_attn_layer_norm.bias", + "decoder.layers.59.fc1.weight", + "decoder.layers.59.fc1.bias", + "decoder.layers.59.fc2.weight", + "decoder.layers.59.fc2.bias", + "decoder.layers.59.final_layer_norm.weight", + "decoder.layers.59.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.60.flat_param_0": { + "names": [ + "decoder.layers.60.self_attn.qkv_proj.weight", + "decoder.layers.60.self_attn.qkv_proj.bias", + "decoder.layers.60.self_attn.out_proj.weight", + "decoder.layers.60.self_attn.out_proj.bias", + "decoder.layers.60.self_attn_layer_norm.weight", + "decoder.layers.60.self_attn_layer_norm.bias", + "decoder.layers.60.fc1.weight", + "decoder.layers.60.fc1.bias", + "decoder.layers.60.fc2.weight", + "decoder.layers.60.fc2.bias", + "decoder.layers.60.final_layer_norm.weight", + "decoder.layers.60.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.61.flat_param_0": { + "names": [ + "decoder.layers.61.self_attn.qkv_proj.weight", + "decoder.layers.61.self_attn.qkv_proj.bias", + "decoder.layers.61.self_attn.out_proj.weight", + "decoder.layers.61.self_attn.out_proj.bias", + "decoder.layers.61.self_attn_layer_norm.weight", + "decoder.layers.61.self_attn_layer_norm.bias", + "decoder.layers.61.fc1.weight", + "decoder.layers.61.fc1.bias", + "decoder.layers.61.fc2.weight", + "decoder.layers.61.fc2.bias", + "decoder.layers.61.final_layer_norm.weight", + "decoder.layers.61.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.62.flat_param_0": { + "names": [ + "decoder.layers.62.self_attn.qkv_proj.weight", + "decoder.layers.62.self_attn.qkv_proj.bias", + "decoder.layers.62.self_attn.out_proj.weight", + "decoder.layers.62.self_attn.out_proj.bias", + "decoder.layers.62.self_attn_layer_norm.weight", + "decoder.layers.62.self_attn_layer_norm.bias", + "decoder.layers.62.fc1.weight", + "decoder.layers.62.fc1.bias", + "decoder.layers.62.fc2.weight", + "decoder.layers.62.fc2.bias", + "decoder.layers.62.final_layer_norm.weight", + "decoder.layers.62.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.63.flat_param_0": { + "names": [ + "decoder.layers.63.self_attn.qkv_proj.weight", + "decoder.layers.63.self_attn.qkv_proj.bias", + "decoder.layers.63.self_attn.out_proj.weight", + "decoder.layers.63.self_attn.out_proj.bias", + "decoder.layers.63.self_attn_layer_norm.weight", + "decoder.layers.63.self_attn_layer_norm.bias", + "decoder.layers.63.fc1.weight", + "decoder.layers.63.fc1.bias", + "decoder.layers.63.fc2.weight", + "decoder.layers.63.fc2.bias", + "decoder.layers.63.final_layer_norm.weight", + "decoder.layers.63.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.64.flat_param_0": { + "names": [ + "decoder.layers.64.self_attn.qkv_proj.weight", + "decoder.layers.64.self_attn.qkv_proj.bias", + "decoder.layers.64.self_attn.out_proj.weight", + "decoder.layers.64.self_attn.out_proj.bias", + "decoder.layers.64.self_attn_layer_norm.weight", + "decoder.layers.64.self_attn_layer_norm.bias", + "decoder.layers.64.fc1.weight", + "decoder.layers.64.fc1.bias", + "decoder.layers.64.fc2.weight", + "decoder.layers.64.fc2.bias", + "decoder.layers.64.final_layer_norm.weight", + "decoder.layers.64.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.65.flat_param_0": { + "names": [ + "decoder.layers.65.self_attn.qkv_proj.weight", + "decoder.layers.65.self_attn.qkv_proj.bias", + "decoder.layers.65.self_attn.out_proj.weight", + "decoder.layers.65.self_attn.out_proj.bias", + "decoder.layers.65.self_attn_layer_norm.weight", + "decoder.layers.65.self_attn_layer_norm.bias", + "decoder.layers.65.fc1.weight", + "decoder.layers.65.fc1.bias", + "decoder.layers.65.fc2.weight", + "decoder.layers.65.fc2.bias", + "decoder.layers.65.final_layer_norm.weight", + "decoder.layers.65.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.66.flat_param_0": { + "names": [ + "decoder.layers.66.self_attn.qkv_proj.weight", + "decoder.layers.66.self_attn.qkv_proj.bias", + "decoder.layers.66.self_attn.out_proj.weight", + "decoder.layers.66.self_attn.out_proj.bias", + "decoder.layers.66.self_attn_layer_norm.weight", + "decoder.layers.66.self_attn_layer_norm.bias", + "decoder.layers.66.fc1.weight", + "decoder.layers.66.fc1.bias", + "decoder.layers.66.fc2.weight", + "decoder.layers.66.fc2.bias", + "decoder.layers.66.final_layer_norm.weight", + "decoder.layers.66.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.67.flat_param_0": { + "names": [ + "decoder.layers.67.self_attn.qkv_proj.weight", + "decoder.layers.67.self_attn.qkv_proj.bias", + "decoder.layers.67.self_attn.out_proj.weight", + "decoder.layers.67.self_attn.out_proj.bias", + "decoder.layers.67.self_attn_layer_norm.weight", + "decoder.layers.67.self_attn_layer_norm.bias", + "decoder.layers.67.fc1.weight", + "decoder.layers.67.fc1.bias", + "decoder.layers.67.fc2.weight", + "decoder.layers.67.fc2.bias", + "decoder.layers.67.final_layer_norm.weight", + "decoder.layers.67.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.68.flat_param_0": { + "names": [ + "decoder.layers.68.self_attn.qkv_proj.weight", + "decoder.layers.68.self_attn.qkv_proj.bias", + "decoder.layers.68.self_attn.out_proj.weight", + "decoder.layers.68.self_attn.out_proj.bias", + "decoder.layers.68.self_attn_layer_norm.weight", + "decoder.layers.68.self_attn_layer_norm.bias", + "decoder.layers.68.fc1.weight", + "decoder.layers.68.fc1.bias", + "decoder.layers.68.fc2.weight", + "decoder.layers.68.fc2.bias", + "decoder.layers.68.final_layer_norm.weight", + "decoder.layers.68.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.69.flat_param_0": { + "names": [ + "decoder.layers.69.self_attn.qkv_proj.weight", + "decoder.layers.69.self_attn.qkv_proj.bias", + "decoder.layers.69.self_attn.out_proj.weight", + "decoder.layers.69.self_attn.out_proj.bias", + "decoder.layers.69.self_attn_layer_norm.weight", + "decoder.layers.69.self_attn_layer_norm.bias", + "decoder.layers.69.fc1.weight", + "decoder.layers.69.fc1.bias", + "decoder.layers.69.fc2.weight", + "decoder.layers.69.fc2.bias", + "decoder.layers.69.final_layer_norm.weight", + "decoder.layers.69.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.70.flat_param_0": { + "names": [ + "decoder.layers.70.self_attn.qkv_proj.weight", + "decoder.layers.70.self_attn.qkv_proj.bias", + "decoder.layers.70.self_attn.out_proj.weight", + "decoder.layers.70.self_attn.out_proj.bias", + "decoder.layers.70.self_attn_layer_norm.weight", + "decoder.layers.70.self_attn_layer_norm.bias", + "decoder.layers.70.fc1.weight", + "decoder.layers.70.fc1.bias", + "decoder.layers.70.fc2.weight", + "decoder.layers.70.fc2.bias", + "decoder.layers.70.final_layer_norm.weight", + "decoder.layers.70.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.71.flat_param_0": { + "names": [ + "decoder.layers.71.self_attn.qkv_proj.weight", + "decoder.layers.71.self_attn.qkv_proj.bias", + "decoder.layers.71.self_attn.out_proj.weight", + "decoder.layers.71.self_attn.out_proj.bias", + "decoder.layers.71.self_attn_layer_norm.weight", + "decoder.layers.71.self_attn_layer_norm.bias", + "decoder.layers.71.fc1.weight", + "decoder.layers.71.fc1.bias", + "decoder.layers.71.fc2.weight", + "decoder.layers.71.fc2.bias", + "decoder.layers.71.final_layer_norm.weight", + "decoder.layers.71.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.72.flat_param_0": { + "names": [ + "decoder.layers.72.self_attn.qkv_proj.weight", + "decoder.layers.72.self_attn.qkv_proj.bias", + "decoder.layers.72.self_attn.out_proj.weight", + "decoder.layers.72.self_attn.out_proj.bias", + "decoder.layers.72.self_attn_layer_norm.weight", + "decoder.layers.72.self_attn_layer_norm.bias", + "decoder.layers.72.fc1.weight", + "decoder.layers.72.fc1.bias", + "decoder.layers.72.fc2.weight", + "decoder.layers.72.fc2.bias", + "decoder.layers.72.final_layer_norm.weight", + "decoder.layers.72.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.73.flat_param_0": { + "names": [ + "decoder.layers.73.self_attn.qkv_proj.weight", + "decoder.layers.73.self_attn.qkv_proj.bias", + "decoder.layers.73.self_attn.out_proj.weight", + "decoder.layers.73.self_attn.out_proj.bias", + "decoder.layers.73.self_attn_layer_norm.weight", + "decoder.layers.73.self_attn_layer_norm.bias", + "decoder.layers.73.fc1.weight", + "decoder.layers.73.fc1.bias", + "decoder.layers.73.fc2.weight", + "decoder.layers.73.fc2.bias", + "decoder.layers.73.final_layer_norm.weight", + "decoder.layers.73.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.74.flat_param_0": { + "names": [ + "decoder.layers.74.self_attn.qkv_proj.weight", + "decoder.layers.74.self_attn.qkv_proj.bias", + "decoder.layers.74.self_attn.out_proj.weight", + "decoder.layers.74.self_attn.out_proj.bias", + "decoder.layers.74.self_attn_layer_norm.weight", + "decoder.layers.74.self_attn_layer_norm.bias", + "decoder.layers.74.fc1.weight", + "decoder.layers.74.fc1.bias", + "decoder.layers.74.fc2.weight", + "decoder.layers.74.fc2.bias", + "decoder.layers.74.final_layer_norm.weight", + "decoder.layers.74.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.75.flat_param_0": { + "names": [ + "decoder.layers.75.self_attn.qkv_proj.weight", + "decoder.layers.75.self_attn.qkv_proj.bias", + "decoder.layers.75.self_attn.out_proj.weight", + "decoder.layers.75.self_attn.out_proj.bias", + "decoder.layers.75.self_attn_layer_norm.weight", + "decoder.layers.75.self_attn_layer_norm.bias", + "decoder.layers.75.fc1.weight", + "decoder.layers.75.fc1.bias", + "decoder.layers.75.fc2.weight", + "decoder.layers.75.fc2.bias", + "decoder.layers.75.final_layer_norm.weight", + "decoder.layers.75.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.76.flat_param_0": { + "names": [ + "decoder.layers.76.self_attn.qkv_proj.weight", + "decoder.layers.76.self_attn.qkv_proj.bias", + "decoder.layers.76.self_attn.out_proj.weight", + "decoder.layers.76.self_attn.out_proj.bias", + "decoder.layers.76.self_attn_layer_norm.weight", + "decoder.layers.76.self_attn_layer_norm.bias", + "decoder.layers.76.fc1.weight", + "decoder.layers.76.fc1.bias", + "decoder.layers.76.fc2.weight", + "decoder.layers.76.fc2.bias", + "decoder.layers.76.final_layer_norm.weight", + "decoder.layers.76.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.77.flat_param_0": { + "names": [ + "decoder.layers.77.self_attn.qkv_proj.weight", + "decoder.layers.77.self_attn.qkv_proj.bias", + "decoder.layers.77.self_attn.out_proj.weight", + "decoder.layers.77.self_attn.out_proj.bias", + "decoder.layers.77.self_attn_layer_norm.weight", + "decoder.layers.77.self_attn_layer_norm.bias", + "decoder.layers.77.fc1.weight", + "decoder.layers.77.fc1.bias", + "decoder.layers.77.fc2.weight", + "decoder.layers.77.fc2.bias", + "decoder.layers.77.final_layer_norm.weight", + "decoder.layers.77.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.78.flat_param_0": { + "names": [ + "decoder.layers.78.self_attn.qkv_proj.weight", + "decoder.layers.78.self_attn.qkv_proj.bias", + "decoder.layers.78.self_attn.out_proj.weight", + "decoder.layers.78.self_attn.out_proj.bias", + "decoder.layers.78.self_attn_layer_norm.weight", + "decoder.layers.78.self_attn_layer_norm.bias", + "decoder.layers.78.fc1.weight", + "decoder.layers.78.fc1.bias", + "decoder.layers.78.fc2.weight", + "decoder.layers.78.fc2.bias", + "decoder.layers.78.final_layer_norm.weight", + "decoder.layers.78.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.79.flat_param_0": { + "names": [ + "decoder.layers.79.self_attn.qkv_proj.weight", + "decoder.layers.79.self_attn.qkv_proj.bias", + "decoder.layers.79.self_attn.out_proj.weight", + "decoder.layers.79.self_attn.out_proj.bias", + "decoder.layers.79.self_attn_layer_norm.weight", + "decoder.layers.79.self_attn_layer_norm.bias", + "decoder.layers.79.fc1.weight", + "decoder.layers.79.fc1.bias", + "decoder.layers.79.fc2.weight", + "decoder.layers.79.fc2.bias", + "decoder.layers.79.final_layer_norm.weight", + "decoder.layers.79.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.80.flat_param_0": { + "names": [ + "decoder.layers.80.self_attn.qkv_proj.weight", + "decoder.layers.80.self_attn.qkv_proj.bias", + "decoder.layers.80.self_attn.out_proj.weight", + "decoder.layers.80.self_attn.out_proj.bias", + "decoder.layers.80.self_attn_layer_norm.weight", + "decoder.layers.80.self_attn_layer_norm.bias", + "decoder.layers.80.fc1.weight", + "decoder.layers.80.fc1.bias", + "decoder.layers.80.fc2.weight", + "decoder.layers.80.fc2.bias", + "decoder.layers.80.final_layer_norm.weight", + "decoder.layers.80.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.81.flat_param_0": { + "names": [ + "decoder.layers.81.self_attn.qkv_proj.weight", + "decoder.layers.81.self_attn.qkv_proj.bias", + "decoder.layers.81.self_attn.out_proj.weight", + "decoder.layers.81.self_attn.out_proj.bias", + "decoder.layers.81.self_attn_layer_norm.weight", + "decoder.layers.81.self_attn_layer_norm.bias", + "decoder.layers.81.fc1.weight", + "decoder.layers.81.fc1.bias", + "decoder.layers.81.fc2.weight", + "decoder.layers.81.fc2.bias", + "decoder.layers.81.final_layer_norm.weight", + "decoder.layers.81.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.82.flat_param_0": { + "names": [ + "decoder.layers.82.self_attn.qkv_proj.weight", + "decoder.layers.82.self_attn.qkv_proj.bias", + "decoder.layers.82.self_attn.out_proj.weight", + "decoder.layers.82.self_attn.out_proj.bias", + "decoder.layers.82.self_attn_layer_norm.weight", + "decoder.layers.82.self_attn_layer_norm.bias", + "decoder.layers.82.fc1.weight", + "decoder.layers.82.fc1.bias", + "decoder.layers.82.fc2.weight", + "decoder.layers.82.fc2.bias", + "decoder.layers.82.final_layer_norm.weight", + "decoder.layers.82.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.83.flat_param_0": { + "names": [ + "decoder.layers.83.self_attn.qkv_proj.weight", + "decoder.layers.83.self_attn.qkv_proj.bias", + "decoder.layers.83.self_attn.out_proj.weight", + "decoder.layers.83.self_attn.out_proj.bias", + "decoder.layers.83.self_attn_layer_norm.weight", + "decoder.layers.83.self_attn_layer_norm.bias", + "decoder.layers.83.fc1.weight", + "decoder.layers.83.fc1.bias", + "decoder.layers.83.fc2.weight", + "decoder.layers.83.fc2.bias", + "decoder.layers.83.final_layer_norm.weight", + "decoder.layers.83.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.84.flat_param_0": { + "names": [ + "decoder.layers.84.self_attn.qkv_proj.weight", + "decoder.layers.84.self_attn.qkv_proj.bias", + "decoder.layers.84.self_attn.out_proj.weight", + "decoder.layers.84.self_attn.out_proj.bias", + "decoder.layers.84.self_attn_layer_norm.weight", + "decoder.layers.84.self_attn_layer_norm.bias", + "decoder.layers.84.fc1.weight", + "decoder.layers.84.fc1.bias", + "decoder.layers.84.fc2.weight", + "decoder.layers.84.fc2.bias", + "decoder.layers.84.final_layer_norm.weight", + "decoder.layers.84.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.85.flat_param_0": { + "names": [ + "decoder.layers.85.self_attn.qkv_proj.weight", + "decoder.layers.85.self_attn.qkv_proj.bias", + "decoder.layers.85.self_attn.out_proj.weight", + "decoder.layers.85.self_attn.out_proj.bias", + "decoder.layers.85.self_attn_layer_norm.weight", + "decoder.layers.85.self_attn_layer_norm.bias", + "decoder.layers.85.fc1.weight", + "decoder.layers.85.fc1.bias", + "decoder.layers.85.fc2.weight", + "decoder.layers.85.fc2.bias", + "decoder.layers.85.final_layer_norm.weight", + "decoder.layers.85.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.86.flat_param_0": { + "names": [ + "decoder.layers.86.self_attn.qkv_proj.weight", + "decoder.layers.86.self_attn.qkv_proj.bias", + "decoder.layers.86.self_attn.out_proj.weight", + "decoder.layers.86.self_attn.out_proj.bias", + "decoder.layers.86.self_attn_layer_norm.weight", + "decoder.layers.86.self_attn_layer_norm.bias", + "decoder.layers.86.fc1.weight", + "decoder.layers.86.fc1.bias", + "decoder.layers.86.fc2.weight", + "decoder.layers.86.fc2.bias", + "decoder.layers.86.final_layer_norm.weight", + "decoder.layers.86.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.87.flat_param_0": { + "names": [ + "decoder.layers.87.self_attn.qkv_proj.weight", + "decoder.layers.87.self_attn.qkv_proj.bias", + "decoder.layers.87.self_attn.out_proj.weight", + "decoder.layers.87.self_attn.out_proj.bias", + "decoder.layers.87.self_attn_layer_norm.weight", + "decoder.layers.87.self_attn_layer_norm.bias", + "decoder.layers.87.fc1.weight", + "decoder.layers.87.fc1.bias", + "decoder.layers.87.fc2.weight", + "decoder.layers.87.fc2.bias", + "decoder.layers.87.final_layer_norm.weight", + "decoder.layers.87.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.88.flat_param_0": { + "names": [ + "decoder.layers.88.self_attn.qkv_proj.weight", + "decoder.layers.88.self_attn.qkv_proj.bias", + "decoder.layers.88.self_attn.out_proj.weight", + "decoder.layers.88.self_attn.out_proj.bias", + "decoder.layers.88.self_attn_layer_norm.weight", + "decoder.layers.88.self_attn_layer_norm.bias", + "decoder.layers.88.fc1.weight", + "decoder.layers.88.fc1.bias", + "decoder.layers.88.fc2.weight", + "decoder.layers.88.fc2.bias", + "decoder.layers.88.final_layer_norm.weight", + "decoder.layers.88.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.89.flat_param_0": { + "names": [ + "decoder.layers.89.self_attn.qkv_proj.weight", + "decoder.layers.89.self_attn.qkv_proj.bias", + "decoder.layers.89.self_attn.out_proj.weight", + "decoder.layers.89.self_attn.out_proj.bias", + "decoder.layers.89.self_attn_layer_norm.weight", + "decoder.layers.89.self_attn_layer_norm.bias", + "decoder.layers.89.fc1.weight", + "decoder.layers.89.fc1.bias", + "decoder.layers.89.fc2.weight", + "decoder.layers.89.fc2.bias", + "decoder.layers.89.final_layer_norm.weight", + "decoder.layers.89.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.90.flat_param_0": { + "names": [ + "decoder.layers.90.self_attn.qkv_proj.weight", + "decoder.layers.90.self_attn.qkv_proj.bias", + "decoder.layers.90.self_attn.out_proj.weight", + "decoder.layers.90.self_attn.out_proj.bias", + "decoder.layers.90.self_attn_layer_norm.weight", + "decoder.layers.90.self_attn_layer_norm.bias", + "decoder.layers.90.fc1.weight", + "decoder.layers.90.fc1.bias", + "decoder.layers.90.fc2.weight", + "decoder.layers.90.fc2.bias", + "decoder.layers.90.final_layer_norm.weight", + "decoder.layers.90.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.91.flat_param_0": { + "names": [ + "decoder.layers.91.self_attn.qkv_proj.weight", + "decoder.layers.91.self_attn.qkv_proj.bias", + "decoder.layers.91.self_attn.out_proj.weight", + "decoder.layers.91.self_attn.out_proj.bias", + "decoder.layers.91.self_attn_layer_norm.weight", + "decoder.layers.91.self_attn_layer_norm.bias", + "decoder.layers.91.fc1.weight", + "decoder.layers.91.fc1.bias", + "decoder.layers.91.fc2.weight", + "decoder.layers.91.fc2.bias", + "decoder.layers.91.final_layer_norm.weight", + "decoder.layers.91.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.92.flat_param_0": { + "names": [ + "decoder.layers.92.self_attn.qkv_proj.weight", + "decoder.layers.92.self_attn.qkv_proj.bias", + "decoder.layers.92.self_attn.out_proj.weight", + "decoder.layers.92.self_attn.out_proj.bias", + "decoder.layers.92.self_attn_layer_norm.weight", + "decoder.layers.92.self_attn_layer_norm.bias", + "decoder.layers.92.fc1.weight", + "decoder.layers.92.fc1.bias", + "decoder.layers.92.fc2.weight", + "decoder.layers.92.fc2.bias", + "decoder.layers.92.final_layer_norm.weight", + "decoder.layers.92.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.93.flat_param_0": { + "names": [ + "decoder.layers.93.self_attn.qkv_proj.weight", + "decoder.layers.93.self_attn.qkv_proj.bias", + "decoder.layers.93.self_attn.out_proj.weight", + "decoder.layers.93.self_attn.out_proj.bias", + "decoder.layers.93.self_attn_layer_norm.weight", + "decoder.layers.93.self_attn_layer_norm.bias", + "decoder.layers.93.fc1.weight", + "decoder.layers.93.fc1.bias", + "decoder.layers.93.fc2.weight", + "decoder.layers.93.fc2.bias", + "decoder.layers.93.final_layer_norm.weight", + "decoder.layers.93.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.94.flat_param_0": { + "names": [ + "decoder.layers.94.self_attn.qkv_proj.weight", + "decoder.layers.94.self_attn.qkv_proj.bias", + "decoder.layers.94.self_attn.out_proj.weight", + "decoder.layers.94.self_attn.out_proj.bias", + "decoder.layers.94.self_attn_layer_norm.weight", + "decoder.layers.94.self_attn_layer_norm.bias", + "decoder.layers.94.fc1.weight", + "decoder.layers.94.fc1.bias", + "decoder.layers.94.fc2.weight", + "decoder.layers.94.fc2.bias", + "decoder.layers.94.final_layer_norm.weight", + "decoder.layers.94.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + }, + "decoder.layers.95.flat_param_0": { + "names": [ + "decoder.layers.95.self_attn.qkv_proj.weight", + "decoder.layers.95.self_attn.qkv_proj.bias", + "decoder.layers.95.self_attn.out_proj.weight", + "decoder.layers.95.self_attn.out_proj.bias", + "decoder.layers.95.self_attn_layer_norm.weight", + "decoder.layers.95.self_attn_layer_norm.bias", + "decoder.layers.95.fc1.weight", + "decoder.layers.95.fc1.bias", + "decoder.layers.95.fc2.weight", + "decoder.layers.95.fc2.bias", + "decoder.layers.95.final_layer_norm.weight", + "decoder.layers.95.final_layer_norm.bias" + ], + "shapes": [ + [ + 4608, + 12288 + ], + [ + 4608 + ], + [ + 12288, + 1536 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 6144, + 12288 + ], + [ + 6144 + ], + [ + 12288, + 6144 + ], + [ + 12288 + ], + [ + 12288 + ], + [ + 12288 + ] + ], + "numels": [ + 56623104, + 4608, + 18874368, + 12288, + 12288, + 12288, + 75497472, + 6144, + 75497472, + 12288, + 12288, + 12288 + ] + } +} diff --git a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py index 0494647d7bcc..ffa6f0d83808 100644 --- a/examples/tutorial/opt/inference/script/processing_ckpt_66b.py +++ b/examples/tutorial/opt/inference/script/processing_ckpt_66b.py @@ -1,7 +1,8 @@ import os -import torch from multiprocessing import Pool +import torch + # download pytorch model ckpt in https://huggingface.co/facebook/opt-66b/tree/main # you can use whether wget or git lfs @@ -20,28 +21,27 @@ restored = {} for ckpt in ckpts: - for k,v in ckpt.items(): - if(k[0] == 'm'): - k = k[6:] - if(k == "lm_head.weight"): + for k, v in ckpt.items(): + if (k[0] == 'm'): + k = k[6:] + if (k == "lm_head.weight"): k = "head.dense.weight" - if(k == "decoder.final_layer_norm.weight"): + if (k == "decoder.final_layer_norm.weight"): k = "decoder.layer_norm.weight" - if(k == "decoder.final_layer_norm.bias"): + if (k == "decoder.final_layer_norm.bias"): k = "decoder.layer_norm.bias" restored[k] = v restored["decoder.version"] = "0.0" - split_num = len(restored.keys()) // 60 count = 0 file_count = 1 tmp = {} -for k,v in restored.items(): +for k, v in restored.items(): print(k) tmp[k] = v - count = count + 1 - if(count == split_num): + count = count + 1 + if (count == split_num): filename = str(file_count) + "-restored.pt" torch.save(tmp, os.path.join(new_path, filename)) file_count = file_count + 1 @@ -50,6 +50,3 @@ filename = str(file_count) + "-restored.pt" torch.save(tmp, os.path.join(new_path, filename)) - - - diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index c4f576cb18aa..b634780ff707 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -30,24 +30,13 @@ import datasets import torch import torch.distributed as dist +import transformers from accelerate.utils import set_seed from context import barrier_context from datasets import load_dataset from packaging import version from torch.utils.data import DataLoader from tqdm.auto import tqdm - -import colossalai -import transformers -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device, get_dataloader -from colossalai.utils.model.colo_init_context import ColoInitContext from transformers import ( CONFIG_MAPPING, MODEL_MAPPING, @@ -61,6 +50,17 @@ ) from transformers.utils.versions import require_version +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.utils.model.colo_init_context import ColoInitContext + require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) diff --git a/examples/tutorial/sequence_parallel/data/__init__.py b/examples/tutorial/sequence_parallel/data/__init__.py index 1ef2d999389f..9815b487e50f 100644 --- a/examples/tutorial/sequence_parallel/data/__init__.py +++ b/examples/tutorial/sequence_parallel/data/__init__.py @@ -1,10 +1,12 @@ +import torch + +from colossalai.context import ParallelMode from colossalai.context.parallel_context import ParallelContext from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger -from colossalai.context import ParallelMode -from .datasets.data_samplers import build_pretraining_data_loader + from .datasets.builder import build_train_valid_test_datasets -import torch +from .datasets.data_samplers import build_pretraining_data_loader def cyclic_iter(iter): @@ -18,8 +20,7 @@ def build_train_valid_test_data_iterators(train_iters, eval_interval, eval_iters, dataloader_type='single', - **kwargs - ): + **kwargs): (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) logger = get_dist_logger() @@ -42,9 +43,7 @@ def build_train_valid_test_data_iterators(train_iters, train_samples = train_iters * global_batch_size eval_iters_ = (train_iters // eval_interval + 1) * eval_iters test_iters = eval_iters - train_val_test_num_samples = [train_samples, - eval_iters_ * global_batch_size, - test_iters * global_batch_size] + train_val_test_num_samples = [train_samples, eval_iters_ * global_batch_size, test_iters * global_batch_size] logger.info(' > datasets target sizes (minimum size):') logger.info(' train: {}'.format(train_val_test_num_samples[0]), ranks=[0]) logger.info(' validation: {}'.format(train_val_test_num_samples[1]), ranks=[0]) @@ -56,19 +55,20 @@ def build_train_valid_test_data_iterators(train_iters, # Build dataloaders. dp_size = gpc.get_world_size(ParallelMode.DATA) - train_dataloader = build_pretraining_data_loader( - train_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) - valid_dataloader = build_pretraining_data_loader( - valid_ds, consumed_samples=0, micro_batch_size=global_batch_size//dp_size) - test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size//dp_size) + train_dataloader = build_pretraining_data_loader(train_ds, + consumed_samples=0, + micro_batch_size=global_batch_size // dp_size) + valid_dataloader = build_pretraining_data_loader(valid_ds, + consumed_samples=0, + micro_batch_size=global_batch_size // dp_size) + test_dataloader = build_pretraining_data_loader(test_ds, 0, micro_batch_size=global_batch_size // dp_size) # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and train_iters > 0 do_valid = valid_dataloader is not None and eval_iters > 0 do_test = test_dataloader is not None and eval_iters > 0 # Need to broadcast num_tokens and num_type_tokens. - flags = torch.cuda.LongTensor( - [int(do_train), int(do_valid), int(do_test)]) + flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) else: flags = torch.cuda.LongTensor([0, 0, 0]) diff --git a/examples/tutorial/sequence_parallel/data/bert_helper.py b/examples/tutorial/sequence_parallel/data/bert_helper.py index d092db3e7dd8..3f890c8bef79 100644 --- a/examples/tutorial/sequence_parallel/data/bert_helper.py +++ b/examples/tutorial/sequence_parallel/data/bert_helper.py @@ -1,7 +1,8 @@ -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode import torch +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc + _MAX_DATA_DIM = 5 @@ -22,7 +23,8 @@ def _build_key_size_numel_dictionaries(keys, data): # Move to GPU and broadcast. sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast(sizes_cuda, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], + torch.distributed.broadcast(sizes_cuda, + gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], group=gpc.get_group(ParallelMode.TENSOR)) # Move back to cpu and unpack. @@ -60,19 +62,15 @@ def broadcast_data(keys, data, datatype): """ # Build (key, size) and (key, number of elements) dictionaries along # with the total number of elements on all ranks. - key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, - data) + key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) # Pack on rank zero. if not gpc.is_initialized(ParallelMode.TENSOR) or gpc.get_local_rank(ParallelMode.TENSOR) == 0: # Check that all keys have the same data type. # Flatten the data associated with the keys - flatten_data = torch.cat( - [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() + flatten_data = torch.cat([data[key].contiguous().view(-1) for key in keys], dim=0).cuda() else: - flatten_data = torch.empty(total_numel, - device=torch.cuda.current_device(), - dtype=datatype) + flatten_data = torch.empty(total_numel, device=torch.cuda.current_device(), dtype=datatype) # Broadcast torch.distributed.broadcast(flatten_data, @@ -139,7 +137,7 @@ def get_batch_for_sequence_parallel(data_iterator): seq_length = data_b['text'].size(1) sub_seq_length = seq_length // local_world_size sub_seq_start = local_rank * sub_seq_length - sub_seq_end = (local_rank+1) * sub_seq_length + sub_seq_end = (local_rank + 1) * sub_seq_length # # # Unpack. tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long() @@ -156,10 +154,9 @@ class SequenceParallelDataIterator: def __init__(self, data_iter): self.data_iter = data_iter - def __iter__(self): return self.data_iter def __next__(self): - return get_batch_for_sequence_parallel(self.data_iter) \ No newline at end of file + return get_batch_for_sequence_parallel(self.data_iter) diff --git a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py index 6a06c869d8c8..c23f6e7dad64 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/blendable_dataset.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Blendable dataset.""" import time @@ -46,9 +45,7 @@ def __init__(self, datasets, weights): self.dataset_sample_index = np.zeros(self.size, dtype=np.int64) from . import helpers - helpers.build_blending_indices(self.dataset_index, - self.dataset_sample_index, - weights, num_datasets, self.size, + helpers.build_blending_indices(self.dataset_index, self.dataset_sample_index, weights, num_datasets, self.size, torch.distributed.get_rank() == 0) print('> elapsed time for building blendable dataset indices: ' '{:.2f} (sec)'.format(time.time() - start_time)) diff --git a/examples/tutorial/sequence_parallel/data/datasets/builder.py b/examples/tutorial/sequence_parallel/data/datasets/builder.py index 6106f833b462..8d3fbd9c3bbe 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/builder.py +++ b/examples/tutorial/sequence_parallel/data/datasets/builder.py @@ -1,7 +1,8 @@ +from colossalai.logging import get_dist_logger + +from .bert_dataset import BertDataset from .blendable_dataset import BlendableDataset from .dataset_utils import get_datasets_weights_and_num_samples, get_indexed_dataset_, get_train_valid_test_split_ -from .bert_dataset import BertDataset -from colossalai.logging import get_dist_logger DSET_TYPE_BERT = 'standard_bert' DSET_TYPE_ICT = 'ict' @@ -10,10 +11,15 @@ DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, +def _build_train_valid_test_datasets(data_prefix, + data_impl, + splits_string, train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, binary_head, dataset_type='standard_bert'): @@ -21,9 +27,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can @@ -39,14 +43,12 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] - logger.info('\n {}:'.format(name) + - '\n document indices in [{}, {}) total of {} documents'.format( - splits[index], splits[index + 1], - splits[index + 1] - splits[index]) + + logger.info('\n {}:'.format(name) + '\n document indices in [{}, {}) total of {} documents'.format( + splits[index], splits[index + 1], splits[index + 1] - splits[index]) + '\n sentence indices in [{}, {}) total of {} sentences'.format( - start_index, end_index, - end_index - start_index), + start_index, end_index, end_index - start_index), ranks=[0]) + print_split_stats('train', 0) print_split_stats('validation', 1) print_split_stats('test', 2) @@ -75,13 +77,11 @@ def build_dataset(index, name): if dataset_type != DSET_TYPE_BERT: raise NotImplementedError("Only BERT dataset is supported") else: - dataset = BertDataset( - indexed_dataset=indexed_dataset, - masked_lm_prob=masked_lm_prob, - short_seq_prob=short_seq_prob, - binary_head=binary_head, - **kwargs - ) + dataset = BertDataset(indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + binary_head=binary_head, + **kwargs) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) @@ -98,26 +98,33 @@ def build_dataset(index, name): return (train_dataset, valid_dataset, test_dataset) -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, +def build_train_valid_test_datasets(data_prefix, + data_impl, + splits_string, train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, binary_head, dataset_type='standard_bert'): if len(data_prefix) == 1: return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, + data_impl, + splits_string, train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, skip_warmup, binary_head, dataset_type=dataset_type) # Blending dataset. # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -125,11 +132,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, valid_datasets = [] test_datasets = [] for i in range(len(prefixes)): - train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, - datasets_train_valid_test_num_samples[i], - max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(prefixes[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -148,5 +161,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) diff --git a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py index cf547ad97558..482f343c54be 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py +++ b/examples/tutorial/sequence_parallel/data/datasets/data_samplers.py @@ -14,10 +14,12 @@ # limitations under the License. """Dataloaders.""" -import torch import random -from colossalai.core import global_context as gpc + +import torch + from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc def build_pretraining_data_loader(dataset, consumed_samples, micro_batch_size, dataloader_type='single', num_workers=0): diff --git a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py index cf4e4763fc10..e6827b0b6c8f 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py +++ b/examples/tutorial/sequence_parallel/data/datasets/dataset_utils.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. - # Most of the code here has been copied from: # https://github.com/google-research/albert/blob/master/create_pretraining_data.py # with some modifications. +import collections import math import time -import collections -from colossalai.logging import get_dist_logger + import numpy as np + +from colossalai.logging import get_dist_logger + from .blendable_dataset import BlendableDataset from .indexed_dataset import make_dataset as make_indexed_dataset @@ -32,18 +34,17 @@ DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD] -def get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples): +def get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. assert len(data_prefix) % 2 == 0 num_datasets = len(data_prefix) // 2 - weights = [0]*num_datasets - prefixes = [0]*num_datasets + weights = [0] * num_datasets + prefixes = [0] * num_datasets for i in range(num_datasets): - weights[i] = float(data_prefix[2*i]) - prefixes[i] = (data_prefix[2*i+1]).strip() + weights[i] = float(data_prefix[2 * i]) + prefixes[i] = (data_prefix[2 * i + 1]).strip() # Normalize weights weight_sum = 0.0 for weight in weights: @@ -57,8 +58,7 @@ def get_datasets_weights_and_num_samples(data_prefix, datasets_train_valid_test_num_samples = [] for weight in weights: datasets_train_valid_test_num_samples.append( - [int(math.ceil(val * weight * 1.005)) - for val in train_valid_test_num_samples]) + [int(math.ceil(val * weight * 1.005)) for val in train_valid_test_num_samples]) return prefixes, weights, datasets_train_valid_test_num_samples @@ -155,8 +155,7 @@ def create_tokens_and_tokentypes(tokens_a, tokens_b, cls_id, sep_id): return tokens, tokentypes -MaskedLmInstance = collections.namedtuple("MaskedLmInstance", - ["index", "label"]) +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) def is_start_piece(piece): @@ -169,9 +168,12 @@ def is_start_piece(piece): def create_masked_lm_predictions(tokens, - vocab_id_list, vocab_id_to_token_dict, + vocab_id_list, + vocab_id_to_token_dict, masked_lm_prob, - cls_id, sep_id, mask_id, + cls_id, + sep_id, + mask_id, max_predictions_per_seq, np_rng, max_ngrams=3, @@ -197,8 +199,7 @@ def create_masked_lm_predictions(tokens, # Note that Whole Word Masking does *not* change the training code # at all -- we still predict each WordPiece independently, softmaxed # over the entire vocabulary. - if (do_whole_word_mask and len(cand_indexes) >= 1 and - not is_start_piece(vocab_id_to_token_dict[token])): + if (do_whole_word_mask and len(cand_indexes) >= 1 and not is_start_piece(vocab_id_to_token_dict[token])): cand_indexes[-1].append(i) else: cand_indexes.append([i]) @@ -211,11 +212,9 @@ def create_masked_lm_predictions(tokens, masked_lm_labels = [] if masked_lm_prob == 0: - return (output_tokens, masked_lm_positions, - masked_lm_labels, token_boundary) + return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) - num_to_predict = min(max_predictions_per_seq, - max(1, int(round(len(tokens) * masked_lm_prob)))) + num_to_predict = min(max_predictions_per_seq, max(1, int(round(len(tokens) * masked_lm_prob)))) # Note(mingdachen): # By default, we set the probabilities to favor shorter ngram sequences. @@ -250,8 +249,7 @@ def create_masked_lm_predictions(tokens, continue n = np_rng.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) + p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) index_set = sum(cand_index_set[n - 1], []) n -= 1 # Note(mingdachen): @@ -310,8 +308,7 @@ def create_masked_lm_predictions(tokens, continue n = np.random.choice(ngrams[:len(cand_index_set)], - p=pvals[:len(cand_index_set)] / - pvals[:len(cand_index_set)].sum(keepdims=True)) + p=pvals[:len(cand_index_set)] / pvals[:len(cand_index_set)].sum(keepdims=True)) index_set = sum(cand_index_set[n - 1], []) n -= 1 @@ -353,8 +350,7 @@ def create_masked_lm_predictions(tokens, return (output_tokens, masked_lm_positions, masked_lm_labels, token_boundary) -def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, - masked_labels, pad_id, max_seq_length): +def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, masked_labels, pad_id, max_seq_length): """Pad sequences and convert them to numpy.""" # Some checks. @@ -370,8 +366,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, tokentypes_np = np.array(tokentypes + filler, dtype=np.int64) # Padding mask. - padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, - dtype=np.int64) + padding_mask_np = np.array([1] * num_tokens + [0] * padding_length, dtype=np.int64) # Lables and loss mask. labels = [-1] * max_seq_length @@ -386,26 +381,33 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions, return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np -def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, +def build_train_valid_test_datasets(data_prefix, + data_impl, + splits_string, train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, binary_head, dataset_type='standard_bert'): if len(data_prefix) == 1: return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, + data_impl, + splits_string, train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, skip_warmup, binary_head, dataset_type=dataset_type) # Blending dataset. # Parse the values. - output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + output = get_datasets_weights_and_num_samples(data_prefix, train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output # Build individual datasets. @@ -413,11 +415,17 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, valid_datasets = [] test_datasets = [] for i in range(len(prefixes)): - train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( - prefixes[i], data_impl, splits_string, - datasets_train_valid_test_num_samples[i], - max_seq_length, masked_lm_prob, short_seq_prob, - seed, skip_warmup, binary_head, dataset_type=dataset_type) + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets(prefixes[i], + data_impl, + splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, + binary_head, + dataset_type=dataset_type) if train_ds: train_datasets.append(train_ds) if valid_ds: @@ -436,14 +444,18 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, if test_datasets: blending_test_dataset = BlendableDataset(test_datasets, weights) - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) + return (blending_train_dataset, blending_valid_dataset, blending_test_dataset) -def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, +def _build_train_valid_test_datasets(data_prefix, + data_impl, + splits_string, train_valid_test_num_samples, - max_seq_length, masked_lm_prob, - short_seq_prob, seed, skip_warmup, + max_seq_length, + masked_lm_prob, + short_seq_prob, + seed, + skip_warmup, binary_head, dataset_type='standard_bert'): logger = get_dist_logger() @@ -452,15 +464,11 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, raise ValueError("Invalid dataset_type: ", dataset_type) # Indexed dataset. - indexed_dataset = get_indexed_dataset_(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) if dataset_type == DSET_TYPE_ICT: args = get_args() - title_dataset = get_indexed_dataset_(args.titles_data_path, - data_impl, - skip_warmup) + title_dataset = get_indexed_dataset_(args.titles_data_path, data_impl, skip_warmup) # Get start and end indices of train/valid/train into doc-idx # Note that doc-idx is designed to be num-docs + 1 so we can @@ -474,16 +482,12 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def print_split_stats(name, index): start_index = indexed_dataset.doc_idx[splits[index]] end_index = indexed_dataset.doc_idx[splits[index + 1]] - logger.info('\n {}:'.format(name) + - '\n document indices in [{}, {}) total of {} documents'.format( - splits[index], - splits[index + 1], - splits[index + 1] - splits[index]) + + logger.info('\n {}:'.format(name) + '\n document indices in [{}, {}) total of {} documents'.format( + splits[index], splits[index + 1], splits[index + 1] - splits[index]) + '\n sentence indices in [{}, {}) total of {} sentences'.format( - start_index, - end_index, - end_index - start_index), + start_index, end_index, end_index - start_index), ranks=[0]) + print_split_stats('train', 0) print_split_stats('validation', 1) print_split_stats('test', 2) @@ -501,32 +505,26 @@ def build_dataset(index, name): # New doc_idx view. indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) # Build the dataset accordingly. - kwargs = dict( - name=name, - data_prefix=data_prefix, - num_epochs=None, - max_num_samples=train_valid_test_num_samples[index], - max_seq_length=max_seq_length, - seed=seed, - binary_head=binary_head - ) + kwargs = dict(name=name, + data_prefix=data_prefix, + num_epochs=None, + max_num_samples=train_valid_test_num_samples[index], + max_seq_length=max_seq_length, + seed=seed, + binary_head=binary_head) if dataset_type == DSET_TYPE_ICT: args = get_args() - dataset = ICTDataset( - block_dataset=indexed_dataset, - title_dataset=title_dataset, - query_in_block_prob=args.query_in_block_prob, - use_one_sent_docs=args.use_one_sent_docs, - **kwargs - ) + dataset = ICTDataset(block_dataset=indexed_dataset, + title_dataset=title_dataset, + query_in_block_prob=args.query_in_block_prob, + use_one_sent_docs=args.use_one_sent_docs, + **kwargs) else: - dataset = BertDataset( - indexed_dataset=indexed_dataset, - masked_lm_prob=masked_lm_prob, - short_seq_prob=short_seq_prob, - **kwargs - ) + dataset = BertDataset(indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + short_seq_prob=short_seq_prob, + **kwargs) # Set the original pointer so dataset remains the main dataset. indexed_dataset.set_doc_idx(doc_idx_ptr) @@ -546,20 +544,16 @@ def build_dataset(index, name): def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): logger = get_dist_logger() start_time = time.time() - indexed_dataset = make_indexed_dataset(data_prefix, - data_impl, - skip_warmup) + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1] logger.info('\n > building dataset index ...', ranks=[0]) logger.info('\n > finished creating indexed dataset in {:4f} ' - 'seconds'.format(time.time() - start_time), ranks=[0]) + 'seconds'.format(time.time() - start_time), + ranks=[0]) logger.info('\n > indexed dataset stats:' + - '\n number of documents: {}'.format( - indexed_dataset.doc_idx.shape[0] - 1) + - '\n number of sentences: {}'.format( - indexed_dataset.sizes.shape[0]), - ranks=[0] - ) + '\n number of documents: {}'.format(indexed_dataset.doc_idx.shape[0] - 1) + + '\n number of sentences: {}'.format(indexed_dataset.sizes.shape[0]), + ranks=[0]) return indexed_dataset @@ -582,8 +576,7 @@ def get_train_valid_test_split_(splits_string, size): splits = [split / splits_sum for split in splits] splits_index = [0] for index, split in enumerate(splits): - splits_index.append(splits_index[index] + - int(round(split * float(size)))) + splits_index.append(splits_index[index] + int(round(split * float(size)))) diff = splits_index[-1] - size for index in range(1, len(splits_index)): splits_index[index] -= diff diff --git a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp index e45926a97696..52977e63181f 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp +++ b/examples/tutorial/sequence_parallel/data/datasets/helpers.cpp @@ -15,29 +15,28 @@ limitations under the License. */ - /* Helper methods for fast index mapping builds */ +#include +#include +#include + #include #include #include -#include -#include -#include -#include #include +#include namespace py = pybind11; using namespace std; const int32_t LONG_SENTENCE_LEN = 512; - void build_blending_indices(py::array_t& dataset_index, - py::array_t& dataset_sample_index, - const py::array_t& weights, - const int32_t num_datasets, - const int64_t size, const bool verbose) { + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, const int64_t size, + const bool verbose) { /* Given multiple datasets and a weighting array, build samples such that it follows those wieghts.*/ @@ -52,24 +51,23 @@ void build_blending_indices(py::array_t& dataset_index, // Initialize buffer for number of samples used for each dataset. int64_t current_samples[num_datasets]; - for(int64_t i = 0; i < num_datasets; ++i) { + for (int64_t i = 0; i < num_datasets; ++i) { current_samples[i] = 0; } // For each sample: - for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { - + for (int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { // Determine where the max error in sampling is happening. auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); int64_t max_error_index = 0; double max_error = weights_ptr[0] * sample_idx_double - - static_cast(current_samples[0]); + static_cast(current_samples[0]); for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { double error = weights_ptr[dataset_idx] * sample_idx_double - - static_cast(current_samples[dataset_idx]); + static_cast(current_samples[dataset_idx]); if (error > max_error) { - max_error = error; - max_error_index = dataset_idx; + max_error = error; + max_error_index = dataset_idx; } } @@ -79,7 +77,6 @@ void build_blending_indices(py::array_t& dataset_index, // Update the total samples. current_samples[max_error_index] += 1; - } // print info @@ -87,631 +84,607 @@ void build_blending_indices(py::array_t& dataset_index, std::cout << " > sample ratios:" << std::endl; for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { auto ratio = static_cast(current_samples[dataset_idx]) / - static_cast(size); - std::cout << " dataset " << dataset_idx << ", input: " << - weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + static_cast(size); + std::cout << " dataset " << dataset_idx + << ", input: " << weights_ptr[dataset_idx] + << ", achieved: " << ratio << std::endl; } } - } - py::array build_sample_idx(const py::array_t& sizes_, - const py::array_t& doc_idx_, - const int32_t seq_length, - const int32_t num_epochs, - const int64_t tokens_per_epoch) { - /* Sample index (sample_idx) is used for gpt2 like dataset for which - the documents are flattened and the samples are built based on this - 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] - where [..., 0] contains the index into `doc_idx` and [..., 1] is the - starting offset in that document.*/ - - // Consistency checks. - assert(seq_length > 1); - assert(num_epochs > 0); - assert(tokens_per_epoch > 1); - - // Remove bound checks. - auto sizes = sizes_.unchecked<1>(); - auto doc_idx = doc_idx_.unchecked<1>(); - - // Mapping and it's length (1D). - int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; - int32_t* sample_idx = new int32_t[2*(num_samples+1)]; - - cout << " using:" << endl << std::flush; - cout << " number of documents: " << - doc_idx_.shape(0) / num_epochs << endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " sequence length: " << seq_length << - endl << std::flush; - cout << " total number of samples: " << num_samples << - endl << std::flush; - - // Index into sample_idx. - int64_t sample_index = 0; - // Index into doc_idx. - int64_t doc_idx_index = 0; - // Begining offset for each document. - int32_t doc_offset = 0; - // Start with first document and no offset. + const py::array_t& doc_idx_, + const int32_t seq_length, const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2 * (num_samples + 1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << doc_idx_.shape(0) / num_epochs + << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " sequence length: " << seq_length << endl + << std::flush; + cout << " total number of samples: " << num_samples << endl + << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. sample_idx[2 * sample_index] = doc_idx_index; sample_idx[2 * sample_index + 1] = doc_offset; ++sample_index; + } - while (sample_index <= num_samples) { - // Start with a fresh sequence. - int32_t remaining_seq_length = seq_length + 1; - while (remaining_seq_length != 0) { - // Get the document length. - auto doc_id = doc_idx[doc_idx_index]; - auto doc_length = sizes[doc_id] - doc_offset; - // And add it to the current sequence. - remaining_seq_length -= doc_length; - // If we have more than a full sequence, adjust offset and set - // remaining length to zero so we return from the while loop. - // Note that -1 here is for the same reason we have -1 in - // `_num_epochs` calculations. - if (remaining_seq_length <= 0) { - doc_offset += (remaining_seq_length + doc_length - 1); - remaining_seq_length = 0; - } else { - // Otherwise, start from the begining of the next document. - ++doc_idx_index; - doc_offset = 0; - } - } - // Record the sequence. - sample_idx[2 * sample_index] = doc_idx_index; - sample_idx[2 * sample_index + 1] = doc_offset; - ++sample_index; - } - - // Method to deallocate memory. - py::capsule free_when_done(sample_idx, [](void *mem_) { - int32_t *mem = reinterpret_cast(mem_); - delete[] mem; - }); - - // Return the numpy array. - const auto byte_size = sizeof(int32_t); - return py::array(std::vector{num_samples+1, 2}, // shape - {2*byte_size, byte_size}, // C-style contiguous strides - sample_idx, // the data pointer - free_when_done); // numpy array references - + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void* mem_) { + int32_t* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples + 1, 2}, // shape + {2 * byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references } - inline int32_t get_target_sample_len(const int32_t short_seq_ratio, - const int32_t max_length, - std::mt19937& rand32_gen) { - /* Training sample length. */ - if (short_seq_ratio == 0) { - return max_length; - } - const auto random_number = rand32_gen(); - if ((random_number % short_seq_ratio) == 0) { - return 2 + random_number % (max_length - 1); - } + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; } - -template +template py::array build_mapping_impl(const py::array_t& docs_, const py::array_t& sizes_, const int32_t num_epochs, const uint64_t max_num_samples, const int32_t max_seq_length, - const double short_seq_prob, - const int32_t seed, - const bool verbose, - const int32_t min_num_sent) { - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(short_seq_prob >= 0.0); - assert(short_seq_prob <= 1.0); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - - // For efficiency, convert probability to ratio. Note: rand() generates int. - int32_t short_seq_ratio = 0; - if (short_seq_prob > 0) { - short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); - } + const double short_seq_prob, const int32_t seed, + const bool verbose, const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } - if (verbose) { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " short sequence probability: " << short_seq_prob << - endl << std::flush; - cout << " short sequence ration (1/prob): " << short_seq_ratio << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; - } + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 + << endl + << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl + << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl + << std::flush; + cout << " short sequence probability: " << short_seq_prob << endl + << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << endl + << std::flush; + cout << " seed: " << seed << endl + << std::flush; + } - // Mapping and it's length (1D). - int64_t num_samples = -1; - DocIdx* maps = NULL; - - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - - // Set the seed so both iterations produce the same results. - std::mt19937 rand32_gen(seed); - - // Set the flag on second iteration. - second = (iteration == 1); - - // Counters: - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - - // Current map index. - uint64_t map_index = 0; - - // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } - break; + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration = 0; iteration < 2; ++iteration) { + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { + if (map_index >= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl + << std::flush; + } + break; + } + // For each document: + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; } - // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } - } - - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent > 1) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - - // If we have more than two sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - auto target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - - // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and if not only one sentence is left in the document. - // and if we have at least two sentneces. - // and if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent > 1) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - - // Check for overflow. - if ((3 * map_index + 2) > - std::numeric_limits::max()) { - cout << "number of samples exceeded maximum " - << "allowed by type int64: " - << std::numeric_limits::max() - << endl; - throw std::overflow_error("Number of samples"); - } - - // Populate the map. - if (second) { - const auto map_index_0 = 3 * map_index; - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(target_seq_len); - } - - // Update indices / counters. - ++map_index; - prev_start_index = sent_index + 1; - target_seq_len = get_target_sample_len(short_seq_ratio, - max_seq_length, - rand32_gen); - seq_len = 0; - num_sent = 0; - } - - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; - } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[3*map_index]; - num_samples = static_cast(map_index); + } } - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 3 * i; - const auto j0 = 3 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - } + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len( + short_seq_ratio, max_seq_length, rand32_gen); + + // Loop through sentences. + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && (num_remain_sent > 1) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { + // Check for overflow. + if ((3 * map_index + 2) > std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len( + short_seq_ratio, max_seq_length, rand32_gen); + seq_len = 0; + num_sent = 0; + } - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl + << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs + << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs + << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3 * map_index]; + num_samples = static_cast(map_index); + } - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 3}, // shape - {3*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } - py::array build_mapping(const py::array_t& docs_, - const py::array_t& sizes_, - const int num_epochs, + const py::array_t& sizes_, const int num_epochs, const uint64_t max_num_samples, - const int max_seq_length, - const double short_seq_prob, - const int seed, - const bool verbose, - const int32_t min_num_sent) { - - if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); - } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_mapping_impl(docs_, sizes_, num_epochs, - max_num_samples, max_seq_length, - short_seq_prob, seed, verbose, - min_num_sent); + const int max_seq_length, const double short_seq_prob, + const int seed, const bool verbose, + const int32_t min_num_sent) { + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl( + docs_, sizes_, num_epochs, max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_mapping_impl( + docs_, sizes_, num_epochs, max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, min_num_sent); + } } -template -py::array build_blocks_mapping_impl(const py::array_t& docs_, - const py::array_t& sizes_, - const py::array_t& titles_sizes_, - const int32_t num_epochs, - const uint64_t max_num_samples, - const int32_t max_seq_length, - const int32_t seed, - const bool verbose, - const bool use_one_sent_blocks) { - /* Build a mapping of (start-index, end-index, sequence-length) where - start and end index are the indices of the sentences in the sample - and sequence-length is the target sequence length. - */ - - // Consistency checks. - assert(num_epochs > 0); - assert(max_seq_length > 1); - assert(seed > 0); - - // Remove bound checks. - auto docs = docs_.unchecked<1>(); - auto sizes = sizes_.unchecked<1>(); - auto titles_sizes = titles_sizes_.unchecked<1>(); +template +py::array build_blocks_mapping_impl( + const py::array_t& docs_, const py::array_t& sizes_, + const py::array_t& titles_sizes_, const int32_t num_epochs, + const uint64_t max_num_samples, const int32_t max_seq_length, + const int32_t seed, const bool verbose, const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); - if (verbose) { - const auto sent_start_index = docs[0]; - const auto sent_end_index = docs[docs_.shape(0) - 1]; - const auto num_sentences = sent_end_index - sent_start_index; - cout << " using:" << endl << std::flush; - cout << " number of documents: " << docs_.shape(0) - 1 << - endl << std::flush; - cout << " sentences range: [" << sent_start_index << - ", " << sent_end_index << ")" << endl << std::flush; - cout << " total number of sentences: " << num_sentences << - endl << std::flush; - cout << " number of epochs: " << num_epochs << - endl << std::flush; - cout << " maximum number of samples: " << max_num_samples << - endl << std::flush; - cout << " maximum sequence length: " << max_seq_length << - endl << std::flush; - cout << " seed: " << seed << endl << - std::flush; - } + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 + << endl + << std::flush; + cout << " sentences range: [" << sent_start_index << ", " + << sent_end_index << ")" << endl + << std::flush; + cout << " total number of sentences: " << num_sentences << endl + << std::flush; + cout << " number of epochs: " << num_epochs << endl + << std::flush; + cout << " maximum number of samples: " << max_num_samples << endl + << std::flush; + cout << " maximum sequence length: " << max_seq_length << endl + << std::flush; + cout << " seed: " << seed << endl + << std::flush; + } - // Mapping and its length (1D). - int64_t num_samples = -1; - DocIdx* maps = NULL; + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; - // Acceptable number of sentences per block. - int min_num_sent = 2; - if (use_one_sent_blocks) { - min_num_sent = 1; - } + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } - // Perform two iterations, in the first iteration get the size - // and allocate memory and in the second iteration populate the map. - bool second = false; - for (int32_t iteration=0; iteration<2; ++iteration) { - - // Set the flag on second iteration. - second = (iteration == 1); - - // Current map index. - uint64_t map_index = 0; - - uint64_t empty_docs = 0; - uint64_t one_sent_docs = 0; - uint64_t long_sent_docs = 0; - // For each epoch: - for (int32_t epoch=0; epoch= max_num_samples) { - if (verbose && (!second)) { - cout << " reached " << max_num_samples << " samples after " - << epoch << " epochs ..." << endl << std::flush; - } - break; - } - // For each document: - for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { - - // Document sentences are in [sent_index_first, sent_index_last) - const auto sent_index_first = docs[doc]; - const auto sent_index_last = docs[doc + 1]; - const auto target_seq_len = max_seq_length - titles_sizes[doc]; - - // At the begining of the document previous index is the - // start index. - auto prev_start_index = sent_index_first; - - // Remaining documents. - auto num_remain_sent = sent_index_last - sent_index_first; - - // Some bookkeeping - if ((epoch == 0) && (!second)) { - if (num_remain_sent == 0) { - ++empty_docs; - } - if (num_remain_sent == 1) { - ++one_sent_docs; - } - } - // Detect documents with long sentences. - bool contains_long_sentence = false; - if (num_remain_sent >= min_num_sent) { - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - if (sizes[sent_index] > LONG_SENTENCE_LEN){ - if ((epoch == 0) && (!second)) { - ++long_sent_docs; - } - contains_long_sentence = true; - break; - } - } - } - // If we have enough sentences and no long sentences. - if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { - - // Set values. - auto seq_len = int32_t{0}; - auto num_sent = int32_t{0}; - - // Loop through sentences. - for (auto sent_index=sent_index_first; - sent_index < sent_index_last; ++sent_index) { - - // Add the size and number of sentences. - seq_len += sizes[sent_index]; - ++num_sent; - --num_remain_sent; - - // If we have reached the target length. - // and there are an acceptable number of sentences left - // and if we have at least the minimum number of sentences. - // or if we have reached end of the document. - if (((seq_len >= target_seq_len) && - (num_remain_sent >= min_num_sent) && - (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { - - // Populate the map. - if (second) { - const auto map_index_0 = 4 * map_index; - // Each sample has 4 items: the starting sentence index, ending sentence index, - // the index of the document from which the block comes (used for fetching titles) - // and the unique id of the block (used for creating block indexes) - - maps[map_index_0] = static_cast(prev_start_index); - maps[map_index_0 + 1] = static_cast(sent_index + 1); - maps[map_index_0 + 2] = static_cast(doc); - maps[map_index_0 + 3] = static_cast(block_id); - } - - // Update indices / counters. - ++map_index; - ++block_id; - prev_start_index = sent_index + 1; - seq_len = 0; - num_sent = 0; - } - } // for (auto sent_index=sent_index_first; ... - } // if (num_remain_sent > 1) { - } // for (int doc=0; doc < num_docs; ++doc) { - } // for (int epoch=0; epoch < num_epochs; ++epoch) { - - if (!second) { - if (verbose) { - cout << " number of empty documents: " << empty_docs << - endl << std::flush; - cout << " number of documents with one sentence: " << - one_sent_docs << endl << std::flush; - cout << " number of documents with long sentences: " << - long_sent_docs << endl << std::flush; - cout << " will create mapping for " << map_index << - " samples" << endl << std::flush; + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration = 0; iteration < 2; ++iteration) { + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch = 0; epoch < num_epochs; ++epoch) { + // assign every block a unique id + int32_t block_id = 0; + + if (map_index >= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl + << std::flush; + } + break; + } + // For each document: + for (int32_t doc = 0; doc < (docs.shape(0) - 1); ++doc) { + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN) { + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; } - assert(maps == NULL); - assert(num_samples < 0); - maps = new DocIdx[4*map_index]; - num_samples = static_cast(map_index); + } } - - } // for (int iteration=0; iteration < 2; ++iteration) { - - // Shuffle. - // We need a 64 bit random number generator as we might have more - // than 2 billion samples. - std::mt19937_64 rand64_gen(seed + 1); - for (auto i=(num_samples - 1); i > 0; --i) { - const auto j = static_cast(rand64_gen() % (i + 1)); - const auto i0 = 4 * i; - const auto j0 = 4 * j; - // Swap values. - swap(maps[i0], maps[j0]); - swap(maps[i0 + 1], maps[j0 + 1]); - swap(maps[i0 + 2], maps[j0 + 2]); - swap(maps[i0 + 3], maps[j0 + 3]); + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index = sent_index_first; sent_index < sent_index_last; + ++sent_index) { + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent)) || + (num_remain_sent == 0)) { + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending + // sentence index, the index of the document from which the + // block comes (used for fetching titles) and the unique id of + // the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << endl + << std::flush; + cout << " number of documents with one sentence: " << one_sent_docs + << endl + << std::flush; + cout << " number of documents with long sentences: " << long_sent_docs + << endl + << std::flush; + cout << " will create mapping for " << map_index << " samples" << endl + << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4 * map_index]; + num_samples = static_cast(map_index); } - // Method to deallocate memory. - py::capsule free_when_done(maps, [](void *mem_) { - DocIdx *mem = reinterpret_cast(mem_); - delete[] mem; - }); - - // Return the numpy array. - const auto byte_size = sizeof(DocIdx); - return py::array(std::vector{num_samples, 4}, // shape - {4*byte_size, byte_size}, // C-style contiguous strides - maps, // the data pointer - free_when_done); // numpy array references + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i = (num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void* mem_) { + DocIdx* mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4 * byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references } -py::array build_blocks_mapping(const py::array_t& docs_, - const py::array_t& sizes_, - const py::array_t& titles_sizes_, - const int num_epochs, - const uint64_t max_num_samples, - const int max_seq_length, - const int seed, - const bool verbose, - const bool use_one_sent_blocks) { - - if (sizes_.size() > std::numeric_limits::max()) { - if (verbose) { - cout << " using uint64 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); - } else { - if (verbose) { - cout << " using uint32 for data mapping..." << endl << std::flush; - } - return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, - num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); +py::array build_blocks_mapping( + const py::array_t& docs_, const py::array_t& sizes_, + const py::array_t& titles_sizes_, const int num_epochs, + const uint64_t max_num_samples, const int max_seq_length, const int seed, + const bool verbose, const bool use_one_sent_blocks) { + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl( + docs_, sizes_, titles_sizes_, num_epochs, max_num_samples, + max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; } + return build_blocks_mapping_impl( + docs_, sizes_, titles_sizes_, num_epochs, max_num_samples, + max_seq_length, seed, verbose, use_one_sent_blocks); + } } PYBIND11_MODULE(helpers, m) { - m.def("build_mapping", &build_mapping); - m.def("build_blocks_mapping", &build_blocks_mapping); - m.def("build_sample_idx", &build_sample_idx); - m.def("build_blending_indices", &build_blending_indices); + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); } diff --git a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py index 6dac35ff9d41..eb45ead6aebb 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/ict_dataset.py @@ -2,12 +2,11 @@ import random import numpy as np -from torch.utils.data import Dataset - -from megatron import get_tokenizer -from megatron import get_args +from megatron import get_args, get_tokenizer from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.realm_dataset_utils import get_block_samples_mapping +from torch.utils.data import Dataset + def make_attention_mask(source_block, target_block): """ @@ -20,6 +19,7 @@ def make_attention_mask(source_block, target_block): # (source_length, target_length) return mask + def get_ict_dataset(use_titles=True, query_in_block_prob=1): """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) rather than for training, since it is only built with a single epoch sample mapping. @@ -28,28 +28,37 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1): block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True) titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True) - kwargs = dict( - name='full', - block_dataset=block_dataset, - title_dataset=titles_dataset, - data_prefix=args.data_path, - num_epochs=1, - max_num_samples=None, - max_seq_length=args.seq_length, - seed=1, - query_in_block_prob=query_in_block_prob, - use_titles=use_titles, - use_one_sent_docs=args.use_one_sent_docs - ) + kwargs = dict(name='full', + block_dataset=block_dataset, + title_dataset=titles_dataset, + data_prefix=args.data_path, + num_epochs=1, + max_num_samples=None, + max_seq_length=args.seq_length, + seed=1, + query_in_block_prob=query_in_block_prob, + use_titles=use_titles, + use_one_sent_docs=args.use_one_sent_docs) dataset = ICTDataset(**kwargs) return dataset class ICTDataset(Dataset): """Dataset containing sentences and their blocks for an inverse cloze task.""" - def __init__(self, name, block_dataset, title_dataset, data_prefix, - num_epochs, max_num_samples, max_seq_length, query_in_block_prob, - seed, use_titles=True, use_one_sent_docs=False, binary_head=False): + + def __init__(self, + name, + block_dataset, + title_dataset, + data_prefix, + num_epochs, + max_num_samples, + max_seq_length, + query_in_block_prob, + seed, + use_titles=True, + use_one_sent_docs=False, + binary_head=False): self.name = name self.seed = seed self.max_seq_length = max_seq_length @@ -60,9 +69,8 @@ def __init__(self, name, block_dataset, title_dataset, data_prefix, self.use_titles = use_titles self.use_one_sent_docs = use_one_sent_docs - self.samples_mapping = get_block_samples_mapping( - block_dataset, title_dataset, data_prefix, num_epochs, - max_num_samples, max_seq_length, seed, name, use_one_sent_docs) + self.samples_mapping = get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs, + max_num_samples, max_seq_length, seed, name, use_one_sent_docs) self.tokenizer = get_tokenizer() self.vocab_id_list = list(self.tokenizer.inv_vocab.keys()) self.vocab_id_to_token_list = self.tokenizer.inv_vocab diff --git a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py index b4febcd822e1..a0f8d656055b 100644 --- a/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py +++ b/examples/tutorial/sequence_parallel/data/datasets/indexed_dataset.py @@ -3,17 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - # copied from fairseq/fairseq/data/indexed_dataset.py # Removed IndexedRawTextDataset since it relied on Fairseq dictionary # other slight modifications to remove fairseq dependencies # Added document index to index file and made it accessible. # An empty sentence no longer separates documents. -from functools import lru_cache import os import shutil import struct +from functools import lru_cache from itertools import accumulate import numpy as np @@ -88,16 +87,7 @@ def write_longs(f, a): f.write(np.array(a, dtype=np.int64)) -dtypes = { - 1: np.uint8, - 2: np.int8, - 3: np.int16, - 4: np.int32, - 5: np.int64, - 6: np.float, - 7: np.double, - 8: np.uint16 -} +dtypes = {1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16} def code(dtype): @@ -136,10 +126,8 @@ def __init__(self, path): def read_index(self, path): with open(index_file_path(path), 'rb') as f: magic = f.read(8) - assert magic == self._HDR_MAGIC, ( - 'Index file doesn\'t match expected format. ' - 'Make sure that --dataset-impl is configured properly.' - ) + assert magic == self._HDR_MAGIC, ('Index file doesn\'t match expected format. ' + 'Make sure that --dataset-impl is configured properly.') version = f.read(8) assert struct.unpack('= 0x4E00 and cp <= 0x9FFF) or # - (cp >= 0x3400 and cp <= 0x4DBF) or # - (cp >= 0x20000 and cp <= 0x2A6DF) or # - (cp >= 0x2A700 and cp <= 0x2B73F) or # - (cp >= 0x2B740 and cp <= 0x2B81F) or # - (cp >= 0x2B820 and cp <= 0x2CEAF) or - (cp >= 0xF900 and cp <= 0xFAFF) or # - (cp >= 0x2F800 and cp <= 0x2FA1F)): # + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # return True return False @@ -422,8 +410,7 @@ def _is_punctuation(char): # Characters such as "^", "$", and "`" are not in the Unicode # Punctuation class but we treat them as punctuation anyways, for # consistency. - if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or - (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): return True cat = unicodedata.category(char) if cat.startswith("P"): diff --git a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py index ee3c923e8e76..2bf680de87cb 100644 --- a/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py +++ b/examples/tutorial/sequence_parallel/data/tokenizer/tokenizer.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Megatron tokenizers.""" -from abc import ABC -from abc import abstractmethod -from colossalai.core import global_context as gpc +from abc import ABC, abstractmethod + from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from .bert_tokenization import FullTokenizer as FullBertTokenizer @@ -26,18 +25,13 @@ def build_tokenizer(vocab_file, tokenizer_type, vocab_extra_ids=0): """Initialize tokenizer.""" if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: - print('> building {} tokenizer ...'.format(tokenizer_type), - flush=True) + print('> building {} tokenizer ...'.format(tokenizer_type), flush=True) # Select and instantiate the tokenizer. if tokenizer_type == 'BertWordPieceLowerCase': - tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, - lower_case=True, - vocab_extra_ids=vocab_extra_ids) + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=True, vocab_extra_ids=vocab_extra_ids) elif tokenizer_type == 'BertWordPieceCase': - tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, - lower_case=False, - vocab_extra_ids=vocab_extra_ids) + tokenizer = _BertWordPieceTokenizer(vocab_file=vocab_file, lower_case=False, vocab_extra_ids=vocab_extra_ids) else: raise NotImplementedError('{} tokenizer is not ' 'implemented.'.format(tokenizer_type)) @@ -62,8 +56,8 @@ def _vocab_size_with_padding(orig_vocab_size, make_vocab_size_divisible_by=128): after += 1 if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0: print(' > padded vocab (size: {}) with {} dummy tokens ' - '(new size: {})'.format( - orig_vocab_size, after - orig_vocab_size, after), flush=True) + '(new size: {})'.format(orig_vocab_size, after - orig_vocab_size, after), + flush=True) return after @@ -142,8 +136,7 @@ def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): self._additional_special_tokens = [] # (dsachan) Add BOS and EOS tokens - SPECIAL_TOKENS = {'eos_token': '[EOS]', - 'bos_token': '[BOS]'} + SPECIAL_TOKENS = {'eos_token': '[EOS]', 'bos_token': '[BOS]'} self._bos_token = '[BOS]' self.add_token(self._bos_token) self._bos_token_id = self.vocab.get(self._bos_token) @@ -155,8 +148,7 @@ def __init__(self, vocab_file, lower_case=True, vocab_extra_ids=0): # (dsachan) Add additional special tokens # These can be used as sentinel tokens in T5 model inputs additional_special_tokens = [] - additional_special_tokens.extend( - ["".format(i) for i in range(vocab_extra_ids)]) + additional_special_tokens.extend(["".format(i) for i in range(vocab_extra_ids)]) self.add_additional_special_tokens(additional_special_tokens) def add_token(self, token): diff --git a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py index e87a778cf5d5..9ef0ce4ef96a 100644 --- a/examples/tutorial/sequence_parallel/loss_func/bert_loss.py +++ b/examples/tutorial/sequence_parallel/loss_func/bert_loss.py @@ -1,37 +1,29 @@ import torch +import torch.distributed as dist import torch.nn as nn -from colossalai.core import global_context as gpc +import torch.nn.functional as F + from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger -import torch.nn.functional as F -import torch.distributed as dist + from .cross_entropy import vocab_cross_entropy class BertLoss(nn.Module): - def forward(self, - lm_loss, - sop_logits, - loss_mask, - sentence_order): + def forward(self, lm_loss, sop_logits, loss_mask, sentence_order): lm_loss_ = lm_loss.float() loss_mask = loss_mask.float() loss_mask_sum = loss_mask.sum() - lm_loss = torch.sum( - lm_loss_.view(-1) * loss_mask.reshape(-1)) + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) lm_loss /= loss_mask_sum - torch.distributed.all_reduce( - lm_loss, - group=gpc.get_group(ParallelMode.SEQUENCE) - ) + torch.distributed.all_reduce(lm_loss, group=gpc.get_group(ParallelMode.SEQUENCE)) if sop_logits is not None: - sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), - sentence_order.view(-1), - ignore_index=-1) + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) sop_loss = sop_loss.float() loss = lm_loss + sop_loss * gpc.get_world_size(ParallelMode.SEQUENCE) else: diff --git a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py index 54553c29a61f..face52fc8b3d 100644 --- a/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py +++ b/examples/tutorial/sequence_parallel/loss_func/cross_entropy.py @@ -1,7 +1,8 @@ -from colossalai.context.parallel_mode import ParallelMode import torch from torch.cuda.amp import custom_bwd, custom_fwd +from colossalai.context.parallel_mode import ParallelMode + class _VocabCrossEntropy(torch.autograd.Function): @@ -24,8 +25,7 @@ def forward(ctx, vocab_parallel_logits, target): # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)) masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], - device=logits_2d.device) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits = predicted_logits_1d.view_as(target) @@ -58,10 +58,8 @@ def backward(ctx, grad_output): grad_2d = grad_input.view(-1, partition_vocab_size) # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], - device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= ( - 1.0 - target_mask.view(-1).float()) + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) # Finally elementwise multiplication with the output gradients. grad_input.mul_(grad_output.unsqueeze(dim=-1)) diff --git a/examples/tutorial/sequence_parallel/loss_func/utils.py b/examples/tutorial/sequence_parallel/loss_func/utils.py index a3d92f294326..aa5a92a0ca98 100644 --- a/examples/tutorial/sequence_parallel/loss_func/utils.py +++ b/examples/tutorial/sequence_parallel/loss_func/utils.py @@ -1,11 +1,9 @@ - import torch def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, '{} is not divisible by {}'.format( - numerator, denominator) + assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator) def divide(numerator, denominator): @@ -15,8 +13,7 @@ def divide(numerator, denominator): return numerator // denominator -def split_tensor_along_last_dim(tensor, num_partitions, - contiguous_split_chunks=False): +def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. Arguments: tensor: input tensor. @@ -42,8 +39,7 @@ class VocabUtility: partition: Note that indices in [fist, last)""" @staticmethod - def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank, world_size): + def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size): index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size return index_f, index_l @@ -51,5 +47,4 @@ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, @staticmethod def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size) + return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size) diff --git a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py index 8d95679ff76d..59c856064142 100644 --- a/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py +++ b/examples/tutorial/sequence_parallel/lr_scheduler/annealing_lr.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Learning rate decay functions.""" import math @@ -86,8 +85,7 @@ def get_lr(self): elif self.decay_style == 'cosine': coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) else: - raise Exception('{} decay style is not supported.'.format( - self.decay_style)) + raise Exception('{} decay style is not supported.'.format(self.decay_style)) return self.min_lr + coeff * delta_lr @@ -127,29 +125,22 @@ def load_state_dict(self, sd): max_lr_ = sd['start_lr'] else: max_lr_ = sd['max_lr'] - self.max_lr = self._check_and_set(self.max_lr, max_lr_, - 'learning rate') + self.max_lr = self._check_and_set(self.max_lr, max_lr_, 'learning rate') - self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], - 'minimum learning rate') + self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'], 'minimum learning rate') if 'warmup_iter' in sd: warmup_steps_ = sd['warmup_iter'] else: warmup_steps_ = sd['warmup_steps'] - self.warmup_steps = self._check_and_set(self.warmup_steps, - warmup_steps_, - 'warmup iterations') + self.warmup_steps = self._check_and_set(self.warmup_steps, warmup_steps_, 'warmup iterations') if 'end_iter' in sd: decay_steps_ = sd['end_iter'] else: decay_steps_ = sd['decay_steps'] - self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, - 'total number of iterations') - self.decay_style = self._check_and_set(self.decay_style, - sd['decay_style'], - 'decay style') + self.decay_steps = self._check_and_set(self.decay_steps, decay_steps_, 'total number of iterations') + self.decay_style = self._check_and_set(self.decay_style, sd['decay_style'], 'decay style') if 'num_iters' in sd: num_steps = sd['num_iters'] diff --git a/examples/tutorial/sequence_parallel/model/__init__.py b/examples/tutorial/sequence_parallel/model/__init__.py index 139597f9cb07..e69de29bb2d1 100644 --- a/examples/tutorial/sequence_parallel/model/__init__.py +++ b/examples/tutorial/sequence_parallel/model/__init__.py @@ -1,2 +0,0 @@ - - diff --git a/examples/tutorial/sequence_parallel/model/bert.py b/examples/tutorial/sequence_parallel/model/bert.py index 049579c5a639..e8987fecff9a 100644 --- a/examples/tutorial/sequence_parallel/model/bert.py +++ b/examples/tutorial/sequence_parallel/model/bert.py @@ -1,33 +1,37 @@ -from colossalai.context.parallel_mode import ParallelMode +import inspect + import torch import torch.nn as nn -import inspect -from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding -from .layers.init_method import init_normal, output_init_normal -from colossalai.core import global_context as gpc + from colossalai.context import ParallelMode +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc from colossalai.kernel import LayerNorm -from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.logging import get_dist_logger +from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper from colossalai.pipeline.utils import partition_uniform +from .layers import BertDualHead, BertLayer, Embedding, PreProcessor, VocabEmbedding +from .layers.init_method import init_normal, output_init_normal + class BertForPretrain(nn.Module): - def __init__(self, - vocab_size, - hidden_size, - max_sequence_length, - num_attention_heads, - num_layers, - add_binary_head, - is_naive_fp16, - num_tokentypes=2, - dropout_prob=0.1, - mlp_ratio=4, - init_std=0.02, - convert_fp16_to_fp32_in_softmax=False, - ): + def __init__( + self, + vocab_size, + hidden_size, + max_sequence_length, + num_attention_heads, + num_layers, + add_binary_head, + is_naive_fp16, + num_tokentypes=2, + dropout_prob=0.1, + mlp_ratio=4, + init_std=0.02, + convert_fp16_to_fp32_in_softmax=False, + ): super().__init__() self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE) assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size' @@ -47,19 +51,19 @@ def __init__(self, self.bert_layers = nn.ModuleList() for i in range(num_layers): - bert_layer = BertLayer(layer_number=i+1, + bert_layer = BertLayer(layer_number=i + 1, hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_dropout=dropout_prob, mlp_ratio=mlp_ratio, hidden_dropout=dropout_prob, convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - is_naive_fp16=is_naive_fp16 - ) + is_naive_fp16=is_naive_fp16) self.bert_layers.append(bert_layer) self.layer_norm = LayerNorm(hidden_size) - self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0), + self.head = BertDualHead(hidden_size, + self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head) self.reset_parameters() @@ -166,22 +170,20 @@ def __init__(self, end_idx = num_layers for i in range(start_idx, end_idx): - bert_layer = BertLayer(layer_number=i+1, + bert_layer = BertLayer(layer_number=i + 1, hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_dropout=dropout_prob, mlp_ratio=mlp_ratio, hidden_dropout=dropout_prob, convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - is_naive_fp16=is_naive_fp16 - ) + is_naive_fp16=is_naive_fp16) self.bert_layers.append(bert_layer) if self.last_stage: self.word_embeddings = VocabEmbedding(vocab_size, hidden_size) self.layer_norm = LayerNorm(hidden_size) - self.head = BertDualHead(hidden_size, vocab_size, - add_binary_head=add_binary_head) + self.head = BertDualHead(hidden_size, vocab_size, add_binary_head=add_binary_head) self.reset_parameters() def _init_normal(self, tensor): diff --git a/examples/tutorial/sequence_parallel/model/layers/__init__.py b/examples/tutorial/sequence_parallel/model/layers/__init__.py index 3a8823caa81b..58495c516239 100644 --- a/examples/tutorial/sequence_parallel/model/layers/__init__.py +++ b/examples/tutorial/sequence_parallel/model/layers/__init__.py @@ -1,4 +1,4 @@ -from .embedding import VocabEmbedding, Embedding from .bert_layer import BertLayer +from .embedding import Embedding, VocabEmbedding from .head import BertDualHead from .preprocess import PreProcessor diff --git a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py index 4ede21516f65..6f10aba5f533 100644 --- a/examples/tutorial/sequence_parallel/model/layers/bert_layer.py +++ b/examples/tutorial/sequence_parallel/model/layers/bert_layer.py @@ -1,10 +1,12 @@ import torch import torch.nn as nn -from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing -from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference + from colossalai.kernel.cuda_native import LayerNorm -from .mlp import TransformerMLP +from colossalai.kernel.jit import bias_dropout_add_fused_inference, bias_dropout_add_fused_train +from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing + from .dropout import get_bias_dropout_add +from .mlp import TransformerMLP def attention_mask_func(attention_scores, attention_mask): @@ -48,8 +50,7 @@ def __init__(self, layer_number=layer_number, apply_query_key_layer_scaling=True, convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax, - fp16=is_naive_fp16 - ) + fp16=is_naive_fp16) self.hidden_dropout = hidden_dropout self.bias_dropout_fusion = bias_dropout_fusion @@ -89,11 +90,8 @@ def forward(self, hidden_states, attention_mask): # re-enable torch grad to enable fused optimization. with torch.enable_grad(): - layernorm_input = bias_dropout_add_func( - attention_output, - attention_bias.expand_as(residual), - residual, - self.hidden_dropout) + layernorm_input = bias_dropout_add_func(attention_output, attention_bias.expand_as(residual), residual, + self.hidden_dropout) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) @@ -109,10 +107,6 @@ def forward(self, hidden_states, attention_mask): # re-enable torch grad to enable fused optimization. with torch.enable_grad(): - output = bias_dropout_add_func( - mlp_output, - mlp_bias.expand_as(residual), - residual, - self.hidden_dropout) + output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) return output diff --git a/examples/tutorial/sequence_parallel/model/layers/dropout.py b/examples/tutorial/sequence_parallel/model/layers/dropout.py index 0e99105b8f7e..a8445120c04f 100644 --- a/examples/tutorial/sequence_parallel/model/layers/dropout.py +++ b/examples/tutorial/sequence_parallel/model/layers/dropout.py @@ -1,5 +1,6 @@ import torch + def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor out = torch.nn.functional.dropout(x + bias, p=prob, training=training) @@ -8,6 +9,8 @@ def bias_dropout_add(x, bias, residual, prob, training): def get_bias_dropout_add(training): + def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) - return _bias_dropout_add \ No newline at end of file + + return _bias_dropout_add diff --git a/examples/tutorial/sequence_parallel/model/layers/embedding.py b/examples/tutorial/sequence_parallel/model/layers/embedding.py index 0700d960d845..88f670848db4 100644 --- a/examples/tutorial/sequence_parallel/model/layers/embedding.py +++ b/examples/tutorial/sequence_parallel/model/layers/embedding.py @@ -19,15 +19,12 @@ def __init__(self, num_embeddings, embedding_dim): self._weight = None # Allocate weights and initialize. - self.weight = nn.Parameter(torch.empty( - self.num_embeddings, self.embedding_dim)) + self.weight = nn.Parameter(torch.empty(self.num_embeddings, self.embedding_dim)) init.xavier_uniform_(self.weight) def forward(self, hidden_state): - output = F.embedding(hidden_state, self.weight, - self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, - self.sparse) + output = F.embedding(hidden_state, self.weight, self.padding_idx, self.max_norm, self.norm_type, + self.scale_grad_by_freq, self.sparse) return output def __repr__(self): @@ -48,12 +45,7 @@ class Embedding(nn.Module): will ignore this embedding """ - def __init__(self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - num_tokentypes): + def __init__(self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes): super(Embedding, self).__init__() self.hidden_size = hidden_size @@ -62,16 +54,14 @@ def __init__(self, self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size) # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) + self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size) # Token type embedding. # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, - self.hidden_size) + self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size) else: self.tokentype_embeddings = None diff --git a/examples/tutorial/sequence_parallel/model/layers/head.py b/examples/tutorial/sequence_parallel/model/layers/head.py index ea336b9d131e..097be8c1c8e3 100644 --- a/examples/tutorial/sequence_parallel/model/layers/head.py +++ b/examples/tutorial/sequence_parallel/model/layers/head.py @@ -1,14 +1,16 @@ -import colossalai import torch import torch.nn as nn import torch.nn.functional as F -from .pooler import Pooler -from .linear import Linear -from .embedding import VocabEmbedding -from colossalai.core import global_context as gpc +from loss_func.cross_entropy import vocab_cross_entropy + +import colossalai from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.kernel import LayerNorm -from loss_func.cross_entropy import vocab_cross_entropy + +from .embedding import VocabEmbedding +from .linear import Linear +from .pooler import Pooler class BertLMHead(nn.Module): @@ -19,10 +21,11 @@ class BertLMHead(nn.Module): layernorm_epsilon: tolerance for layer norm divisions """ - def __init__(self, - vocab_size, - hidden_size, - ): + def __init__( + self, + vocab_size, + hidden_size, + ): super(BertLMHead, self).__init__() self.bias = torch.nn.Parameter(torch.zeros(vocab_size)) diff --git a/examples/tutorial/sequence_parallel/model/layers/init_method.py b/examples/tutorial/sequence_parallel/model/layers/init_method.py index 1b409dfe4054..22d12a504fab 100644 --- a/examples/tutorial/sequence_parallel/model/layers/init_method.py +++ b/examples/tutorial/sequence_parallel/model/layers/init_method.py @@ -1,6 +1,8 @@ -import torch import math +import torch + + def init_normal(tensor, sigma): """Init method based on N(0, sigma).""" torch.nn.init.normal_(tensor, mean=0.0, std=sigma) diff --git a/examples/tutorial/sequence_parallel/model/layers/linear.py b/examples/tutorial/sequence_parallel/model/layers/linear.py index 5ae7d671e2bf..fb56ad60c322 100644 --- a/examples/tutorial/sequence_parallel/model/layers/linear.py +++ b/examples/tutorial/sequence_parallel/model/layers/linear.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from torch.nn import Parameter import torch.nn.functional as F import torch.nn.init as init +from torch.nn import Parameter class Linear(nn.Module): @@ -24,11 +24,7 @@ class Linear(nn.Module): adding bias but instead return it. """ - def __init__(self, - input_size, - output_size, - bias=True, - skip_bias_add=False): + def __init__(self, input_size, output_size, bias=True, skip_bias_add=False): super(Linear, self).__init__() # Keep input parameters @@ -36,9 +32,10 @@ def __init__(self, self.output_size = output_size self.skip_bias_add = skip_bias_add - self.weight = Parameter(torch.empty(self.output_size, - self.input_size, - )) + self.weight = Parameter(torch.empty( + self.output_size, + self.input_size, + )) init.normal_(self.weight) if bias: self.bias = Parameter(torch.empty(self.output_size)) diff --git a/examples/tutorial/sequence_parallel/model/layers/mlp.py b/examples/tutorial/sequence_parallel/model/layers/mlp.py index a255de813d13..2147d5ff5c71 100644 --- a/examples/tutorial/sequence_parallel/model/layers/mlp.py +++ b/examples/tutorial/sequence_parallel/model/layers/mlp.py @@ -2,9 +2,10 @@ import torch.nn as nn import torch.nn.functional as F -from .linear import Linear from colossalai.kernel.jit import bias_gelu_impl +from .linear import Linear + class TransformerMLP(nn.Module): """MLP. @@ -18,19 +19,13 @@ def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True): super(TransformerMLP, self).__init__() # Project to 4h. - self.dense_h_to_4h = Linear( - hidden_size, - int(hidden_size*mlp_ratio), - skip_bias_add=True) + self.dense_h_to_4h = Linear(hidden_size, int(hidden_size * mlp_ratio), skip_bias_add=True) self.bias_gelu_fusion = fuse_gelu self.activation_func = F.gelu # Project back to h. - self.dense_4h_to_h = Linear( - int(hidden_size*mlp_ratio), - hidden_size, - skip_bias_add=True) + self.dense_4h_to_h = Linear(int(hidden_size * mlp_ratio), hidden_size, skip_bias_add=True) def forward(self, hidden_states): # hidden states should be in the shape of [s, b, h] diff --git a/examples/tutorial/sequence_parallel/model/layers/pooler.py b/examples/tutorial/sequence_parallel/model/layers/pooler.py index 282ed114790b..c3397787aecf 100644 --- a/examples/tutorial/sequence_parallel/model/layers/pooler.py +++ b/examples/tutorial/sequence_parallel/model/layers/pooler.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn + from .linear import Linear diff --git a/examples/tutorial/sequence_parallel/model/layers/preprocess.py b/examples/tutorial/sequence_parallel/model/layers/preprocess.py index 53a326ddacf1..688edf891e51 100644 --- a/examples/tutorial/sequence_parallel/model/layers/preprocess.py +++ b/examples/tutorial/sequence_parallel/model/layers/preprocess.py @@ -1,6 +1,7 @@ -from colossalai.context.parallel_mode import ParallelMode import torch import torch.nn as nn + +from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -14,8 +15,8 @@ def bert_position_ids(self, token_ids): # Create position ids seq_length = token_ids.size(1) local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE) - position_ids = torch.arange(seq_length*local_rank, - seq_length * (local_rank+1), + position_ids = torch.arange(seq_length * local_rank, + seq_length * (local_rank + 1), dtype=torch.long, device=token_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(token_ids) diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py index 500e2cc0eddc..883361ee5ddd 100644 --- a/op_builder/cpu_adam.py +++ b/op_builder/cpu_adam.py @@ -20,10 +20,7 @@ def sources_files(self): return ret def include_dirs(self): - return [ - self.csrc_abs_path("includes"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("includes"), self.get_cuda_home_include()] def cxx_flags(self): extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native'] diff --git a/op_builder/fused_optim.py b/op_builder/fused_optim.py index 31ddfced1db2..e6933cc76594 100644 --- a/op_builder/fused_optim.py +++ b/op_builder/fused_optim.py @@ -10,7 +10,7 @@ class FusedOptimBuilder(Builder): def __init__(self): super().__init__(name=FusedOptimBuilder.NAME, prebuilt_import_path=FusedOptimBuilder.PREBUILT_IMPORT_PATH) - + def sources_files(self): ret = [ self.csrc_abs_path(fname) for fname in [ diff --git a/op_builder/moe.py b/op_builder/moe.py index eeb7d8e3980c..0831aa157599 100644 --- a/op_builder/moe.py +++ b/op_builder/moe.py @@ -13,10 +13,7 @@ def __init__(self): super().__init__(name=MOEBuilder.NAME, prebuilt_import_path=MOEBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): - ret = [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] return ret def sources_files(self): diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py index f9103fe94729..51fc300098eb 100644 --- a/op_builder/multi_head_attn.py +++ b/op_builder/multi_head_attn.py @@ -10,9 +10,7 @@ class MultiHeadAttnBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.multihead_attention" def __init__(self): - super().__init__(name=MultiHeadAttnBuilder.NAME, - prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) - + super().__init__(name=MultiHeadAttnBuilder.NAME, prebuilt_import_path=MultiHeadAttnBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): ret = [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py index 11cfda39a85c..38de3447ee80 100644 --- a/op_builder/scaled_masked_softmax.py +++ b/op_builder/scaled_masked_softmax.py @@ -9,21 +9,16 @@ class ScaledMaskedSoftmaxBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.scaled_masked_softmax" def __init__(self): - super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + super().__init__(name=ScaledMaskedSoftmaxBuilder.NAME, + prebuilt_import_path=ScaledMaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) # necessary 4 functions def sources_files(self): - ret = [ - self.csrc_abs_path(fname) for fname in - ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu'] - ] + ret = [self.csrc_abs_path(fname) for fname in ['scaled_masked_softmax.cpp', 'scaled_masked_softmax_cuda.cu']] return ret def include_dirs(self): - return [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] def cxx_flags(self): return ['-O3'] + self.version_dependent_macros diff --git a/op_builder/scaled_upper_triangle_masked_softmax.py b/op_builder/scaled_upper_triangle_masked_softmax.py index d0d2433aa645..db4f8ddfd9e9 100644 --- a/op_builder/scaled_upper_triangle_masked_softmax.py +++ b/op_builder/scaled_upper_triangle_masked_softmax.py @@ -9,13 +9,11 @@ class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder): PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax" def __init__(self): - super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) + super().__init__(name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME, + prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH) def include_dirs(self): - return [ - self.csrc_abs_path("kernels/include"), - self.get_cuda_home_include() - ] + return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()] def sources_files(self): ret = [ diff --git a/op_builder/utils.py b/op_builder/utils.py index 4029703e4829..79fbaa6d50da 100644 --- a/op_builder/utils.py +++ b/op_builder/utils.py @@ -156,16 +156,15 @@ def set_cuda_arch_list(cuda_dir): # we only need to set this when CUDA is not available for cross-compilation if not cuda_available: - warnings.warn( - '\n[extension] PyTorch did not find available GPUs on this system.\n' - 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Colossal-AI will cross-compile for \n' - '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' - '2. Volta (compute capability 7.0)\n' - '3. Turing (compute capability 7.5),\n' - '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' - '\nIf you wish to cross-compile for a single specific architecture,\n' - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') + warnings.warn('\n[extension] PyTorch did not find available GPUs on this system.\n' + 'If your intention is to cross-compile, this is not an error.\n' + 'By default, Colossal-AI will cross-compile for \n' + '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n' + '2. Volta (compute capability 7.0)\n' + '3. Turing (compute capability 7.5),\n' + '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n' + '\nIf you wish to cross-compile for a single specific architecture,\n' + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n') if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) diff --git a/pytest.ini b/pytest.ini index ac31ace4bfae..01e5cd217c5d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,4 +3,4 @@ markers = cpu: tests which can run on CPU gpu: tests which requires a single GPU dist: tests which are run in a multi-GPU or multi-machine environment - experiment: tests for experimental features \ No newline at end of file + experiment: tests for experimental features diff --git a/tests/components_to_test/resnet.py b/tests/components_to_test/resnet.py index 193832ebc12d..2833afa97942 100644 --- a/tests/components_to_test/resnet.py +++ b/tests/components_to_test/resnet.py @@ -1,12 +1,15 @@ -from torchvision.models import resnet18 -from .registry import non_distributed_component_funcs -from pathlib import Path import os +from pathlib import Path + import torch -from torchvision.transforms import transforms from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 +from torchvision.transforms import transforms + from colossalai.utils import get_dataloader +from .registry import non_distributed_component_funcs + def get_cifar10_dataloader(train): # build dataloaders diff --git a/tests/test_analyzer/test_fx/test_nested_ckpt.py b/tests/test_analyzer/test_fx/test_nested_ckpt.py index c31aab6752f8..1344fe539968 100644 --- a/tests/test_analyzer/test_fx/test_nested_ckpt.py +++ b/tests/test_analyzer/test_fx/test_nested_ckpt.py @@ -1,7 +1,7 @@ +import pytest import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint -import pytest try: from colossalai._analyzer.fx import symbolic_trace diff --git a/tests/test_analyzer/test_subclasses/test_aten.py b/tests/test_analyzer/test_subclasses/test_aten.py index 591a8d617580..1ca8647d54d0 100644 --- a/tests/test_analyzer/test_subclasses/test_aten.py +++ b/tests/test_analyzer/test_subclasses/test_aten.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Union -import pytest +import pytest import torch import torch.nn as nn diff --git a/tests/test_auto_parallel/test_offload/model_utils.py b/tests/test_auto_parallel/test_offload/model_utils.py index c22b17ae42ba..8789e442793d 100644 --- a/tests/test_auto_parallel/test_offload/model_utils.py +++ b/tests/test_auto_parallel/test_offload/model_utils.py @@ -1,17 +1,13 @@ import torch import torch.nn as nn -from transformers import GPT2Config, GPT2LMHeadModel -from transformers import BertConfig, BertLMHeadModel +from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel + from tests.components_to_test.registry import non_distributed_component_funcs + class GPTLMModel(nn.Module): - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50257): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257): super().__init__() self.model = GPT2LMHeadModel( GPT2Config(n_embd=hidden_size, @@ -38,17 +34,24 @@ def forward(self, logits, labels): # Flatten the tokens return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + class BertLMModel(nn.Module): + def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522): super().__init__() - self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size, - num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size, - vocab_size=vocab_size)) + self.model = BertLMHeadModel( + BertConfig(n_embd=hidden_size, + num_hidden_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + max_position_embeddings=hidden_size, + vocab_size=vocab_size)) def forward(self, input_ids, attention_mask): # Only return lm_logits return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] + @non_distributed_component_funcs.register(name='bert_') def get_bert_components(): vocab_size = 1024 @@ -67,6 +70,7 @@ def bert_data_gen(device="meta"): return bert_model_builder, bert_data_gen + @non_distributed_component_funcs.register(name='gpt2_') def get_gpt2_components(): vocab_size = 1024 @@ -83,4 +87,4 @@ def gpt2_data_gen(device="meta"): kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - return gpt2_model_builder, gpt2_data_gen \ No newline at end of file + return gpt2_model_builder, gpt2_data_gen diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index d569570f4b7d..3e231aeff3fe 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -1,46 +1,44 @@ import time -import pytest from functools import partial +import pytest import torch -from torch.utils._pytree import tree_map import torch.multiprocessing as mp +from torch.utils._pytree import tree_map import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.fx.profiler import parameter_size -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.utils import free_port, get_current_device -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML +from colossalai.fx.profiler import parameter_size +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.testing import parameterize - -from tests.test_tensor.common_utils import set_seed +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext from tests.test_auto_parallel.test_offload.model_utils import * +from tests.test_tensor.common_utils import set_seed @parameterize('model_name', ['gpt2_']) @parameterize('memory_budget', [5000]) @parameterize('solver_name', ['asyn']) -def exam_fwd_bwd( - model_name: str, - memory_budget: float, - solver_name: str -): +def exam_fwd_bwd(model_name: str, memory_budget: float, solver_name: str): # build model get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() - label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device()) + label = torch.randint(low=0, high=128, size=( + 64, + 8, + ), device=get_current_device()) criterion = LMLoss() set_seed(42) start_time = time.time() model = model_builder() model.train() - param_size = parameter_size(model) / 1024 ** 2 / 2 + param_size = parameter_size(model) / 1024**2 / 2 init_time = time.time() - start_time print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s") @@ -92,13 +90,11 @@ def exam_fwd_bwd( torch.cuda.synchronize() exec_time = sum(sorted(time_list)[:5]) / 5 - runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 - runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 print(f'gemini | model_name: {model_name}') - print( - f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' - ) + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') print(time_list) del data_args @@ -129,15 +125,14 @@ def exam_fwd_bwd( torch.cuda.synchronize() exec_time = sum(sorted(time_list)[:5]) / 5 - runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2 - runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2 + runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024**2 + runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024**2 print(f'solver_name: {solver_name} | model_name: {model_name}') - print( - f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' - f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|' - ) + print(f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB ' + f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|') print(time_list) + @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') def test_perf(rank, world_size, port): config = {} diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py index 2efbb750f80d..f74784bd1243 100644 --- a/tests/test_auto_parallel/test_offload/test_solver.py +++ b/tests/test_auto_parallel/test_offload/test_solver.py @@ -3,20 +3,19 @@ from torch.fx import GraphModule from torch.utils._pytree import tree_map +from colossalai.auto_parallel.offload.region_manager import RegionManager +from colossalai.auto_parallel.offload.solver import NOT_NVML, SolverFactory from colossalai.fx import ColoTracer, is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.auto_parallel.offload.region_manager import RegionManager -from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML from colossalai.testing import parameterize from tests.test_auto_parallel.test_offload.model_utils import * + @pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed') @parameterize('model_name', ['gpt2_', 'bert_']) @parameterize('memory_budget', [4000]) @parameterize('solver_name', ['syn', 'asyn']) -def solver_test(model_name: str, - memory_budget: float, - solver_name: str): +def solver_test(model_name: str, memory_budget: float, solver_name: str): get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen = get_components_func() @@ -52,11 +51,16 @@ def solver_test(model_name: str, for region in region_list: need_offload = region.need_offload to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None - print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') + print( + f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + ) for region in region_list.__reversed__(): need_offload = region.need_offload to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None - print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}') + print( + f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}' + ) + if __name__ == '__main__': - solver_test() \ No newline at end of file + solver_test() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py index 9838e2eb01c6..002b8003bb2d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_where_handler.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn -from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import \ - WhereHandler -from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, StrategiesVector) +from colossalai.auto_parallel.tensor_shard.node_handler.where_handler import WhereHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.device.device_mesh import DeviceMesh from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx.tracer.meta_patch.patched_module import linear diff --git a/tests/test_comm/test_boardcast_send_recv_v2.py b/tests/test_comm/test_boardcast_send_recv_v2.py index 1520d6054043..7a93eeb43951 100644 --- a/tests/test_comm/test_boardcast_send_recv_v2.py +++ b/tests/test_comm/test_boardcast_send_recv_v2.py @@ -5,13 +5,14 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.communication.p2p_v2 import _send_object, _recv_object, init_process_group + +from colossalai.communication.p2p_v2 import _recv_object, _send_object, init_process_group from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device disable_existing_loggers() world_size = 4 diff --git a/tests/test_comm/test_comm.py b/tests/test_comm/test_comm.py index 07cb67730d24..1cf21812e49f 100644 --- a/tests/test_comm/test_comm.py +++ b/tests/test_comm/test_comm.py @@ -4,12 +4,13 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp + from colossalai.communication import all_gather, all_reduce, reduce_scatter from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device CONFIG = dict(parallel=dict(data=8, pipeline=1, tensor=dict(mode=None, size=1))) diff --git a/tests/test_comm/test_object_list_p2p.py b/tests/test_comm/test_object_list_p2p.py index 701e3e8ade79..d90f516229b8 100644 --- a/tests/test_comm/test_object_list_p2p.py +++ b/tests/test_comm/test_object_list_p2p.py @@ -4,12 +4,20 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.communication.p2p import send_forward, recv_forward, send_backward, recv_backward, send_forward_recv_backward, send_backward_recv_forward + +from colossalai.communication.p2p import ( + recv_backward, + recv_forward, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, +) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device CONFIG = dict(parallel=dict(pipeline=2)) torch.manual_seed(123) diff --git a/tests/test_comm/test_object_list_p2p_v2.py b/tests/test_comm/test_object_list_p2p_v2.py index c639ac9f8ef3..da5325fc1531 100644 --- a/tests/test_comm/test_object_list_p2p_v2.py +++ b/tests/test_comm/test_object_list_p2p_v2.py @@ -4,13 +4,14 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.communication.p2p_v2 import send_forward, recv_forward, send_backward, recv_backward, init_process_group -from colossalai.context import ParallelMode, Initializer_Pipeline + +from colossalai.communication.p2p_v2 import init_process_group, recv_backward, recv_forward, send_backward, send_forward +from colossalai.context import Initializer_Pipeline, ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch -from colossalai.utils import free_port, get_current_device -from colossalai.testing import rerun_if_address_is_in_use from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device disable_existing_loggers() diff --git a/tests/test_config/sample_config.py b/tests/test_config/sample_config.py index 08ca108281b9..19e156841574 100644 --- a/tests/test_config/sample_config.py +++ b/tests/test_config/sample_config.py @@ -1,25 +1,19 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -train_data = dict( - dataset=dict( - type='CIFAR10Dataset', - root='/path/to/data', - download=True, - transform_pipeline=[ - dict(type='RandomResizedCrop', size=224), - dict(type='RandomHorizontalFlip'), - dict(type='ToTensor'), - dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ] - ), - dataloader=dict( - batch_size=64, - pin_memory=True, - num_workers=4, - sampler=dict( - type='DataParallelSampler', - shuffle=True, - ) - ) -) +train_data = dict(dataset=dict(type='CIFAR10Dataset', + root='/path/to/data', + download=True, + transform_pipeline=[ + dict(type='RandomResizedCrop', size=224), + dict(type='RandomHorizontalFlip'), + dict(type='ToTensor'), + dict(type='Normalize', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ]), + dataloader=dict(batch_size=64, + pin_memory=True, + num_workers=4, + sampler=dict( + type='DataParallelSampler', + shuffle=True, + ))) diff --git a/tests/test_context/configs/parallel_2d_init.py b/tests/test_context/configs/parallel_2d_init.py index 6af884450ad0..6cf816942fdd 100644 --- a/tests/test_context/configs/parallel_2d_init.py +++ b/tests/test_context/configs/parallel_2d_init.py @@ -1,10 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=4, - mode='2d' - ) -) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=4, mode='2d')) diff --git a/tests/test_context/configs/parallel_2p5d_init.py b/tests/test_context/configs/parallel_2p5d_init.py index c2d896d383e2..b946d45b3a91 100644 --- a/tests/test_context/configs/parallel_2p5d_init.py +++ b/tests/test_context/configs/parallel_2p5d_init.py @@ -1,11 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=8, - depth=2, - mode='2.5d' - ) -) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, depth=2, mode='2.5d')) diff --git a/tests/test_context/configs/parallel_3d_init.py b/tests/test_context/configs/parallel_3d_init.py index 0ec724f8bb4f..a1564bbb2d51 100644 --- a/tests/test_context/configs/parallel_3d_init.py +++ b/tests/test_context/configs/parallel_3d_init.py @@ -1,10 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -parallel = dict( - pipeline=dict(size=2), - tensor=dict( - size=8, - mode='3d' - ) -) +parallel = dict(pipeline=dict(size=2), tensor=dict(size=8, mode='3d')) diff --git a/tests/test_context/test_hybrid_parallel.py b/tests/test_context/test_hybrid_parallel.py index f311b1d2e736..fc13e0e35e7c 100644 --- a/tests/test_context/test_hybrid_parallel.py +++ b/tests/test_context/test_hybrid_parallel.py @@ -3,17 +3,18 @@ from functools import partial from pathlib import Path + import pytest import torch import torch.multiprocessing as mp from colossalai import launch +from colossalai.context import reset_seeds from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import free_port -from colossalai.context import reset_seeds from colossalai.global_variables import tensor_parallel_env as tp_env from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py')) diff --git a/tests/test_data/test_cifar10_dataset.py b/tests/test_data/test_cifar10_dataset.py index 4b9ca61d9f17..501bfc2b957a 100644 --- a/tests/test_data/test_cifar10_dataset.py +++ b/tests/test_data/test_cifar10_dataset.py @@ -5,8 +5,8 @@ from pathlib import Path import pytest -from torchvision import transforms, datasets from torch.utils.data import DataLoader +from torchvision import datasets, transforms @pytest.mark.cpu diff --git a/tests/test_data/test_data_parallel_sampler.py b/tests/test_data/test_data_parallel_sampler.py index 54fa44bdc0c2..7a65e64c7c5d 100644 --- a/tests/test_data/test_data_parallel_sampler.py +++ b/tests/test_data/test_data_parallel_sampler.py @@ -9,13 +9,13 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp +from torchvision import datasets, transforms import colossalai -from torchvision import transforms, datasets -from colossalai.context import ParallelMode, Config +from colossalai.context import Config, ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader, free_port from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_dataloader CONFIG = Config(dict( parallel=dict( diff --git a/tests/test_data/test_deterministic_dataloader.py b/tests/test_data/test_deterministic_dataloader.py index 4d76e7f137f1..da934053ebf4 100644 --- a/tests/test_data/test_deterministic_dataloader.py +++ b/tests/test_data/test_deterministic_dataloader.py @@ -9,14 +9,13 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from torchvision import transforms, datasets +from torchvision import datasets, transforms import colossalai -from colossalai.context import ParallelMode, Config +from colossalai.context import Config, ParallelMode from colossalai.core import global_context as gpc -from colossalai.utils import get_dataloader, free_port from colossalai.testing import rerun_if_address_is_in_use -from torchvision import transforms +from colossalai.utils import free_port, get_dataloader CONFIG = Config( dict( diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py index 3c2390c92837..64f2cf68f364 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor.py @@ -1,25 +1,24 @@ import os - from functools import partial from pathlib import Path -import colossalai import pytest import torch import torch.multiprocessing as mp +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +import colossalai from colossalai.amp import AMP_TYPE -from colossalai.trainer import Trainer, hooks from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from colossalai.utils import free_port from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn import CrossEntropyLoss from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import get_dataloader from colossalai.pipeline.pipelinable import PipelinableContext -from torchvision.datasets import CIFAR10 -from torchvision import transforms +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus +from colossalai.trainer import Trainer, hooks +from colossalai.utils import free_port, get_dataloader BATCH_SIZE = 4 NUM_EPOCHS = 60 diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py index 2bafe0f7e374..762a9b3741c5 100644 --- a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py @@ -1,111 +1,108 @@ -import os - -from functools import partial -from pathlib import Path - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.amp import AMP_TYPE -from colossalai.trainer import Trainer, hooks -from colossalai.context import ParallelMode -from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from colossalai.utils import free_port -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.nn import CrossEntropyLoss -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR -from colossalai.utils import get_dataloader -from colossalai.pipeline.pipelinable import PipelinableContext -from colossalai.logging import disable_existing_loggers -from torchvision.datasets import CIFAR10 -from torchvision import transforms - -from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 - -disable_existing_loggers() -BATCH_SIZE = 4 -NUM_EPOCHS = 10 -WARMUP_EPOCHS = 5 -CONFIG = dict(NUM_MICRO_BATCHES=2, - parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), - fp16=dict(mode=AMP_TYPE.NAIVE), - gradient_accumulation=2) - - -def run_trainer(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - disable_existing_loggers() - # get logger - logger = get_dist_logger() - - pipelinable = PipelinableContext() - try: - from titans.model.vit import vit_tiny_patch4_32 - except ImportError: - logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') - logger.warning('please install titan from https://github.com/hpcaitech/Titans') - return - with pipelinable: - model = vit_tiny_patch4_32() - pipelinable.to_layer_list() - pipelinable.policy = "uniform" - model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) - - # craete dataloaders - root = Path(os.environ['DATA']) - transform_train = transforms.Compose([ - transforms.RandomCrop(32, padding=4, pad_if_needed=True), - transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ]) - train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) - train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) - - # create loss function - criterion = CrossEntropyLoss(label_smoothing=0.1) - - # create optimizer - optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) - - # create lr scheduler - lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) - - # intiailize - engine, train_dataloader, *_ = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) - - logger = get_dist_logger() - - trainer = Trainer(engine=engine, logger=logger) - - hook_list = [ - hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), - ] - - trainer.fit(train_dataloader=train_dataloader, - max_steps=2, - epochs=NUM_EPOCHS, - hooks=hook_list, - display_progress=True) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_hybrid_parallel(): - world_size = 2 - run_func = partial(run_trainer, world_size=world_size, port=free_port()) - disable_existing_loggers() - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_hybrid_parallel() +import os +from functools import partial +from pathlib import Path + +import pytest +import torch +import torch.multiprocessing as mp +from torchvision import transforms +from torchvision.datasets import CIFAR10 + +import colossalai +from colossalai.amp import AMP_TYPE +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus +from colossalai.trainer import Trainer, hooks +from colossalai.utils import free_port, get_dataloader + +disable_existing_loggers() +BATCH_SIZE = 4 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 5 +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) + + +def run_trainer(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # get logger + logger = get_dist_logger() + + pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return + with pipelinable: + model = vit_tiny_patch4_32() + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + + # craete dataloaders + root = Path(os.environ['DATA']) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) + + # intiailize + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) + + logger = get_dist_logger() + + trainer = Trainer(engine=engine, logger=logger) + + hook_list = [ + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + ] + + trainer.fit(train_dataloader=train_dataloader, + max_steps=2, + epochs=NUM_EPOCHS, + hooks=hook_list, + display_progress=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_hybrid_parallel(): + world_size = 2 + run_func = partial(run_trainer, world_size=world_size, port=free_port()) + disable_existing_loggers() + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_hybrid_parallel() diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index f229364c6eb1..af812ea59e55 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -1,18 +1,19 @@ import copy +from collections import OrderedDict +from functools import partial import pytest -import colossalai import torch import torch.multiprocessing as mp + +import colossalai +from colossalai.nn.parallel import ColoDDP +from colossalai.tensor import ColoParameter, ProcessGroup from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from functools import partial from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.nn.parallel import ColoDDP -from collections import OrderedDict -from colossalai.tensor import ProcessGroup, ColoParameter def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): diff --git a/tests/test_ddp/test_reducer.py b/tests/test_ddp/test_reducer.py index 5b302d99ffb1..4fbd983fc8ff 100644 --- a/tests/test_ddp/test_reducer.py +++ b/tests/test_ddp/test_reducer.py @@ -1,14 +1,16 @@ +from functools import partial + import pytest -import colossalai import torch +import torch.distributed as dist import torch.multiprocessing as mp +from torch.distributed.distributed_c10d import _get_default_group + +import colossalai +from colossalai.nn.parallel.reducer import Reducer from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port -from functools import partial -from colossalai.nn.parallel.reducer import Reducer -import torch.distributed as dist -from torch.distributed.distributed_c10d import _get_default_group +from colossalai.utils.cuda import get_current_device REDUCE_CNT = 0 diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 3be057b3a98b..789ce8ab35b8 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,6 +1,7 @@ -from colossalai.device.device_mesh import DeviceMesh import torch +from colossalai.device.device_mesh import DeviceMesh + def test_device_mesh(): physical_mesh_id = torch.arange(0, 16).reshape(2, 8) diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py index 3172897fb5cd..ad7efe6cda22 100644 --- a/tests/test_device/test_init_logical_pg.py +++ b/tests/test_device/test_init_logical_pg.py @@ -1,15 +1,16 @@ -import torch from functools import partial + import pytest +import torch import torch.distributed as dist import torch.multiprocessing as mp from torch.distributed import ReduceOp from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch -from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use -from colossalai.device.device_mesh import DeviceMesh +from colossalai.utils import free_port def check_layer(rank, world_size, port): diff --git a/tests/test_engine/test_engine.py b/tests/test_engine/test_engine.py index fb5bd1e1602e..181e2250bfbe 100644 --- a/tests/test_engine/test_engine.py +++ b/tests/test_engine/test_engine.py @@ -1,13 +1,14 @@ from functools import partial -import colossalai import pytest import torch.multiprocessing as mp + +import colossalai from colossalai.amp import AMP_TYPE from colossalai.core import global_context as gpc +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize, rerun_if_address_is_in_use CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), diff --git a/tests/test_engine/test_gradient_accumluation.py b/tests/test_engine/test_gradient_accumluation.py index 7f5ee47be8e6..b7c50122d3ed 100644 --- a/tests/test_engine/test_gradient_accumluation.py +++ b/tests/test_engine/test_gradient_accumluation.py @@ -2,21 +2,22 @@ from functools import partial from pathlib import Path -import colossalai -from colossalai.testing.utils import rerun_if_address_is_in_use import pytest import torch import torch.multiprocessing as mp import torch.nn as nn -from colossalai.core import global_context as gpc -from colossalai.logging import get_dist_logger -from colossalai.utils import free_port, get_dataloader -from colossalai.testing import rerun_if_address_is_in_use from torch.optim import Adam from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 +import colossalai +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.testing.utils import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_dataloader + # Config BATCH_SIZE = 2 NUM_CLASSES = 10 diff --git a/tests/test_fx/test_coloproxy.py b/tests/test_fx/test_coloproxy.py index 2bb6cf86466c..14f6c92b6b07 100644 --- a/tests/test_fx/test_coloproxy.py +++ b/tests/test_fx/test_coloproxy.py @@ -1,9 +1,10 @@ +import pytest import torch import torch.nn as nn +from torch.fx import GraphModule + from colossalai.fx.proxy import ColoProxy from colossalai.fx.tracer.tracer import ColoTracer -from torch.fx import GraphModule -import pytest class Conv1D(nn.Module): diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index 8825bbb461d6..438915bee113 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -1,13 +1,14 @@ -import colossalai -import colossalai.nn as col_nn import pytest import torch import torch.nn as nn +from torch.fx import symbolic_trace + +import colossalai +import colossalai.nn as col_nn from colossalai.fx._compatibility import is_compatible_with_meta -from colossalai.fx.passes.adding_split_node_pass import (split_with_split_nodes_pass, uniform_split_pass) +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.utils import get_comm_size -from torch.fx import symbolic_trace is_compatible = is_compatible_with_meta() if is_compatible: diff --git a/tests/test_fx/test_graph_manipulation.py b/tests/test_fx/test_graph_manipulation.py index fb33e58a778c..f0c00f4f07bc 100644 --- a/tests/test_fx/test_graph_manipulation.py +++ b/tests/test_fx/test_graph_manipulation.py @@ -1,9 +1,10 @@ -import colossalai import torch -from colossalai.fx.passes.utils import get_leaf, get_top, assign_bfs_level_to_nodes -from colossalai.fx import ColoTracer from torch.fx import GraphModule + +import colossalai +from colossalai.fx import ColoTracer from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata +from colossalai.fx.passes.utils import assign_bfs_level_to_nodes, get_leaf, get_top class MLP(torch.nn.Module): diff --git a/tests/test_fx/test_meta/test_aten.py b/tests/test_fx/test_meta/test_aten.py index 209ded89cfb9..2a444489c7f2 100644 --- a/tests/test_fx/test_meta/test_aten.py +++ b/tests/test_fx/test_meta/test_aten.py @@ -3,6 +3,7 @@ import pytest import torch import torch.nn as nn + from colossalai.fx._compatibility import is_compatible_with_meta if is_compatible_with_meta(): diff --git a/tests/test_fx/test_meta/test_backward.py b/tests/test_fx/test_meta/test_backward.py index 351c02c5744a..28fb9f442e94 100644 --- a/tests/test_fx/test_meta/test_backward.py +++ b/tests/test_fx/test_meta/test_backward.py @@ -2,6 +2,7 @@ import timm.models as tmm import torch import torchvision.models as tm + from colossalai.fx._compatibility import is_compatible_with_meta if is_compatible_with_meta(): diff --git a/tests/test_fx/test_meta/test_meta_trace.py b/tests/test_fx/test_meta/test_meta_trace.py index 404b6d27d2d4..b671a3086e2e 100644 --- a/tests/test_fx/test_meta/test_meta_trace.py +++ b/tests/test_fx/test_meta/test_meta_trace.py @@ -2,6 +2,7 @@ import timm.models as tmm import torch import torchvision.models as tm + from colossalai.fx._compatibility import is_compatible_with_meta if is_compatible_with_meta(): diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 6fac180d8ba2..d23059a2beec 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -1,7 +1,8 @@ import torch +from torch.fx import symbolic_trace + from colossalai.fx._compatibility import is_compatible_with_meta from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata -from torch.fx import symbolic_trace if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor diff --git a/tests/test_fx/test_parallel_1d.py b/tests/test_fx/test_parallel_1d.py index 8963ba29cb03..714441138d10 100644 --- a/tests/test_fx/test_parallel_1d.py +++ b/tests/test_fx/test_parallel_1d.py @@ -6,13 +6,14 @@ import pytest import torch import torch.multiprocessing as mp +from torch.fx import symbolic_trace + from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers +from colossalai.fx.passes import column_shard_linear_pass from colossalai.initialize import launch -from colossalai.utils import free_port +from colossalai.logging import disable_existing_loggers from colossalai.testing import rerun_if_address_is_in_use -from torch.fx import symbolic_trace -from colossalai.fx.passes import column_shard_linear_pass +from colossalai.utils import free_port class MLP(torch.nn.Module): diff --git a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py index 3afc6c97e2bb..2bd229192598 100644 --- a/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py +++ b/tests/test_fx/test_pipeline/test_hf_model/hf_utils.py @@ -1,11 +1,12 @@ -import torch -from torch.fx import symbolic_trace -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer import inspect import random + import numpy as np +import torch +from torch.fx import GraphModule, symbolic_trace + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) diff --git a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py index aa870e5c7a65..4cb7fa158615 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py +++ b/tests/test_fx/test_pipeline/test_timm_model/timm_utils.py @@ -1,11 +1,12 @@ -import torch -from torch.fx import symbolic_trace -from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer import inspect import random + import numpy as np +import torch +from torch.fx import GraphModule, symbolic_trace + +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass MANUAL_SEED = 0 random.seed(MANUAL_SEED) diff --git a/tests/test_fx/test_pipeline/test_topo/test_topo.py b/tests/test_fx/test_pipeline/test_topo/test_topo.py index 75c74870523c..68f4f5e55210 100644 --- a/tests/test_fx/test_pipeline/test_topo/test_topo.py +++ b/tests/test_fx/test_pipeline/test_topo/test_topo.py @@ -1,11 +1,12 @@ import pytest import torch import transformers -from topo_utils import split_model_and_get_DAG, check_topo, MLP +from topo_utils import MLP, check_topo, split_model_and_get_DAG BATCH_SIZE = 1 SEQ_LENGHT = 16 + def test_opt(): MODEL_LIST = [ MLP, @@ -13,7 +14,10 @@ def test_opt(): ] CONFIGS = [ - {'dim': 10, 'layers': 12}, + { + 'dim': 10, + 'layers': 12 + }, transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4), ] @@ -21,15 +25,15 @@ def data_gen_MLP(): x = torch.zeros((16, 10)) kwargs = dict(x=x) return kwargs - + def data_gen_OPT(): input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) return kwargs - + DATAGEN = [ - data_gen_MLP, + data_gen_MLP, data_gen_OPT, ] @@ -39,5 +43,6 @@ def data_gen_OPT(): # print(f'{top_mod=}\n----\n{topo=}') check_topo(top_mod, topo) + if __name__ == '__main__': - test_opt() \ No newline at end of file + test_opt() diff --git a/tests/test_fx/test_pipeline/test_topo/topo_utils.py b/tests/test_fx/test_pipeline/test_topo/topo_utils.py index 55dd65201acd..5fd13399ac84 100644 --- a/tests/test_fx/test_pipeline/test_topo/topo_utils.py +++ b/tests/test_fx/test_pipeline/test_topo/topo_utils.py @@ -1,18 +1,22 @@ +import random + +import numpy as np import torch from torch.fx import GraphModule -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass + from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo from colossalai.pipeline.middleware.adaptor import get_fx_topology -import random -import numpy as np MANUAL_SEED = 0 random.seed(MANUAL_SEED) np.random.seed(MANUAL_SEED) torch.manual_seed(MANUAL_SEED) + class MLP(torch.nn.Module): + def __init__(self, config={}): super().__init__() dim = config['dim'] @@ -27,6 +31,7 @@ def forward(self, x): x = layer(x) return x + def split_model_and_get_DAG(model, data_gen): model.eval() @@ -46,7 +51,7 @@ def split_model_and_get_DAG(model, data_gen): # apply transform passes annotated_model = balanced_split_pass(gm, 2) top_module, split_submodules = split_with_split_nodes_pass(annotated_model) - + topo = get_fx_topology(top_module) for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): @@ -54,6 +59,7 @@ def split_model_and_get_DAG(model, data_gen): return top_module, split_submodules[0]._topo + def check_input(top_module, input_partition: Partition): partition_output = input_partition.get_output_vals() arg_pos = 0 @@ -63,13 +69,14 @@ def check_input(top_module, input_partition: Partition): to_partition_and_offset = cur_checkee.get() assert len(to_partition_and_offset) == len(node.users.keys()) arg_pos += 1 - + assert arg_pos == len(partition_output) - + + def check_submod(top_module, part_id, mid_partition: Partition): partition_input = mid_partition.get_input_vals() partition_output = mid_partition.get_output_vals() - + cnt = 1 cur_node = None for node in top_module.graph.nodes: @@ -78,15 +85,15 @@ def check_submod(top_module, part_id, mid_partition: Partition): if cnt == part_id: cur_node = node break - + assert len(partition_input) == len(cur_node.args) assert len(partition_output) == len(cur_node.users) -def check_topo(top_module, topo: Topo): + +def check_topo(top_module, topo: Topo): input_partition = topo.get_input_partition() mid_partitions = topo.get_mid_partitions() - + check_input(top_module, input_partition) for part_id, submod in mid_partitions.items(): check_submod(top_module, part_id, submod) - \ No newline at end of file diff --git a/tests/test_fx/test_pipeline_passes.py b/tests/test_fx/test_pipeline_passes.py index de8a9402ba56..cc86a738fba1 100644 --- a/tests/test_fx/test_pipeline_passes.py +++ b/tests/test_fx/test_pipeline_passes.py @@ -1,12 +1,16 @@ +import pytest import torch import torch.nn as nn -import colossalai -import colossalai.nn as col_nn from torch.fx import symbolic_trace -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass, \ - uniform_split_pass, balanced_split_pass_v2 -import pytest +import colossalai +import colossalai.nn as col_nn +from colossalai.fx.passes.adding_split_node_pass import ( + balanced_split_pass, + balanced_split_pass_v2, + split_with_split_nodes_pass, + uniform_split_pass, +) MODEL_DIM = 16 BATCH_SIZE = 8 diff --git a/tests/test_fx/test_tracer/test_control_flow.py b/tests/test_fx/test_tracer/test_control_flow.py index ed842cff2776..a5b40488577e 100644 --- a/tests/test_fx/test_tracer/test_control_flow.py +++ b/tests/test_fx/test_tracer/test_control_flow.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn from torch.fx import GraphModule + from colossalai.fx import ColoTracer as Tracer diff --git a/tests/test_fx/test_tracer/test_functional_conv.py b/tests/test_fx/test_tracer/test_functional_conv.py index 95670b85f335..62ef7119e25e 100644 --- a/tests/test_fx/test_tracer/test_functional_conv.py +++ b/tests/test_fx/test_tracer/test_functional_conv.py @@ -1,5 +1,6 @@ import torch from torch.nn import functional as F + from colossalai.fx.tracer.meta_patch import patched_function diff --git a/tests/test_fx/test_tracer/test_patched_module.py b/tests/test_fx/test_tracer/test_patched_module.py index 94a93e16f3c7..b6fa06ca884a 100644 --- a/tests/test_fx/test_tracer/test_patched_module.py +++ b/tests/test_fx/test_tracer/test_patched_module.py @@ -1,4 +1,5 @@ import torch + from colossalai.fx.tracer.meta_patch import patched_module diff --git a/tests/test_fx/test_tracer/test_patched_op.py b/tests/test_fx/test_tracer/test_patched_op.py index 4406f02db24b..5aeec84f18b2 100644 --- a/tests/test_fx/test_tracer/test_patched_op.py +++ b/tests/test_fx/test_tracer/test_patched_op.py @@ -1,6 +1,8 @@ +from functools import partial + import torch + from colossalai.fx.tracer.meta_patch import patched_function -from functools import partial def _run(data, patch_fn): diff --git a/tests/test_gemini/test_gemini_manager.py b/tests/test_gemini/test_gemini_manager.py index 0c138f101f75..755053d4041a 100644 --- a/tests/test_gemini/test_gemini_manager.py +++ b/tests/test_gemini/test_gemini_manager.py @@ -1,73 +1,73 @@ -import pytest -import torch - -from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor - - -@pytest.mark.dist -def test_gemini_manager(): - # reset the manager, in case that there exists memory information left - manager = StatefulTensor.GST_MGR - manager.reset() - - # occupation 8 - st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) - # occupation 60 - st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) - - # occupation 28 - t1 = torch.empty(7, device='cuda') - # occupation 12 - t2 = torch.empty(3, device='cpu') - st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) - st4 = StatefulTensor(None, TensorState.FREE) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 60 - assert manager.total_mem['cuda'] == 36 - assert manager.state_mem['cpu'][TensorState.HOLD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 - - st4.payload_reset(t2) - st3.payload_reset(t2) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 84 - assert manager.total_mem['cuda'] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD] == 72 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 - - st1.move_to(torch.device('cpu')) - st2.move_to(torch.device('cpu')) - st3.move_to(torch.device('cuda', 0)) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 80 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - - st1.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.HOLD_AFTER_BWD) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 - assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 - assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 - - -if __name__ == '__main__': - test_gemini_manager() +import pytest +import torch + +from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState + + +@pytest.mark.dist +def test_gemini_manager(): + # reset the manager, in case that there exists memory information left + manager = StatefulTensor.GST_MGR + manager.reset() + + # occupation 8 + st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) + # occupation 60 + st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) + + # occupation 28 + t1 = torch.empty(7, device='cuda') + # occupation 12 + t2 = torch.empty(3, device='cpu') + st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) + st4 = StatefulTensor(None, TensorState.FREE) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 60 + assert manager.total_mem['cuda'] == 36 + assert manager.state_mem['cpu'][TensorState.HOLD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 + + st4.payload_reset(t2) + st3.payload_reset(t2) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 84 + assert manager.total_mem['cuda'] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD] == 72 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 + + st1.move_to(torch.device('cpu')) + st2.move_to(torch.device('cpu')) + st3.move_to(torch.device('cuda', 0)) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 80 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + + st1.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.HOLD_AFTER_BWD) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 + assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 + assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 + + +if __name__ == '__main__': + test_gemini_manager() diff --git a/tests/test_layers/test_1d/checks_1d/common.py b/tests/test_layers/test_1d/checks_1d/common.py index 8b7b28613d22..29a9a3d20330 100644 --- a/tests/test_layers/test_1d/checks_1d/common.py +++ b/tests/test_layers/test_1d/checks_1d/common.py @@ -1,15 +1,16 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import torch - -DEPTH = 4 -BATCH_SIZE = 8 -SEQ_LENGTH = 8 -IMG_SIZE = 16 -HIDDEN_SIZE = 8 -NUM_CLASSES = 8 -VOCAB_SIZE = 16 - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DEPTH = 4 +BATCH_SIZE = 8 +SEQ_LENGTH = 8 +IMG_SIZE = 16 +HIDDEN_SIZE = 8 +NUM_CLASSES = 8 +VOCAB_SIZE = 16 + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True diff --git a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py b/tests/test_layers/test_2d/checks_2d/check_layer_2d.py index e030e473a363..27de3bddb29e 100644 --- a/tests/test_layers/test_2d/checks_2d/check_layer_2d.py +++ b/tests/test_layers/test_2d/checks_2d/check_layer_2d.py @@ -1,12 +1,23 @@ import torch + from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn import (Classifier2D, CrossEntropyLoss2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, - VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2D, - VocabParallelCrossEntropyLoss2D, VocabParallelEmbedding2D) +from colossalai.nn import ( + Classifier2D, + CrossEntropyLoss2D, + Embedding2D, + LayerNorm2D, + Linear2D, + PatchEmbedding2D, + VanillaClassifier, + VanillaPatchEmbedding, + VocabParallelClassifier2D, + VocabParallelCrossEntropyLoss2D, + VocabParallelEmbedding2D, +) from colossalai.utils import get_current_device, print_rank_0 -from .common import (BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal) +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal def check_linear(): @@ -336,7 +347,7 @@ def check_classifier_no_given_weight(): layer.weight.data.copy_(W) # W.requires_grad = True - B_shape = (OUTPUT_SIZE, ) + B_shape = (OUTPUT_SIZE,) B_master = torch.randint(5, B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) # B = torch.chunk(B_master, DEPTH, dim=0)[j] @@ -572,7 +583,7 @@ def check_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i] @@ -607,7 +618,7 @@ def check_vocab_parallel_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i] diff --git a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py index a5e37b1ec309..73f687879bce 100644 --- a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -6,9 +6,9 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D -from colossalai.utils import get_current_device -from colossalai.utils import print_rank_0 -from .common import check_equal, BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE, DEPTH +from colossalai.utils import get_current_device, print_rank_0 + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal def check_AB(): diff --git a/tests/test_layers/test_2d/test_2d.py b/tests/test_layers/test_2d/test_2d.py index da235d0cf168..332f5f89e80a 100644 --- a/tests/test_layers/test_2d/test_2d.py +++ b/tests/test_layers/test_2d/test_2d.py @@ -6,17 +6,26 @@ import pytest import torch import torch.multiprocessing as mp +from checks_2d.check_layer_2d import ( + check_classifier_given_embed_weight, + check_classifier_no_given_weight, + check_embed, + check_layernorm, + check_linear, + check_loss, + check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, + check_vocab_parallel_embed, + check_vocab_parallel_loss, +) +from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use -from checks_2d.check_layer_2d import (check_classifier_given_embed_weight, check_classifier_no_given_weight, - check_embed, check_layernorm, check_linear, check_loss, check_patch_embed, - check_vocab_parallel_classifier_given_embed_weight, - check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, - check_vocab_parallel_loss) -from checks_2d.check_operation_2d import check_AB, check_ABT, check_ATB +from colossalai.utils import free_port CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='2d')),) diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py b/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py index a8f551093b1e..e36b64229ad2 100644 --- a/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py +++ b/tests/test_layers/test_2p5d/checks_2p5d/check_layer_2p5d.py @@ -1,11 +1,22 @@ import torch +from torch.nn import Parameter + from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn import (Classifier2p5D, CrossEntropyLoss2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, - PatchEmbedding2p5D, VanillaClassifier, VanillaPatchEmbedding, VocabParallelClassifier2p5D, - VocabParallelCrossEntropyLoss2p5D, VocabParallelEmbedding2p5D) +from colossalai.nn import ( + Classifier2p5D, + CrossEntropyLoss2p5D, + Embedding2p5D, + LayerNorm2p5D, + Linear2p5D, + PatchEmbedding2p5D, + VanillaClassifier, + VanillaPatchEmbedding, + VocabParallelClassifier2p5D, + VocabParallelCrossEntropyLoss2p5D, + VocabParallelEmbedding2p5D, +) from colossalai.utils import get_current_device, print_rank_0 -from torch.nn import Parameter from .common import * @@ -342,7 +353,7 @@ def check_classifier_no_given_weight(): layer.weight.data.copy_(W) # W.requires_grad = True - B_shape = (OUTPUT_SIZE, ) + B_shape = (OUTPUT_SIZE,) B_master = torch.randint(5, B_shape, dtype=dtype, device=device) torch.distributed.broadcast(B_master, src=0) # B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[j] @@ -577,7 +588,7 @@ def check_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] @@ -612,7 +623,7 @@ def check_vocab_parallel_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, TESSERACT_DIM, dim=0)[i] diff --git a/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py b/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py index d0c3b02fccba..f4134600067b 100644 --- a/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py +++ b/tests/test_layers/test_2p5d/checks_2p5d/check_operation_2p5d.py @@ -2,10 +2,9 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, \ - Matmul_ATB_2p5D -from colossalai.utils import get_current_device -from colossalai.utils import print_rank_0 +from colossalai.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D +from colossalai.utils import get_current_device, print_rank_0 + from .common import * diff --git a/tests/test_layers/test_2p5d/checks_2p5d/common.py b/tests/test_layers/test_2p5d/checks_2p5d/common.py index aff85f109666..c90d8fc086bd 100644 --- a/tests/test_layers/test_2p5d/checks_2p5d/common.py +++ b/tests/test_layers/test_2p5d/checks_2p5d/common.py @@ -11,4 +11,4 @@ def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) \ No newline at end of file + assert torch.allclose(A, B, rtol=1e-5, atol=1e-2) diff --git a/tests/test_layers/test_2p5d/test_2p5d.py b/tests/test_layers/test_2p5d/test_2p5d.py index 365e2d934df8..50a3760ee3e1 100644 --- a/tests/test_layers/test_2p5d/test_2p5d.py +++ b/tests/test_layers/test_2p5d/test_2p5d.py @@ -3,13 +3,14 @@ import pytest import torch import torch.multiprocessing as mp +from checks_2p5d.check_layer_2p5d import * +from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use -from checks_2p5d.check_layer_2p5d import * -from checks_2p5d.check_operation_2p5d import check_AB, check_ABT, check_ATB +from colossalai.utils import free_port CONFIG = dict(parallel=dict( pipeline=dict(size=1), diff --git a/tests/test_layers/test_3d/checks_3d/common.py b/tests/test_layers/test_3d/checks_3d/common.py index afb19c4745cc..509fc2cecf59 100644 --- a/tests/test_layers/test_3d/checks_3d/common.py +++ b/tests/test_layers/test_3d/checks_3d/common.py @@ -16,4 +16,4 @@ def check_equal(A, B): eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2) assert eq, f"\nA = {A}\nB = {B}" - return eq \ No newline at end of file + return eq diff --git a/tests/test_layers/test_3d/test_3d.py b/tests/test_layers/test_3d/test_3d.py index 29a8b3aea239..cfae65f942ab 100644 --- a/tests/test_layers/test_3d/test_3d.py +++ b/tests/test_layers/test_3d/test_3d.py @@ -5,15 +5,24 @@ import pytest import torch import torch.multiprocessing as mp +from checks_3d.check_layer_3d import ( + check_classifier_no_given_weight, + check_embed, + check_layernorm, + check_linear, + check_loss, + check_patch_embed, + check_vocab_parallel_classifier_given_embed_weight, + check_vocab_parallel_classifier_no_given_weight, + check_vocab_parallel_embed, + check_vocab_parallel_loss, +) + from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus -from checks_3d.check_layer_3d import (check_classifier_no_given_weight, check_embed, check_layernorm, check_linear, - check_loss, check_patch_embed, check_vocab_parallel_classifier_given_embed_weight, - check_vocab_parallel_classifier_no_given_weight, check_vocab_parallel_embed, - check_vocab_parallel_loss) +from colossalai.utils import free_port CONFIG = dict( parallel=dict( diff --git a/tests/test_layers/test_cache_embedding.py b/tests/test_layers/test_cache_embedding.py index cff9072c7a06..bcfa017c6274 100644 --- a/tests/test_layers/test_cache_embedding.py +++ b/tests/test_layers/test_cache_embedding.py @@ -1,20 +1,32 @@ -import pytest +import random from functools import partial +from typing import List import numpy as np -import random - +import pytest import torch import torch.multiprocessing as mp import colossalai -from colossalai.utils import free_port +from colossalai.nn.parallel.layers import ( + CachedEmbeddingBag, + CachedParamMgr, + EvictionStrategy, + ParallelCachedEmbeddingBag, + ParallelCachedEmbeddingBagTablewise, + TablewiseEmbeddingBagConfig, +) +from colossalai.tensor import ( + ColoParameter, + ColoTensor, + ColoTensorSpec, + ComputePattern, + ComputeSpec, + ProcessGroup, + ShardSpec, +) from colossalai.testing import rerun_if_address_is_in_use -from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec, \ - ColoTensor, ColoTensorSpec -from colossalai.nn.parallel.layers import CachedParamMgr, CachedEmbeddingBag, ParallelCachedEmbeddingBag, EvictionStrategy, \ - ParallelCachedEmbeddingBagTablewise, TablewiseEmbeddingBagConfig -from typing import List +from colossalai.utils import free_port NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 @@ -248,7 +260,7 @@ def run_parallel_freq_aware_embed_tablewise(rank, world_size): input0 [1,2,3] [6,7] [] input1 [] [9] [13,15] input2 [1,5] [6,8] [11] - ↑ ↑ ↑ + ↑ ↑ ↑ rank 0 rank 0 rank 1 in KJT format ''' diff --git a/tests/test_layers/test_sequence/test_sequence.py b/tests/test_layers/test_sequence/test_sequence.py index 3862c4ccd439..5309ea37bc41 100644 --- a/tests/test_layers/test_sequence/test_sequence.py +++ b/tests/test_layers/test_sequence/test_sequence.py @@ -1,14 +1,15 @@ -import colossalai -import colossalai.nn as col_nn +from functools import partial + +import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp -import pytest -from colossalai.core import global_context as gpc +import colossalai +import colossalai.nn as col_nn from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use -from functools import partial CONFIG = dict(parallel=dict(tensor=dict(size=4, mode='sequence'))) diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index e7b9a55277c6..cb297124cb88 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -1,16 +1,18 @@ from functools import partial + import pytest import torch -import torch.nn as nn -import torch.multiprocessing as mp import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn + import colossalai -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.utils.moe import sync_moe_model_param from colossalai.engine.gradient_handler import MoeGradientHandler +from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.moe import sync_moe_model_param BATCH_SIZE = 4 DIM = 16 diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 62f9241642b9..e1337850034f 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -1,15 +1,17 @@ from functools import partial + import pytest import torch -import torch.nn as nn import torch.multiprocessing as mp +import torch.nn as nn + import colossalai from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts from colossalai.context.moe_context import MOE_CONTEXT +from colossalai.core import global_context as gpc +from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device BATCH_SIZE = 16 NUM_EXPERTS = 4 diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index ae0c1390c129..40500a3de859 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -1,63 +1,60 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import parameterize -from colossalai.utils import free_port -from colossalai.context import MOE_CONTEXT -from colossalai.tensor import ColoParameter -from colossalai.utils.model.colo_init_context import ColoInitContext - -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import get_current_device - -from tests.test_zero.common import CONFIG -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_tensor.common_utils import debug_print - - -@parameterize("init_device_type", ['cpu', 'cuda']) -def exam_moe_colo_init(init_device_type): - world_size = dist.get_world_size() - - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - with ColoInitContext(device=init_device): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) - - if hasattr(param, "moe_info"): - param.set_process_group(param.moe_info.pg) - - if hasattr(param, "moe_info"): - assert param.process_group.dp_world_size() == param.moe_info.dp_size - else: - assert param.process_group.dp_world_size() == world_size - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) - exam_moe_colo_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_moe_colo_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_moe_colo_init(world_size=4) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.tensor import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_tensor.common_utils import debug_print +from tests.test_zero.common import CONFIG + + +@parameterize("init_device_type", ['cpu', 'cuda']) +def exam_moe_colo_init(init_device_type): + world_size = dist.get_world_size() + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + with ColoInitContext(device=init_device): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) + + if hasattr(param, "moe_info"): + param.set_process_group(param.moe_info.pg) + + if hasattr(param, "moe_info"): + assert param.process_group.dp_world_size() == param.moe_info.dp_size + else: + assert param.process_group.dp_world_size() == world_size + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + exam_moe_colo_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_colo_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 3126f59e246e..8f39b9cf5624 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -1,14 +1,16 @@ from functools import partial + import pytest -import torch.nn as nn -import torch.multiprocessing as mp import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn + import colossalai -from colossalai.utils import free_port, get_current_device -from colossalai.nn.layer.moe import Experts from colossalai.context.moe_context import MOE_CONTEXT -from colossalai.utils.moe import sync_moe_model_param +from colossalai.nn.layer.moe import Experts from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.utils.moe import sync_moe_model_param D_MODEL = 4 D_FF = 8 diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 04dc9c514dd0..f83d118c1032 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -1,114 +1,112 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.nn import CheckpointModule -from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize -from colossalai.utils import free_port -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer import MoeModule -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) - -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import get_current_device -from tests.test_zero.common import CONFIG - - -class MoeModel(nn.Module): - - def __init__(self, checkpoint: bool = False): - - class TestSubModule(CheckpointModule): - - def __init__(self): - super().__init__(checkpoint) - expert_cls = nn.Linear - expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, - num_experts=8, - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) - self.proj = nn.Linear(16, 4) - - def _forward(self, x): - x, y = self.moe(x) - x = self.proj(x) - return x, y - - super().__init__() - self.test_embed = nn.Linear(4, 16) - self.test_transform = TestSubModule() - - def forward(self, x): - MOE_CONTEXT.reset_loss() - - x = self.test_embed(x) - x, y = self.test_transform(x) - - MOE_CONTEXT.add_loss(y) - return x - - -@parameterize("init_device_type", ['cpu', 'cuda']) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_moe_zero_init(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_moe_zero_init") - - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert hasattr(param, 'colo_attr') - - # the parameters in moe experts and its gate should not be sharded - if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): - assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) - else: - assert param.colo_attr.sharded_data_tensor.is_sharded - - # the parameters in moe experts is not replicated - if 'experts' in name: - assert not param.colo_attr.is_replicated - else: - assert param.colo_attr.is_replicated - - if param.colo_attr.param_is_sharded: - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' - else: - assert param.colo_attr.data_payload.device.type == 'cuda' - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) - run_moe_zero_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_moe_zero_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_moe_zero_init(world_size=2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.logging import get_dist_logger +from colossalai.nn import CheckpointModule +from colossalai.nn.layer import MoeModule +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from tests.test_zero.common import CONFIG + + +class MoeModel(nn.Module): + + def __init__(self, checkpoint: bool = False): + + class TestSubModule(CheckpointModule): + + def __init__(self): + super().__init__(checkpoint) + expert_cls = nn.Linear + expert_args_dict = dict(in_features=16, out_features=16) + self.moe = MoeModule(dim_model=16, + num_experts=8, + use_residual=True, + expert_cls=expert_cls, + **expert_args_dict) + self.proj = nn.Linear(16, 4) + + def _forward(self, x): + x, y = self.moe(x) + x = self.proj(x) + return x, y + + super().__init__() + self.test_embed = nn.Linear(4, 16) + self.test_transform = TestSubModule() + + def forward(self, x): + MOE_CONTEXT.reset_loss() + + x = self.test_embed(x) + x, y = self.test_transform(x) + + MOE_CONTEXT.add_loss(y) + return x + + +@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_moe_zero_init(init_device_type, shard_strategy_class): + logger = get_dist_logger("test_moe_zero_init") + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + model_numel_tensor = torch.zeros(1, dtype=torch.int) + with ZeroInitContext(target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert hasattr(param, 'colo_attr') + + # the parameters in moe experts and its gate should not be sharded + if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) + else: + assert param.colo_attr.sharded_data_tensor.is_sharded + + # the parameters in moe experts is not replicated + if 'experts' in name: + assert not param.colo_attr.is_replicated + else: + assert param.colo_attr.is_replicated + + if param.colo_attr.param_is_sharded: + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + else: + assert param.colo_attr.data_payload.device.type == 'cuda' + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_moe_zero_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_moe_zero_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_init(world_size=2) diff --git a/tests/test_ops/test_addmm_tp.py b/tests/test_ops/test_addmm_tp.py index 5182868b5bbd..fa0eef74b33f 100644 --- a/tests/test_ops/test_addmm_tp.py +++ b/tests/test_ops/test_addmm_tp.py @@ -1,14 +1,15 @@ -import colossalai -import torch +from functools import partial + import pytest -import torch.nn as nn +import torch import torch.multiprocessing as mp -from colossalai.tensor import ColoTensor, ProcessGroup -from colossalai.tensor import ColoTensorSpec +import torch.nn as nn + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from functools import partial -from tests.test_tensor.common_utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal class Conv1D(nn.Module): diff --git a/tests/test_ops/test_embedding_bag_tp.py b/tests/test_ops/test_embedding_bag_tp.py index c7a1604e5455..a79a0cbb8df0 100644 --- a/tests/test_ops/test_embedding_bag_tp.py +++ b/tests/test_ops/test_embedding_bag_tp.py @@ -1,14 +1,15 @@ -from torch.nn import functional as F from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp +from torch.nn import functional as F + +import colossalai +from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d +from tests.test_tensor.common_utils import split_param_col_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func): diff --git a/tests/test_ops/test_embedding_tp.py b/tests/test_ops/test_embedding_tp.py index 541dc5c09324..3502872fe2bc 100644 --- a/tests/test_ops/test_embedding_tp.py +++ b/tests/test_ops/test_embedding_tp.py @@ -1,14 +1,15 @@ -from torch.nn import functional as F from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp +from torch.nn import functional as F + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func, pg: ProcessGroup): diff --git a/tests/test_ops/test_linear_tp.py b/tests/test_ops/test_linear_tp.py index 603e98564de8..a961db1c5a9d 100644 --- a/tests/test_ops/test_linear_tp.py +++ b/tests/test_ops/test_linear_tp.py @@ -1,14 +1,15 @@ from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp import torch.nn.functional as F + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor -from tests.test_tensor.common_utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d +from tests.test_tensor.common_utils import split_param_col_tp1d, split_param_row_tp1d, tensor_equal, tensor_shard_equal def run_with_spec(spec_init_func, split_bias): diff --git a/tests/test_ops/test_loss_func.py b/tests/test_ops/test_loss_func.py index 9210242a0a9f..c416f7dd624b 100644 --- a/tests/test_ops/test_loss_func.py +++ b/tests/test_ops/test_loss_func.py @@ -1,52 +1,52 @@ -import torch -import pytest -import colossalai -import torch.nn.functional as F -import torch.multiprocessing as mp -from functools import partial -from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec -from colossalai.utils import get_current_device -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern - - -def check_cross_entropy(): - input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) - input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) - with torch.no_grad(): - input_ct.copy_(input_t) - - target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) - - world_size = torch.distributed.get_world_size() - pg = ProcessGroup(tp_degree=world_size) - input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) - input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) - input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) - - output = F.cross_entropy(input_t, target) - output_colo = F.cross_entropy(input_shard, target) - assert torch.allclose(output_colo, output) - - output.backward() - output_colo.backward() - - assert torch.allclose(input_t.grad, input_ct.grad) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - check_cross_entropy() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_loss_func(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_loss_func(1) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn.functional as F + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device + + +def check_cross_entropy(): + input_t = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + input_ct = torch.randn(4, 4, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + input_ct.copy_(input_t) + + target = torch.randint(4, (4,), dtype=torch.int64, device=get_current_device()) + + world_size = torch.distributed.get_world_size() + pg = ProcessGroup(tp_degree=world_size) + input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg)) + input_shard = input_t_colo.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D)) + + output = F.cross_entropy(input_t, target) + output_colo = F.cross_entropy(input_shard, target) + assert torch.allclose(output_colo, output) + + output.backward() + output_colo.backward() + + assert torch.allclose(input_t.grad, input_ct.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_cross_entropy() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2]) +@rerun_if_address_is_in_use() +def test_loss_func(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_loss_func(1) diff --git a/tests/test_ops/test_op.py b/tests/test_ops/test_op.py index 8d3cf50ff2aa..4ed45f1a0e86 100644 --- a/tests/test_ops/test_op.py +++ b/tests/test_ops/test_op.py @@ -1,14 +1,15 @@ -import torch +from functools import partial + import pytest -import colossalai -import torch.nn.functional as F +import torch import torch.multiprocessing as mp -from functools import partial -from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec, ShardSpec -from colossalai.utils import get_current_device +import torch.nn.functional as F from torch.nn import Parameter + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port +from colossalai.utils import free_port, get_current_device def _run_layer_norm(): diff --git a/tests/test_ops/test_view.py b/tests/test_ops/test_view.py index fc6fc2d3c291..dfac044c8faf 100644 --- a/tests/test_ops/test_view.py +++ b/tests/test_ops/test_view.py @@ -1,100 +1,101 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port, get_current_device -from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor, ShardSpec -from colossalai.tensor.distspec import DistPlacementPattern -from tests.test_tensor.common_utils import split_param_row_tp1d, split_param_col_tp1d, debug_print - - -def exam_view_core(pg): - # the case of replicated ColoTensors - x = torch.randn(4, 4).cuda() - x_colo = ColoTensor(x, ColoTensorSpec(pg)) - - y = x.view(2, -1, 2) - y_colo = x_colo.view(2, -1, 2) - - assert torch.all(y == y_colo) - assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE - # the perfect case of col-sliced ColoTensors - split_param_col_tp1d(x_colo, pg) - - z = x.view(torch.Size((2, 1, 2, -1))) - z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) - if dist.get_rank() == 0: - z = z[:, :, :, 0:2] - else: - z = z[:, :, :, 2:] - assert torch.all(z == z_colo) - assert z_colo.dist_spec == x_colo.dist_spec - # the perfect case of row-sliced ColoTensors - split_param_row_tp1d(x_colo, pg) - - z = x.view(torch.Size((-1, 2, 2))) - z_colo = x_colo.view(torch.Size((-1, 2, 2))) - if dist.get_rank() == 0: - z = z[0:2, :, :] - else: - z = z[2:, :, :] - assert torch.all(z == z_colo) - assert z_colo.dist_spec == x_colo.dist_spec - # the normal case of row-sliced ColoTensors - z = x.view(-1, 2, 2, 2) - z_colo = x_colo.view(-1, 2, 2, 2) - assert torch.all(z == z_colo) - assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE - - -def exam_view_autograd(pg): - x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) - y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) - with torch.no_grad(): - y.copy_(x) - y = ColoTensor(y, ColoTensorSpec(pg)) - y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) - - xx = x.view(2, 2, -1) - yy_slice = y_slice.view(2, 2, -1) - yy = yy_slice.to_replicate() - grad = torch.randn(2, 2, 4, device=get_current_device()) - - xx.backward(grad) - yy.backward(grad) - assert torch.all(x.grad == y.grad) - - -def exam_view_errors(pg): - x = torch.randn(8, 2, device=get_current_device()) - x = ColoTensor(x, ColoTensorSpec(pg)) - split_param_row_tp1d(x, pg) - - x.view('a', 'b', 'c') - x.view(8, -1) - x.view([-2, -2, -2]) - x.view((-1, -1, -1)) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) - exam_view_core(pg) - exam_view_autograd(pg) - # exam_view_errors(pg) - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) -@rerun_if_address_is_in_use() -def test_view(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_view(2) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ShardSpec +from colossalai.tensor.distspec import DistPlacementPattern +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from tests.test_tensor.common_utils import debug_print, split_param_col_tp1d, split_param_row_tp1d + + +def exam_view_core(pg): + # the case of replicated ColoTensors + x = torch.randn(4, 4).cuda() + x_colo = ColoTensor(x, ColoTensorSpec(pg)) + + y = x.view(2, -1, 2) + y_colo = x_colo.view(2, -1, 2) + + assert torch.all(y == y_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + # the perfect case of col-sliced ColoTensors + split_param_col_tp1d(x_colo, pg) + + z = x.view(torch.Size((2, 1, 2, -1))) + z_colo = x_colo.view(torch.Size((2, 1, 2, -1))) + if dist.get_rank() == 0: + z = z[:, :, :, 0:2] + else: + z = z[:, :, :, 2:] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the perfect case of row-sliced ColoTensors + split_param_row_tp1d(x_colo, pg) + + z = x.view(torch.Size((-1, 2, 2))) + z_colo = x_colo.view(torch.Size((-1, 2, 2))) + if dist.get_rank() == 0: + z = z[0:2, :, :] + else: + z = z[2:, :, :] + assert torch.all(z == z_colo) + assert z_colo.dist_spec == x_colo.dist_spec + # the normal case of row-sliced ColoTensors + z = x.view(-1, 2, 2, 2) + z_colo = x_colo.view(-1, 2, 2, 2) + assert torch.all(z == z_colo) + assert y_colo.dist_spec.placement == DistPlacementPattern.REPLICATE + + +def exam_view_autograd(pg): + x = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + y = torch.randn(8, 2, device=get_current_device(), requires_grad=True) + with torch.no_grad(): + y.copy_(x) + y = ColoTensor(y, ColoTensorSpec(pg)) + y_slice = y.redistribute(ShardSpec([-1], [pg.tp_world_size()])) + + xx = x.view(2, 2, -1) + yy_slice = y_slice.view(2, 2, -1) + yy = yy_slice.to_replicate() + grad = torch.randn(2, 2, 4, device=get_current_device()) + + xx.backward(grad) + yy.backward(grad) + assert torch.all(x.grad == y.grad) + + +def exam_view_errors(pg): + x = torch.randn(8, 2, device=get_current_device()) + x = ColoTensor(x, ColoTensorSpec(pg)) + split_param_row_tp1d(x, pg) + + x.view('a', 'b', 'c') + x.view(8, -1) + x.view([-2, -2, -2]) + x.view((-1, -1, -1)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=torch.distributed.get_world_size()) + exam_view_core(pg) + exam_view_autograd(pg) + # exam_view_errors(pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_view(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_view(2) diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py index f7227c2d57c0..aa7923d2cec9 100644 --- a/tests/test_optimizer/test_fused_adam.py +++ b/tests/test_optimizer/test_fused_adam.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from torch.optim.adam import Adam from torch.optim import AdamW +from torch.optim.adam import Adam from colossalai.nn.optimizer.fused_adam import FusedAdam from colossalai.testing import parameterize diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py index d19192add3fb..84facfb6f2b0 100644 --- a/tests/test_optimizer/test_hybrid_adam.py +++ b/tests/test_optimizer/test_hybrid_adam.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from torch.optim.adam import Adam from torch.optim import AdamW +from torch.optim.adam import Adam from colossalai.nn.optimizer.hybrid_adam import HybridAdam from colossalai.testing import parameterize diff --git a/tests/test_optimizer/test_nvme.py b/tests/test_optimizer/test_nvme.py index 243f785adaf9..0e8095fbfcaf 100644 --- a/tests/test_optimizer/test_nvme.py +++ b/tests/test_optimizer/test_nvme.py @@ -1,7 +1,8 @@ import pytest import torch -from tests.components_to_test.registry import non_distributed_component_funcs + from colossalai.nn.optimizer import CPUAdam, HybridAdam +from tests.components_to_test.registry import non_distributed_component_funcs def move_some_params_to_cuda(model, torch_model): diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index 7ce2cd433b12..dab474a4ee21 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -6,13 +6,14 @@ import torch.distributed as dist import torch.distributed.rpc as rpc import torch.multiprocessing as mp -from colossalai import launch -from colossalai.logging import disable_existing_loggers -from colossalai.pipeline.pipeline_process_group import ppg from torch import nn from torch._C._distributed_rpc import _is_current_rpc_agent_set from torch.optim import SGD, Adam, Optimizer, RMSprop +from colossalai import launch +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.pipeline_process_group import ppg + rpc_is_initialized = _is_current_rpc_agent_set @@ -20,7 +21,9 @@ def color_debug(text, prefix=' ', color='blue'): color = color.upper() print(getattr(Back, color), prefix, Style.RESET_ALL, text) + class MLP(nn.Module): + def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -32,8 +35,10 @@ def forward(self, x): for layer in self.layers: x = layer(x) return x.sum() - + + class DAG_MLP(nn.Module): + def __init__(self, dim: int, layers: int): super().__init__() self.layers = torch.nn.ModuleList() @@ -48,6 +53,7 @@ def forward(self, x, y): y = self.dag_layer(y) return x.sum(), y.sum() + class RpcTestModel(nn.Module): def __init__(self, stage_id, actual_stage_num, feat_num, h) -> None: diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_pipeline/test_cuda_rpc_chimera.py index 45ad8f828e61..ee049050ec04 100644 --- a/tests/test_pipeline/test_cuda_rpc_chimera.py +++ b/tests/test_pipeline/test_cuda_rpc_chimera.py @@ -1,10 +1,10 @@ import torch -from torch import nn import torch.autograd as autograd +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import nn from colossalai.pipeline.rpc import ChimeraPipelineEngine from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel # global variable for model created feat_num = 100 diff --git a/tests/test_pipeline/test_cuda_rpc_optimizer.py b/tests/test_pipeline/test_cuda_rpc_optimizer.py index 842566730caf..2809b6c6f5f0 100644 --- a/tests/test_pipeline/test_cuda_rpc_optimizer.py +++ b/tests/test_pipeline/test_cuda_rpc_optimizer.py @@ -1,11 +1,10 @@ import torch -from torch import nn -from torch import autograd -from torch.optim import SGD, Adam, RMSprop, Optimizer +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import autograd, nn +from torch.optim import SGD, Adam, Optimizer, RMSprop from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel # global variable for model created feat_num = 100 diff --git a/tests/test_pipeline/test_cuda_rpc_performance.py b/tests/test_pipeline/test_cuda_rpc_performance.py index 6a0509555862..664cde32f9bf 100644 --- a/tests/test_pipeline/test_cuda_rpc_performance.py +++ b/tests/test_pipeline/test_cuda_rpc_performance.py @@ -1,25 +1,25 @@ import os -from typing import Callable, List, Optional, Type, Union import time +from typing import Callable, List, Optional, Type, Union import pytest import torch import torch.nn as nn +from rpc_test_utils import parse_args, rpc_run from titans.dataloader.cifar10 import build_cifar from torchvision.models import resnet50 from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 from tqdm import tqdm -from rpc_test_utils import rpc_run, parse_args import colossalai import colossalai.nn as col_nn -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.trainer import Trainer, hooks -from colossalai.utils import MultiTimer, get_dataloader from colossalai.context import ParallelMode +from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.pipeline.pipelinable import PipelinableContext, PipelinableModel -from colossalai.pipeline.rpc import OneFOneBPipelineEngine, ChimeraPipelineEngine from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc import ChimeraPipelineEngine, OneFOneBPipelineEngine +from colossalai.trainer import Trainer, hooks +from colossalai.utils import MultiTimer, get_dataloader def flatten(x): diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py index 8d03e79813e8..9e2f519bba98 100644 --- a/tests/test_pipeline/test_cuda_rpc_pipeline.py +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -1,8 +1,8 @@ import torch +from rpc_test_utils import RpcTestModel, parse_args, rpc_run from torch import nn from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine -from rpc_test_utils import rpc_run, parse_args, RpcTestModel # global variable for model created feat_num = 100 diff --git a/tests/test_pipeline/test_cuda_rpc_value_correctness.py b/tests/test_pipeline/test_cuda_rpc_value_correctness.py index e6713478baec..289ec6a52248 100644 --- a/tests/test_pipeline/test_cuda_rpc_value_correctness.py +++ b/tests/test_pipeline/test_cuda_rpc_value_correctness.py @@ -1,10 +1,9 @@ import torch -from torch import nn -from torch import autograd +from rpc_test_utils import RpcTestModel, parse_args, rpc_run +from torch import autograd, nn from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine from colossalai.testing import assert_close -from rpc_test_utils import rpc_run, parse_args, RpcTestModel feat_num = 100 h = 100 diff --git a/tests/test_pipeline/test_middleware_1f1b.py b/tests/test_pipeline/test_middleware_1f1b.py index c4dc617b1683..f59405ef819e 100644 --- a/tests/test_pipeline/test_middleware_1f1b.py +++ b/tests/test_pipeline/test_middleware_1f1b.py @@ -1,20 +1,21 @@ -import torch -import pytest import os -import torch.multiprocessing as mp -import torch.distributed.rpc as rpc +from functools import partial +import pytest +import torch +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +from rpc_test_utils import DAG_MLP, MLP from torch import nn from torch._C._distributed_rpc import _is_current_rpc_agent_set + from colossalai import launch +from colossalai.fx import ColoTracer +from colossalai.fx.passes.adding_split_node_pass import balanced_split_pass, split_with_split_nodes_pass from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.middleware.adaptor import get_fx_topology from colossalai.pipeline.pipeline_process_group import ppg from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine -from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass -from colossalai.fx import ColoTracer -from colossalai.pipeline.middleware.adaptor import get_fx_topology -from rpc_test_utils import MLP, DAG_MLP -from functools import partial from colossalai.testing import parameterize, rerun_if_address_is_in_use # global variable for model created @@ -22,6 +23,7 @@ dim = 10 rpc_is_initialized = _is_current_rpc_agent_set + def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): model.eval() tracer = ColoTracer() @@ -34,13 +36,15 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): for submodule in split_submodules: if isinstance(submodule, torch.fx.GraphModule): setattr(submodule, '_topo', topo) - return split_submodules[pp_rank+1] + return split_submodules[pp_rank + 1] + def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): torch.manual_seed(1024) partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) return partition + def run_master(model_cls, world_size, forward_only): torch.manual_seed(100) @@ -50,23 +54,27 @@ def run_master(model_cls, world_size, forward_only): chunk = 1 num_microbatches = 8 use_checkpoint = 'store_true' - + if model_cls == MLP: + def data_gen(): x = torch.zeros((batch_size, dim)) kwargs = dict(x=x) return kwargs + model = model_cls(dim, stage_num * 3) if forward_only: labels = None else: labels = 1 elif model_cls == DAG_MLP: + def data_gen(): x = torch.zeros((batch_size, dim)) y = torch.zeros((batch_size, dim)) kwargs = dict(x=x, y=y) return kwargs + model = model_cls(dim, stage_num * 3) if forward_only: labels = None @@ -74,15 +82,17 @@ def data_gen(): labels = 1 else: pass - + data_kwargs = data_gen() - - engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), - stage_num=stage_num, - num_microbatches=num_microbatches, - device=device, - chunk=chunk, - checkpoint=use_checkpoint,) + + engine = OneFOneBPipelineEngine( + partition_fn=partial(partition, model, data_kwargs), + stage_num=stage_num, + num_microbatches=num_microbatches, + device=device, + chunk=chunk, + checkpoint=use_checkpoint, + ) if not forward_only: engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) @@ -90,13 +100,14 @@ def data_gen(): input_x = torch.randn((batch_size, dim), device=device) input_y = torch.randn((batch_size, dim), device=device) logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) - + + def run_worker(rank, model_cls, world_size, forward_only, master_func): master_addr = 'localhost' master_port = 29020 os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_PORT'] = str(master_port) - + disable_existing_loggers() launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) @@ -113,7 +124,8 @@ def run_worker(rank, model_cls, world_size, forward_only, master_func): # barrier here if rpc_is_initialized(): rpc.shutdown() - + + @pytest.mark.skip("skip due to CI torch version 1.11") @parameterize('model_cls', [MLP, DAG_MLP]) @parameterize('forward_only', [True, False]) @@ -124,5 +136,6 @@ def test_pp_middleware_fwd(model_cls, forward_only): master_func = run_master mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size) + if __name__ == "__main__": - test_pp_middleware_fwd() \ No newline at end of file + test_pp_middleware_fwd() diff --git a/tests/test_pipeline/test_pipelinable.py b/tests/test_pipeline/test_pipelinable.py index c99a88550b71..bb9fe83e53c1 100644 --- a/tests/test_pipeline/test_pipelinable.py +++ b/tests/test_pipeline/test_pipelinable.py @@ -2,7 +2,6 @@ import torch.multiprocessing as mp from colossalai.pipeline.pipelinable import PipelinableContext - from colossalai.testing import rerun_on_exception NUM_CHUNKS = 1 diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_pipeline/test_pipeline_process_group.py index c67e4175df92..b4053c63cfbb 100644 --- a/tests/test_pipeline/test_pipeline_process_group.py +++ b/tests/test_pipeline/test_pipeline_process_group.py @@ -1,13 +1,13 @@ import os +import pytest import torch.distributed.rpc as rpc import torch.multiprocessing as mp -import pytest +from rpc_test_utils import pg_parse_args, rpc_is_initialized -from colossalai.pipeline.pipeline_process_group import ppg from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from rpc_test_utils import pg_parse_args, rpc_is_initialized +from colossalai.pipeline.pipeline_process_group import ppg def run_worker(rank, args): @@ -40,4 +40,4 @@ def run_worker(rank, args): if __name__ == "__main__": args = pg_parse_args() world_size = args.world_size - mp.spawn(run_worker, args=(args,), nprocs=world_size) \ No newline at end of file + mp.spawn(run_worker, args=(args,), nprocs=world_size) diff --git a/tests/test_tensor/common_utils/__init__.py b/tests/test_tensor/common_utils/__init__.py index 5387db70445f..9a35d02ce5ed 100644 --- a/tests/test_tensor/common_utils/__init__.py +++ b/tests/test_tensor/common_utils/__init__.py @@ -1 +1 @@ -from ._utils import * +from ._utils import * diff --git a/tests/test_tensor/core/test_dist_spec_mgr.py b/tests/test_tensor/core/test_dist_spec_mgr.py index e02f4e7977f6..c6abe1f2db26 100644 --- a/tests/test_tensor/core/test_dist_spec_mgr.py +++ b/tests/test_tensor/core/test_dist_spec_mgr.py @@ -1,13 +1,15 @@ import math +from functools import partial + +import pytest import torch import torch.distributed as dist -import pytest -import colossalai import torch.multiprocessing as mp + +import colossalai +from colossalai.tensor import DistSpecManager, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.tensor import DistSpecManager, ProcessGroup, ShardSpec, ReplicaSpec -from functools import partial def run(): diff --git a/tests/test_tensor/core/test_tensor.py b/tests/test_tensor/core/test_tensor.py index b48d9e9a2dfa..94ad7134d690 100644 --- a/tests/test_tensor/core/test_tensor.py +++ b/tests/test_tensor/core/test_tensor.py @@ -1,17 +1,15 @@ -import torch +from functools import partial + import pytest -from colossalai.tensor import ColoTensor +import torch +import torch.multiprocessing as mp from numpy import allclose import colossalai -from colossalai.utils import free_port -from colossalai.tensor import ColoTensorSpec from colossalai.core import global_context as gpc -import torch.multiprocessing as mp +from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup, ReplicaSpec, ShardSpec, distspec from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.tensor import distspec, ColoTensor, ProcessGroup, ShardSpec, ReplicaSpec -from functools import partial def _run_tensor_indexing(): diff --git a/tests/test_tensor/test_colo_checkpoint_tools.py b/tests/test_tensor/test_colo_checkpoint_tools.py index aa333d55276c..04bb9b9db578 100644 --- a/tests/test_tensor/test_colo_checkpoint_tools.py +++ b/tests/test_tensor/test_colo_checkpoint_tools.py @@ -1,47 +1,47 @@ -import torch -import pytest -from functools import partial - -import torch.multiprocessing as mp -import torch.distributed as dist - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup, ColoTensorSpec -from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor -from tests.test_tensor.common_utils import tensor_shard_equal - - -def run_dist(rank, world_size, port, dp_degree, tp_degree): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) - x = torch.randn(4, 4) - param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) - spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) - param.set_tensor_spec(*spec) - - gather_tensor(param) - if dist.get_rank() == 0: - assert torch.all(x == param) - else: - assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) - dist.barrier() - - scatter_tensor(param, spec[0]) - assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) - assert param.requires_grad is True - dist.barrier() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [4]) -@rerun_if_address_is_in_use() -def test_checkpoint(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_checkpoint(world_size=4) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor +from colossalai.utils.cuda import get_current_device +from tests.test_tensor.common_utils import tensor_shard_equal + + +def run_dist(rank, world_size, port, dp_degree, tp_degree): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree) + x = torch.randn(4, 4) + param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg)) + spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D) + param.set_tensor_spec(*spec) + + gather_tensor(param) + if dist.get_rank() == 0: + assert torch.all(x == param) + else: + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + dist.barrier() + + scatter_tensor(param, spec[0]) + assert tensor_shard_equal(x, param.data, pg.tp_local_rank(), pg.tp_world_size()) + assert param.requires_grad is True + dist.barrier() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port(), dp_degree=2, tp_degree=world_size // 2) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_checkpoint(world_size=4) diff --git a/tests/test_tensor/test_parameter.py b/tests/test_tensor/test_parameter.py index 7c3c4b2132e4..21f9cb22fb94 100644 --- a/tests/test_tensor/test_parameter.py +++ b/tests/test_tensor/test_parameter.py @@ -1,8 +1,9 @@ -from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup -import torch import pytest +import torch from common_utils import tensor_equal + import colossalai +from colossalai.tensor import ColoParameter, ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.utils import free_port diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 6fe9ee292cd0..06fa05051062 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,7 +1,8 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch -from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec + from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec physical_mesh_id = torch.arange(0, 16).reshape(2, 8) mesh_shape = (4, 4) diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py index 72820c6a1f0d..7da061c6a824 100644 --- a/tests/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_trainer/test_pipeline/test_p2p.py @@ -7,15 +7,23 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward, - send_backward_recv_forward, send_forward, send_forward_recv_backward, - send_obj_meta) + +from colossalai.communication import ( + recv_backward, + recv_forward, + recv_obj_meta, + send_backward, + send_backward_recv_forward, + send_forward, + send_forward_recv_backward, + send_obj_meta, +) from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch from colossalai.logging import get_dist_logger -from colossalai.utils import free_port, get_current_device from colossalai.testing import rerun_on_exception +from colossalai.utils import free_port, get_current_device BATCH_SIZE = 4 SEQ_LENGTH = 2 diff --git a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py index 48f729658134..dcd3f593edff 100644 --- a/tests/test_trainer/test_pipeline/test_pipeline_schedule.py +++ b/tests/test_trainer/test_pipeline/test_pipeline_schedule.py @@ -5,30 +5,25 @@ from functools import partial from pathlib import Path -import colossalai import pytest import torch -import torch.nn as nn import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.context import ParallelMode -from colossalai.initialize import launch -from colossalai.utils import free_port, get_dataloader, print_rank_0 -from colossalai.testing import rerun_on_exception +import torch.nn as nn from torchvision import transforms from torchvision.datasets import CIFAR10 from torchvision.models import resnet18 +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.testing import rerun_on_exception +from colossalai.utils import free_port, get_dataloader, print_rank_0 BATCH_SIZE = 8 -CONFIG=dict( - NUM_MICRO_BATCHES=2, - parallel = dict( - pipeline=dict(size=2), - tensor=dict(size=1, mode=None) - ) -) +CONFIG = dict(NUM_MICRO_BATCHES=2, parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode=None))) + def run_schedule(rank, world_size, port): launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py index b013433293cd..0259e5ab6f3c 100644 --- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -1,15 +1,16 @@ from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp + +import colossalai from colossalai.amp.amp_type import AMP_TYPE from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.trainer import Trainer from colossalai.utils import MultiTimer, free_port from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize, rerun_if_address_is_in_use BATCH_SIZE = 4 IMG_SIZE = 32 diff --git a/tests/test_trainer/test_trainer_with_pipe_schedule.py b/tests/test_trainer/test_trainer_with_pipe_schedule.py index 3698526a8e6c..926bb52e9073 100644 --- a/tests/test_trainer/test_trainer_with_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_pipe_schedule.py @@ -2,22 +2,23 @@ from functools import partial from pathlib import Path -import colossalai import pytest import torch import torch.multiprocessing as mp import torch.nn as nn +from torch.optim import Adam +from torchvision import transforms +from torchvision.datasets import CIFAR10 +from torchvision.models import resnet18 + +import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.engine.schedule import PipelineSchedule from colossalai.logging import get_dist_logger +from colossalai.testing import rerun_if_address_is_in_use from colossalai.trainer import Trainer from colossalai.utils import MultiTimer, free_port, get_dataloader -from torch.optim import Adam -from torchvision import transforms -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet18 -from colossalai.testing import rerun_if_address_is_in_use BATCH_SIZE = 4 IMG_SIZE = 32 diff --git a/tests/test_utils/test_activation_checkpointing.py b/tests/test_utils/test_activation_checkpointing.py index 3ac75fb00c86..fc1dfac20bb3 100644 --- a/tests/test_utils/test_activation_checkpointing.py +++ b/tests/test_utils/test_activation_checkpointing.py @@ -4,8 +4,9 @@ import pytest import torch import torch.nn.functional as F + from colossalai.context.parallel_mode import ParallelMode -from colossalai.context.random import add_seed, seed, set_mode, reset_seeds +from colossalai.context.random import add_seed, reset_seeds, seed, set_mode from colossalai.utils.activation_checkpoint import checkpoint diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py index 8a0fea9ae47a..e857bd741ce8 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_1d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_1d.py @@ -1,80 +1,81 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_1d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_1d(): - world_size = 8 - run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_1d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus +from colossalai.utils import free_port, is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_1d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="1d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_checkpoint_1d(): + world_size = 8 + run_func = partial(check_checkpoint_1d, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == "__main__": + test_checkpoint_1d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py index 26314290d4de..8463a89ed455 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2d.py @@ -1,80 +1,81 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_2d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_2d(): - world_size = 8 - run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_2d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus +from colossalai.utils import free_port, get_current_device, is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, mode="2d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_checkpoint_2d(): + world_size = 8 + run_func = partial(check_checkpoint_2d, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == "__main__": + test_checkpoint_2d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py index 3dbd340fd42d..f20fb9f6d0e5 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_2p5d.py @@ -1,80 +1,81 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_2p5d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_2p5d(): - world_size = 8 - run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_2p5d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus +from colossalai.utils import free_port, get_current_device, is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_2p5d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=4, depth=1, mode="2.5d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_checkpoint_2p5d(): + world_size = 8 + run_func = partial(check_checkpoint_2p5d, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == "__main__": + test_checkpoint_2p5d() diff --git a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py index 38f650547585..fb9936342846 100644 --- a/tests/test_utils/test_checkpoint/test_checkpoint_3d.py +++ b/tests/test_utils/test_checkpoint/test_checkpoint_3d.py @@ -1,80 +1,81 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pprint -from functools import partial - -import colossalai.nn as col_nn -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.initialize import launch -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device, is_using_pp -from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint -from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus - - -def build_pipeline(model): - from colossalai.pipeline.utils import partition_uniform - - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - depth = len(model) - start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] - layers = [] - for i in range(depth): - if start <= i < end: - layers.append(model[i]) - else: - layers.append(nn.Identity()) - return nn.Sequential(*tuple(layers)) - - -def check_equal(A, B): - assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) - - -def check_checkpoint_3d(rank, world_size, port): - config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) - - disable_existing_loggers() - launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") - - m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) - sd1 = m1.state_dict() - if gpc.get_global_rank() == 0: - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") - save_checkpoint("test.pt", 0, m1) - - m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) - if is_using_pp(): - m2 = build_pipeline(m2) - - load_checkpoint("test.pt", m2) - sd2 = m2.state_dict() - if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: - sd2 = gather_pipeline_parallel_state_dict(sd2) - print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") - - if gpc.get_global_rank() == 0: - for k, v in sd1.items(): - assert k in sd2 - check_equal(v, sd2[k].to(torch.device("cpu"))) - - -@pytest.mark.dist -@pytest.mark.skip("takes too long") -@skip_if_not_enough_gpus(min_gpus=8) -@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") -def test_checkpoint_3d(): - world_size = 8 - run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == "__main__": - test_checkpoint_3d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import pprint +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai.nn as col_nn +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_on_exception, skip_if_not_enough_gpus +from colossalai.utils import free_port, get_current_device, is_using_pp +from colossalai.utils.checkpointing import gather_pipeline_parallel_state_dict, load_checkpoint, save_checkpoint + + +def build_pipeline(model): + from colossalai.pipeline.utils import partition_uniform + + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + depth = len(model) + start, end = partition_uniform(depth, pipeline_size, 1)[pipeline_rank][0] + layers = [] + for i in range(depth): + if start <= i < end: + layers.append(model[i]) + else: + layers.append(nn.Identity()) + return nn.Sequential(*tuple(layers)) + + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-2) + + +def check_checkpoint_3d(rank, world_size, port): + config = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=8, mode="3d")),) + + disable_existing_loggers() + launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + + m1 = nn.Sequential(nn.Linear(4, 8), nn.Linear(8, 4)) + sd1 = m1.state_dict() + if gpc.get_global_rank() == 0: + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd1)}\n") + save_checkpoint("test.pt", 0, m1) + + m2 = nn.Sequential(col_nn.Linear(4, 8), col_nn.Linear(8, 4)) + if is_using_pp(): + m2 = build_pipeline(m2) + + load_checkpoint("test.pt", m2) + sd2 = m2.state_dict() + if is_using_pp() and gpc.get_local_rank(ParallelMode.TENSOR) == 0: + sd2 = gather_pipeline_parallel_state_dict(sd2) + print(f"Rank {gpc.get_global_rank()}:\n{pprint.pformat(sd2)}\n") + + if gpc.get_global_rank() == 0: + for k, v in sd1.items(): + assert k in sd2 + check_equal(v, sd2[k].to(torch.device("cpu"))) + + +@pytest.mark.dist +@pytest.mark.skip("takes too long") +@skip_if_not_enough_gpus(min_gpus=8) +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_checkpoint_3d(): + world_size = 8 + run_func = partial(check_checkpoint_3d, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == "__main__": + test_checkpoint_3d() diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py index 6d89fb90c574..3bc7d51fa16c 100644 --- a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py +++ b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py @@ -1,8 +1,9 @@ import torch import torch.nn as nn +from torch.optim import Adam + from colossalai.utils.checkpoint_io.meta import ParamDistMeta from colossalai.utils.checkpoint_io.utils import build_checkpoints -from torch.optim import Adam class DummyModel(nn.Module): diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py index 780c13dc534a..c767c4cb41eb 100644 --- a/tests/test_utils/test_checkpoint_io/test_load.py +++ b/tests/test_utils/test_checkpoint_io/test_load.py @@ -3,20 +3,21 @@ from tempfile import TemporaryDirectory from typing import Dict -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.utils.checkpoint_io.io import load, save -from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta) from torch import Tensor from torch.nn import Module from torch.optim import Adam, Optimizer +import colossalai +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.io import load, save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta + def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: assert set(a.keys()) == set(b.keys()) diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py index 04e454dcb713..5e3412b5b3b9 100644 --- a/tests/test_utils/test_checkpoint_io/test_merge.py +++ b/tests/test_utils/test_checkpoint_io/test_merge.py @@ -1,18 +1,20 @@ -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import save, merge -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import free_port -from tempfile import TemporaryDirectory -from torch.optim import Adam -from functools import partial -import torch import os +from functools import partial +from tempfile import TemporaryDirectory + import pytest -import colossalai -import torch.nn as nn +import torch import torch.distributed as dist import torch.multiprocessing as mp +import torch.nn as nn +from torch.optim import Adam + +import colossalai +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import merge, save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta class DummyModel(nn.Module): diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py index 5da2ae4fe1f8..fcb8a40b9bd6 100644 --- a/tests/test_utils/test_checkpoint_io/test_merge_param.py +++ b/tests/test_utils/test_checkpoint_io/test_merge_param.py @@ -1,6 +1,7 @@ import torch + +from colossalai.utils.checkpoint_io.distributed import gather_tp_param, merge_param, unflatten_zero_param from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param def test_unflatten_zero_param_even() -> None: diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py index 6e76f3167e31..d7474aa7ec21 100644 --- a/tests/test_utils/test_checkpoint_io/test_redist.py +++ b/tests/test_utils/test_checkpoint_io/test_redist.py @@ -2,19 +2,25 @@ from functools import partial from tempfile import TemporaryDirectory -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from torch.optim import Adam + +import colossalai from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME from colossalai.utils.checkpoint_io.io import redist, save -from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, - RedistMeta) -from torch.optim import Adam +from colossalai.utils.checkpoint_io.meta import ( + ParamDistMeta, + ParamRedistMeta, + PipelineRedistMeta, + RankRedistMeta, + RedistMeta, +) class DummyModel(nn.Module): diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py index 5ff9d0aa2217..81504d9dd319 100644 --- a/tests/test_utils/test_checkpoint_io/test_save.py +++ b/tests/test_utils/test_checkpoint_io/test_save.py @@ -3,20 +3,25 @@ from tempfile import TemporaryDirectory from typing import Dict -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from torch import Tensor +from torch.optim import Adam + +import colossalai from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME, - OTHER_CKPT_FILE_NAME) +from colossalai.utils.checkpoint_io.constant import ( + GLOBAL_META_FILE_NAME, + META_CKPT_FILE_NAME, + MODEL_CKPT_FILE_NAME, + OTHER_CKPT_FILE_NAME, +) from colossalai.utils.checkpoint_io.io import save from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from torch import Tensor -from torch.optim import Adam def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py index 8b83caa12359..11724296c9e8 100644 --- a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py +++ b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py @@ -1,6 +1,7 @@ import torch -from colossalai.utils.checkpoint_io.meta import ParamRedistMeta + from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param +from colossalai.utils.checkpoint_io.meta import ParamRedistMeta def test_flatten_zero_param_even() -> None: diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index a5ea75fffc36..9cd1c417a4dd 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,25 +1,23 @@ -import os, shutil -import torch -import pytest +import os +import shutil from copy import deepcopy from functools import partial -import torch.multiprocessing as mp +import pytest +import torch import torch.distributed as dist - -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import MultiplicativeLR -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +import torch.multiprocessing as mp +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR import colossalai +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.nn.optimizer import ColossalaiOptimizer +from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port +from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint +from colossalai.utils.cuda import get_current_device from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup -from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint -from colossalai.nn.optimizer import ColossalaiOptimizer - from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index 0ecb7446c788..25aa59ac2a39 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,13 +1,12 @@ -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.zero.sharded_param import ShardedTensor -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline -import colossalai - import torch - import torch.multiprocessing as mp +import colossalai +from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.sharded_param import ShardedTensor + def run_tensor_move(rank): colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') diff --git a/tests/test_utils/test_lazy_init_ctx.py b/tests/test_utils/test_lazy_init_ctx.py index 97efb3367490..350131ca9fe3 100644 --- a/tests/test_utils/test_lazy_init_ctx.py +++ b/tests/test_utils/test_lazy_init_ctx.py @@ -1,8 +1,10 @@ -import torch -from colossalai.utils.model.lazy_init_context import LazyInitContext -from torchvision.models import resnet34 import random + import numpy as np +import torch +from torchvision.models import resnet34 + +from colossalai.utils.model.lazy_init_context import LazyInitContext MANUAL_SEED = 0 random.seed(MANUAL_SEED) diff --git a/tests/test_utils/test_memory.py b/tests/test_utils/test_memory.py index 46a5aeba505b..4c858faa0928 100644 --- a/tests/test_utils/test_memory.py +++ b/tests/test_utils/test_memory.py @@ -1,12 +1,12 @@ +from functools import partial + import pytest +import torch.multiprocessing as mp import colossalai -from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_set_process_memory_fraction, colo_device_memory_capacity from colossalai.utils import free_port - -from functools import partial -import torch.multiprocessing as mp +from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): diff --git a/tests/test_utils/test_norm_gradient_clipping.py b/tests/test_utils/test_norm_gradient_clipping.py index 259286663033..baeb7362ac25 100644 --- a/tests/test_utils/test_norm_gradient_clipping.py +++ b/tests/test_utils/test_norm_gradient_clipping.py @@ -1,16 +1,18 @@ -from colossalai.tensor import distspec, ColoTensorSpec, ProcessGroup -from colossalai.tensor.colo_parameter import ColoParameter -import colossalai +from functools import partial + import pytest import torch import torch.multiprocessing as mp -from colossalai.logging import disable_existing_loggers -from colossalai.utils import free_port, get_current_device +from torch.nn.parameter import Parameter from torch.nn.utils import clip_grad_norm_ -from functools import partial + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.tensor import ColoTensorSpec, ProcessGroup, distspec +from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device from colossalai.utils.common import clip_grad_norm -from torch.nn.parameter import Parameter def close(num: float, other: float, rtol: float = 1e-5, atol: float = 1e-8): diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py index 8bdae88464b1..57479ca16ebe 100644 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -2,21 +2,22 @@ # -*- encoding: utf-8 -*- import copy +from functools import partial -import colossalai -from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from colossalai.logging import disable_existing_loggers -from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ -from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy -from functools import partial + +import colossalai +from colossalai.logging import disable_existing_loggers from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port +from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy +from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 def checkpoint_wrapper(module, enable=True): diff --git a/tests/test_zero/common.py b/tests/test_zero/common.py index bc6cd75a6a60..bb481636e896 100644 --- a/tests/test_zero/common.py +++ b/tests/test_zero/common.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist + from colossalai.logging import get_dist_logger from colossalai.utils import checkpoint from colossalai.zero.shard_utils import TensorShardStrategy diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py index 34283f5015e1..b53d83c2455e 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_found_inf.py @@ -1,72 +1,72 @@ -from functools import partial - -import colossalai -from colossalai.utils.cuda import get_current_device -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_zero.test_sharded_optim_v2 import _run_step - -from common import CONFIG - - -@parameterize("cpu_offload", [True, False]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers'] - shard_strategy = shard_strategy_class() - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=True, - ) - - sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - for i, (data, label) in enumerate(train_dataloader): - if i > 1: - break - assert zero_model.overflow_counter == 0 - data, label = data.cuda(), label.cuda() - _run_step(zero_model, sharded_optim, data, label, criterion, False) - for param in zero_model.parameters(): - assert not has_inf_or_nan(param.colo_attr.data_payload) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_found_inf() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_found_inf(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_found_inf(world_size=2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from common import CONFIG + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.shard_utils import BucketTensorShardStrategy +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_zero.test_sharded_optim_v2 import _run_step + + +@parameterize("cpu_offload", [True, False]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) +def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): + test_models = ['repeated_computed_layers'] + shard_strategy = shard_strategy_class() + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=True, + ) + + sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + assert zero_model.overflow_counter == 0 + data, label = data.cuda(), label.cuda() + _run_step(zero_model, sharded_optim, data, label, criterion, False) + for param in zero_model.parameters(): + assert not has_inf_or_nan(param.colo_attr.data_payload) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_test_found_inf() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_found_inf(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_found_inf(world_size=2) diff --git a/tests/test_zero/test_shard_param.py b/tests/test_zero/test_shard_param.py index 8db2b7e79604..27bef958d560 100644 --- a/tests/test_zero/test_shard_param.py +++ b/tests/test_zero/test_shard_param.py @@ -1,17 +1,18 @@ from copy import deepcopy from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp + +import colossalai +from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from tests.test_zero.common import CONFIG, allclose -from colossalai.gemini.stateful_tensor import StatefulTensor @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) diff --git a/tests/test_zero/test_sharded_optim_state_dict.py b/tests/test_zero/test_sharded_optim_state_dict.py index f8c42930b281..2b12884959e7 100644 --- a/tests/test_zero/test_sharded_optim_state_dict.py +++ b/tests/test_zero/test_sharded_optim_state_dict.py @@ -1,20 +1,21 @@ +from functools import partial + import pytest -import colossalai import torch import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from functools import partial -from tests.test_tensor.common_utils import set_seed -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize + +import colossalai from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.tensor import ProcessGroup +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed def init_zero(model_builder, placement_policy): diff --git a/tests/test_zero/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_sharded_optim_with_sync_bn.py index ea5b315188a3..b34d4763a5db 100644 --- a/tests/test_zero/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero/test_sharded_optim_with_sync_bn.py @@ -3,18 +3,19 @@ from functools import partial -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from torchvision.models import resnet50 + +import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext from colossalai.zero.shard_utils import TensorShardStrategy -from torchvision.models import resnet50 def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_state_dict.py index 7ac9b151e4d6..bdcdfa1272e1 100644 --- a/tests/test_zero/test_state_dict.py +++ b/tests/test_zero/test_state_dict.py @@ -4,20 +4,20 @@ from copy import deepcopy from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp +from common import CONFIG + +import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) +from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs -from common import CONFIG - @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_zero_state_dict(shard_strategy_class): diff --git a/tests/test_zero/test_tensor_utils.py b/tests/test_zero/test_tensor_utils.py index 81855ff5e10a..226ec80b959b 100644 --- a/tests/test_zero/test_tensor_utils.py +++ b/tests/test_zero/test_tensor_utils.py @@ -1,18 +1,21 @@ +from functools import partial + import pytest +import torch +import torch.multiprocessing as mp import colossalai -from colossalai.utils.cuda import get_current_device -from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move, - colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, - colo_model_tensor_clone) from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.utils import free_port +from colossalai.gemini.tensor_utils import ( + colo_model_data_move_to_cpu, + colo_model_data_tensor_move, + colo_model_data_tensor_move_inline, + colo_model_tensor_clone, + colo_tensor_mem_usage, +) from colossalai.testing import rerun_if_address_is_in_use - -import torch - -from functools import partial -import torch.multiprocessing as mp +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device def _run_colo_tensor_mem_usage(): diff --git a/tests/test_zero/test_zero_engine.py b/tests/test_zero/test_zero_engine.py index 80ded65d634c..302ee2836450 100644 --- a/tests/test_zero/test_zero_engine.py +++ b/tests/test_zero/test_zero_engine.py @@ -3,11 +3,14 @@ from functools import partial -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai from colossalai.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port @@ -15,9 +18,6 @@ from colossalai.zero.sharded_model.utils import col_model_deepcopy from colossalai.zero.sharded_optim._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params) def run_dist(rank, world_size, port, parallel_config):