diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py index c85f53cc44c3..5234b6b1a1b5 100644 --- a/colossalai/nn/layer/base_layer.py +++ b/colossalai/nn/layer/base_layer.py @@ -1,11 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from contextlib import contextmanager + import torch.nn as nn from colossalai.context import ParallelMode from colossalai.core import global_context as gpc -from contextlib import contextmanager class ParallelLayer(nn.Module): diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py index 3be057b3a98b..789ce8ab35b8 100644 --- a/tests/test_device/test_device_mesh.py +++ b/tests/test_device/test_device_mesh.py @@ -1,6 +1,7 @@ -from colossalai.device.device_mesh import DeviceMesh import torch +from colossalai.device.device_mesh import DeviceMesh + def test_device_mesh(): physical_mesh_id = torch.arange(0, 16).reshape(2, 8)