diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 3e96310e1890..267c4529eb95 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -503,4 +503,4 @@ def all_to_all_cost(self, num_bytes, mesh_dim): num_devices = self.logical_mesh_id.shape[mesh_dim] penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * - (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) \ No newline at end of file + (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index c025daaeccc7..7e97bee01b33 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -287,4 +287,4 @@ def reduce_forward(input_, process_group): def reduce_backward(input_, process_group): - return _ReduceBackward.apply(input_, process_group) \ No newline at end of file + return _ReduceBackward.apply(input_, process_group) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 8051433e8d71..085e3150c697 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -80,7 +80,7 @@ class PolicyLocation: PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"), "transformers.models.opt.modeling_opt.OPTForQuestionAnswering": PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"), - + # Bloom "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(file_name="bloom", class_name="BloomModelPolicy"), diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 1f8db99c9236..590d6966bff6 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -86,4 +86,4 @@ def test_device_mesh_from_process_group(): if __name__ == '__main__': test_device_mesh() - test_device_mesh_from_process_group() \ No newline at end of file + test_device_mesh_from_process_group() diff --git a/tests/test_shardformer/test_layer/test_layernorm.py b/tests/test_shardformer/test_layer/test_layernorm.py index 080fae034956..a117845545be 100644 --- a/tests/test_shardformer/test_layer/test_layernorm.py +++ b/tests/test_shardformer/test_layer/test_layernorm.py @@ -41,4 +41,4 @@ def test_layernorm(): if __name__ == '__main__': - test_layernorm_1d() \ No newline at end of file + test_layernorm_1d()