Skip to content

mlm training fails due to large message size for nested_gather on torch_xla #16005

@miladm

Description

@miladm

The PyTorch/XLA/TPU HF tests for mlm-bert and mlm-roberta fail as discussed below. I have extensively tested this issue on both 2VM and 1VM machines. On both machines, when I set --num_core 1, the test passes as expected, and when I set --num_core 8 I get the error below.

This error suggests the mesh_reduce API called by evaluate() > evaluation_loop > nested_xla_mesh_reduce() communicates larger than expected tensor payloads.

Reference to an older issue which sounds relevant here.

Repro command:

python3 examples/pytorch/xla_spawn.py  --num_cores 8  examples/pytorch/language-modeling/run_mlm.py  --logging_dir ./tensorboard-metric --cache_dir ./cache_dir  --dataset_name wikitext  --dataset_config_name wikitext-2-raw-v1  --do_train  --do_eval  --overwrite_output_dir  --output_dir language-modeling  --logging_steps 30  --save_steps 3000  --overwrite_cache  --tpu_metrics_debug  --model_type=bert --tokenizer=bert-base-cased --num_train_epochs 1 --per_device_train_batch_size 16 --per_device_eval_batch_size 4

Error message:

***** train metrics *****
  epoch                    =        1.0
  train_loss               =      8.969
  train_runtime            = 0:02:58.03
  train_samples            =       4771
  train_samples_per_second =     26.798
  train_steps_per_second   =      0.213
03/09/2022 03:22:36 - INFO - run_mlm - *** Evaluate ***
[INFO|trainer.py:570] 2022-03-09 03:22:36,278 >> The following columns in the evaluation set  don't have a corresponding argument in `BertForMaskedLM.forward` and have been ignored: special_tokens_mask. If special_tokens_mask are not expected by `BertForMaskedLM.forward`,  you can safely ignore this message.
[INFO|trainer.py:2403] 2022-03-09 03:22:36,281 >> ***** Running Evaluation *****
[INFO|trainer.py:2405] 2022-03-09 03:22:36,281 >>   Num examples = 493
[INFO|trainer.py:2408] 2022-03-09 03:22:36,281 >>   Batch size = 2
Exception in device=TPU:7: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
Exception in device=TPU:2: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
Exception in device=TPU:0: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
Exception in device=TPU:3: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)

...

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 582, in _mp_fn
    main()
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 545, in main
    metrics = trainer.evaluate()
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2271, in evaluate
    output = eval_loop(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2460, in evaluation_loop
    logits = self._nested_gather(logits)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2546, in _nested_gather
    tensors = nested_xla_mesh_reduce(tensors, name)
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 582, in _mp_fn
    main()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 163, in nested_xla_mesh_reduce
    return xm.mesh_reduce(name, tensors, torch.cat)
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 545, in main
    metrics = trainer.evaluate()
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 582, in _mp_fn
    main()
  File "/home/miladmo/transformers/examples/pytorch/language-modeling/run_mlm.py", line 545, in main
    metrics = trainer.evaluate()
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 974, in mesh_reduce
    xdata = rendezvous(tag, bio.getvalue())
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2271, in evaluate
    output = eval_loop(
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2271, in evaluate
    output = eval_loop(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 926, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2460, in evaluation_loop
    logits = self._nested_gather(logits)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2460, in evaluation_loop
    logits = self._nested_gather(logits)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2546, in _nested_gather
    tensors = nested_xla_mesh_reduce(tensors, name)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 163, in nested_xla_mesh_reduce
    return xm.mesh_reduce(name, tensors, torch.cat)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer.py", line 2546, in _nested_gather
    tensors = nested_xla_mesh_reduce(tensors, name)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
  File "/home/miladmo/.local/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 163, in nested_xla_mesh_reduce
    return xm.mesh_reduce(name, tensors, torch.cat)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 974, in mesh_reduce
    xdata = rendezvous(tag, bio.getvalue())
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 974, in mesh_reduce
    xdata = rendezvous(tag, bio.getvalue())
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 926, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 926, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'nested_gather': Received message larger than max (950146944 vs. 4194304) (8)
Traceback (most recent call last):

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions