Skip to content

[Bug][EP][compile] non_blocking=True D2H race in input_splits under per-layer torch.compile#2951

Open
weifengpy wants to merge 1 commit intopytorch:mainfrom
weifengpy:ep-fix-non-blocking-race-v2
Open

[Bug][EP][compile] non_blocking=True D2H race in input_splits under per-layer torch.compile#2951
weifengpy wants to merge 1 commit intopytorch:mainfrom
weifengpy:ep-fix-non-blocking-race-v2

Conversation

@weifengpy
Copy link
Copy Markdown
Contributor

@weifengpy weifengpy commented Apr 13, 2026

this is needed for per-layer compilation #2741

repro command: NGPU=8 MODULE=deepseek_v3 CONFIG=deepseek_v3_16b ./run_train.sh --compile.enable --compile.components model,loss --parallelism.expert_parallel_degree 4 --training.steps 20

error: RuntimeError: Eq(u27 + u28 + u29 + u30, 98304)

clauded guessed root cause:

inductor does not insert a proper CUDA synchronization between a              
  non_blocking=True D2H copy and the subsequent .item() call that reads the CPU tensor

torchtitan/distributed/expert_parallel.py:166 (and the identical pattern at line 504 for TorchAOExpertParallel):        
  .to(torch.device("cpu"), non_blocking=True)

raw error stack:

      File "torch/distributed/elastic/multiprocessing/errors/__init__.py", line 367, in wrapper                             
        return f(*args, **kwargs)                                                                                           
      File "torchtitan/trainer.py", line 843, in train                                                                      
        self.train_step(data_iterator)                                                                                      
      File "torchtitan/trainer.py", line 753, in train_step                                                                 
        loss = self.forward_backward_step(                                                                                  
      File "torchtitan/trainer.py", line 701, in forward_backward_step                                                      
        pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)                                                     
      File "torch/nn/modules/module.py", line 1778, in _wrapped_call_impl                                                   
        return self._call_impl(*args, **kwargs)                                                                             
      File "torch/nn/modules/module.py", line 1884, in _call_impl                                                           
        return inner()                                                                                                      
      File "torch/nn/modules/module.py", line 1832, in inner                                                                
        result = forward_call(*args, **kwargs)
      File "torchtitan/models/common/decoder.py", line 134, in forward                                                      
        h = layer(h, self.freqs_cis, attention_masks, positions)                                                            
      File "torch/nn/modules/module.py", line 1776, in _wrapped_call_impl                                                   
        return self._compiled_call_impl(*args, **kwargs)                                                                    
      File "torch/_dynamo/eval_frame.py", line 1041, in compile_wrapper                                                     
        return fn(*args, **kwargs)                                                                                          
      File "torch/nn/modules/module.py", line 1884, in _call_impl                                                         
        return inner()                                                                                                      
      File "torch/nn/modules/module.py", line 1832, in inner                                                              
        result = forward_call(*args, **kwargs)                                                                              
      File "torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 143, in forward                         
      File "torch/_dynamo/eval_frame.py", line 1277, in _fn                                                                 
        return fn(*args, **kwargs)                                                                                          
      File "torch/_functorch/aot_autograd.py", line 1273, in forward                                                        
        return compiled_fn(full_args)                                                                                       
      File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 777, in runtime_wrapper                             
        all_outs = compiled_invoker.run(args, on_before_call=exit_prologue)                                                 
      File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 506, in run                                         
        return call_func_at_runtime_with_args(                                                                              
      File "torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args                         
        out = normalize_as_list(f(args))                                                                                    
      File "torch/_functorch/_aot_autograd/utils.py", line 93, in g                                                       
        return f(*args)                                                                                                     
      File "torch/autograd/function.py", line 596, in apply                                                               
        return super().apply(*args, **kwargs)                                                                               
      File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2844, in forward
        fw_outs = call_func_at_runtime_with_args(                                                                           
      File "torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args                           
        out = normalize_as_list(f(args))
      File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 850, in wrapper                                       
        return compiled_fn(runtime_args)                                                                                    
      File "<subclass_wrapper>", line 42, in inner_fn
      File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1055, in inner_fn                                     
        outs = compiled_fn(args)                                                                                            
      File "torch/_inductor/output_code.py", line 710, in __call__
        return self.current_callable(inputs)                                                                                
      File "torch/_inductor/utils.py", line 3495, in run                                                                  
        out = model(new_inputs)                                                                                             
      File "/tmp/torchinductor_weif/.../c22q3ucc33sz....py", line 2394, in call
        raise RuntimeError('Eq(u27 + u28 + u29 + u30, 98304)')                                                              
    RuntimeError: Eq(u27 + u28 + u29 + u30, 98304)   

Consolidate apply_compile_dense and apply_compile_sparse into a single
apply_compile function. The only difference was capture_scalar_outputs
which is harmless for dense models.

Remove the _run_experts_grouped_mm separate compile boundary and EP
wrapper — no longer needed after fixing the symbolic shape issue in
_generate_permute_indices.

Fix _generate_permute_indices to use torch.arange(seg_ids.shape[0])
instead of torch.arange(total), reusing the unbacked symint from
repeat_interleave rather than creating a redundant one that produces
an Eq(u1, u2) constraint inductor cannot lower.

Remove the x[:total_tokens] slice in _run_experts_for_loop — padding
was removed in pytorch#2774, so sum(num_tokens_per_expert) == x.shape[0] and
the slice is a no-op. This also eliminates the need for torch._check
guards on the unbacked symint.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 13, 2026
@weifengpy weifengpy marked this pull request as ready for review April 13, 2026 21:42
@weifengpy weifengpy changed the title [EP][compile] non_blocking=True D2H race in input_splits under per-layer torch.compile [Bug][EP][compile] non_blocking=True D2H race in input_splits under per-layer torch.compile Apr 13, 2026
@weifengpy weifengpy force-pushed the ep-fix-non-blocking-race-v2 branch from 1fbf4d4 to 612002d Compare April 13, 2026 21:43
@tianyu-l tianyu-l added module: torch.compile bug Something isn't working labels Apr 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot. module: torch.compile

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

2 participants