diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index b7ffc3d0bfd..4ad1cab4818 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -294,7 +294,8 @@ c10::intrusive_ptr postScatter( input_tensors.front().push_back(output_tensor); continue; } - input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); + input_tensors.front().push_back( + input_tensor.slice(0, j, j + 1).contiguous()); j++; } diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index 6c114448744..e11e491a889 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -217,8 +217,9 @@ DeviceMesh mesh1({1}); DeviceMesh mesh2({0, 1, 2, 3}); DeviceMesh mesh3({0, 2, 3}); DeviceMesh mesh4({1, 0, 2}); -auto all_meshes = testing::Values(mesh0, mesh1, mesh2, mesh3, mesh4); -auto all_nontrivial_meshes = testing::Values(mesh2, mesh3, mesh4); +DeviceMesh mesh5({1, 0}); +auto all_meshes = testing::Values(mesh0, mesh1, mesh2, mesh3, mesh4, mesh5); +auto all_nontrivial_meshes = testing::Values(mesh2, mesh3, mesh4, mesh5); } // namespace