diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index 8c8f898b8..01a1c9e3d 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -428,52 +428,56 @@ def _layer(self) -> dict: from distributed.shuffle._shuffle import shuffle_barrier dsk = {} - name_left = "hash-join-transfer-" + self.left._name - name_right = "hash-join-transfer-" + self.right._name + token = self._name.split("-")[-1] + token_left = token + "-left" + token_right = token + "-right" + _barrier_key_left = barrier_key(ShuffleId(token_left)) + _barrier_key_right = barrier_key(ShuffleId(token_right)) + + transfer_name_left = "hash-join-transfer-" + token_left + transfer_name_right = "hash-join-transfer-" + token_right transfer_keys_left = list() transfer_keys_right = list() func = create_assign_index_merge_transfer() for i in range(self.left.npartitions): - transfer_keys_left.append((name_left, i)) - dsk[(name_left, i)] = ( + transfer_keys_left.append((transfer_name_left, i)) + dsk[(transfer_name_left, i)] = ( func, (self.left._name, i), self.shuffle_left_on, _HASH_COLUMN_NAME, self.npartitions, - self.left._name, + token_left, i, self.left._meta, self._partitions, ) for i in range(self.right.npartitions): - transfer_keys_right.append((name_right, i)) - dsk[(name_right, i)] = ( + transfer_keys_right.append((transfer_name_right, i)) + dsk[(transfer_name_right, i)] = ( func, (self.right._name, i), self.shuffle_right_on, _HASH_COLUMN_NAME, self.npartitions, - self.right._name, + token_right, i, self.right._meta, self._partitions, ) - _barrier_key_left = barrier_key(ShuffleId(self.left._name)) - _barrier_key_right = barrier_key(ShuffleId(self.right._name)) - dsk[_barrier_key_left] = (shuffle_barrier, self.left._name, transfer_keys_left) + dsk[_barrier_key_left] = (shuffle_barrier, token_left, transfer_keys_left) dsk[_barrier_key_right] = ( shuffle_barrier, - self.right._name, + token_right, transfer_keys_right, ) for part_out in self._partitions: dsk[(self._name, part_out)] = ( merge_unpack, - self.left._name, - self.right._name, + token_left, + token_right, part_out, _barrier_key_left, _barrier_key_right, diff --git a/dask_expr/tests/test_distributed.py b/dask_expr/tests/test_distributed.py index e83836eca..4bd59860a 100644 --- a/dask_expr/tests/test_distributed.py +++ b/dask_expr/tests/test_distributed.py @@ -8,6 +8,7 @@ distributed = pytest.importorskip("distributed") from distributed import Client, LocalCluster +from distributed.shuffle._core import id_from_key from distributed.utils_test import client as c # noqa F401 from distributed.utils_test import gen_cluster @@ -55,6 +56,46 @@ async def test_merge_p2p_shuffle(c, s, a, b, npartitions_left): lib.testing.assert_frame_equal(x.reset_index(drop=True), df_left.merge(df_right)) +@gen_cluster(client=True) +async def test_self_merge_p2p_shuffle(c, s, a, b): + pdf = lib.DataFrame({"a": range(100), "b": range(0, 200, 2)}) + ddf = from_pandas(pdf, npartitions=5) + + out = ddf.merge(ddf, left_on="a", right_on="b", shuffle_backend="p2p") + # Generate unique shuffle IDs if the input frame is the same but parameters differ + assert sum(id_from_key(k) is not None for k in out.dask) == 2 + x = await c.compute(out) + expected = pdf.merge(pdf, left_on="a", right_on="b") + lib.testing.assert_frame_equal( + x.sort_values("a_x", ignore_index=True), + expected.sort_values("a_x", ignore_index=True), + ) + + +@gen_cluster(client=True) +async def test_merge_p2p_shuffle_reused_dataframe(c, s, a, b): + pdf1 = lib.DataFrame({"a": range(100), "b": range(0, 200, 2)}) + pdf2 = lib.DataFrame({"x": range(200), "y": [1, 2, 3, 4] * 50}) + ddf1 = from_pandas(pdf1, npartitions=5) + ddf2 = from_pandas(pdf2, npartitions=10) + + out = ( + ddf1.merge(ddf2, left_on="a", right_on="x", shuffle_backend="p2p") + .repartition(20) + .merge(ddf2, left_on="b", right_on="x", shuffle_backend="p2p") + ) + # Generate unique shuffle IDs if the input frame is the same but parameters differ + assert sum(id_from_key(k) is not None for k in out.dask) == 4 + x = await c.compute(out) + expected = pdf1.merge(pdf2, left_on="a", right_on="x").merge( + pdf2, left_on="b", right_on="x" + ) + lib.testing.assert_frame_equal( + x.sort_values("a", ignore_index=True), + expected.sort_values("a", ignore_index=True), + ) + + @pytest.mark.parametrize("npartitions_left", [5, 6]) @gen_cluster(client=True) async def test_index_merge_p2p_shuffle(c, s, a, b, npartitions_left):