Skip to content

Conversation

@Kosinkadink
Copy link
Collaborator

@Kosinkadink Kosinkadink commented Mar 4, 2025

Overview

This PR adds support for MultiGPU acceleration via 'work unit' splitting - by default, conditioning is treated as work units. Any model that uses more than a single conditioning can be sped up via MultiGPU Work Units - positive+negative, multiple positive/masked conditioning, etc. The code is extendible to allow extensions to implement their own work units; as proof of concept, I have implemented AnimateDiff-Evolved contexts to behave as work units.

As long as there is a heavy bottleneck on the GPU, there will be a noticeable performance improvement. If the GPU is only lightly loaded (i.e RTX 4090 sampling a single 512x512 SD1.5 image), the overhead to split and combine work units will result in performance loss compared to using just one GPU.

The MultiGPU Work Units node can be placed in (almost) any existing workflow. When only one device is found, the node does effectively nothing, so workflows making use of the node will stay compatible between single and multi-GPU setups:
image

The feature works best when work splitting is symmetrical (GPUs are the same/have roughly the same performance), with the slowest GPU acting as the limiter. For asymmetrical setups, the MultiGPU Options node can be used to inform load balancing code about the relative performance of the MultiGPU setup:
image

Nvidia (CUDA): Tested, works ✅.
AMD (ROCm): Untested, will validate soon
AMD (DirectML): Untested,
Intel (Arc XPU): Tested, does not work on Windows but works on Linux ⚠️.

Implementation Details

Based on max_gpus and the available amount of devices, the main ModelPatcher is cloned and relevant properties (like model) are deepcloned after the values are unloaded. MultiGPU clones are stored on the ModelPatcher's additional_models under key multigpu. During sampling, the deepcloned ModelPatchers are re-cloned with the values from the main ModelPatcher, with any additional_models kept consistent. To avoid unnecessarily deepcloning models, currently_loaded_models from comfy.model_management are checked for a matching deepcloned model, in which case they are (soft) cloned and made to match the main ModelPatcher.

When native conds are used as the work units, _calc_cond_batch calls and returns _calc_cond_batch_multigpu to avoid potential regression in performance if single-GPU code was to be refactored. In the future, this can be revisited to reuse the same code while carefully comparing performance for various models. No processes are created, only python threads; while GIL does limit CPU performance, the GPU being the bottleneck makes diffusion I/O-bound rather than CPU-bound. This vastly improves compatibility with existing code.

Since deepcloning requires that the base model is 'clean', comfy.model_management has received a unload_model_and_clones function to unload only specific models and their clones.

The --cuda-device startup argument has been refactored to accept a string rather than an int, allowing multiple ids to be provided while not breaking any existing usage:
image
image
This can be used to not only limit ComfyUI's visibility to a subset of devices per instance, but also their order (the first id is treated as device:0, second as device:1, etc.)

Performance (will add more examples soon)

Wan 1.3B t2v: 1.85x uplift for 2 RTX 4090s vs 1 RTX 4090.
image
image

Wan 14B t2v: 1.89x uplift for 2 RTX 4090s vs 1 RTX 4090
image
image

…about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas'
…ches to use target_dict instead of target so that more information can be provided about the current execution environment if needed
… to separate out Wrappers/Callbacks/Patches into different hook types (all affect transformer_options)
… hook_type, modified necessary code to no longer need to manually separate out hooks by hook_type
…ptions to not conflict with the "sigmas" that will overwrite "sigmas" in _calc_cond_batch
…ade AddModelsHook operational and compliant with should_register result, moved TransformerOptionsHook handling out of ModelPatcher.register_all_hook_patches, support patches in TransformerOptionsHook properly by casting any patches/wrappers/hooks to proper device at sample time
…ops nodes by properly caching between positive and negative conds, make hook_patches_backup behave as intended (in the case that something pre-registers WeightHooks on the ModelPatcher instead of registering it at sample time)
…added some doc strings and removed a so-far unused variable
…ok to InjectionsHook (not yet implemented, but at least getting the naming figured out)
@Kosinkadink
Copy link
Collaborator Author

Kosinkadink commented Aug 14, 2025

I've been given the greenlight to finish this PR!

Anyone who experienced black images, could you tell me your operating system, pytorch version, and hardware? I am able to reproduce, but want to confirm the scale of the problem.

Anyone who experienced issues with changing loras causing memory leak messages, could you give a step by step guide to reproduce? I think I can brute force the steps, but getting a good way to reproduce would be great!

@monstari I am able to reproduce your issue with Sage Attention, but haven't tried torch compile yet. What happens if you do torch compile without sage attention? From my initial look, it may be a triton bug w/ sage attention, but I'll need to reconfirm later.

Since I have a way to reproduce the memory management issues when close to VRAM cap, I'll work on that too. Goal will be for the major problems to be solved over a course of 1-3 weeks and then submit the PR for review.

@jkyamog
Copy link

jkyamog commented Aug 15, 2025

I've been given the greenlight to finish this PR!

Anyone who experienced black images, could you tell me your operating system, pytorch version, and hardware? I am able to reproduce, but want to confirm the scale of the problem.

Anyone who experienced issues with changing loras causing memory leak messages, could you give a step by step guide to reproduce? I think I can brute force the steps, but getting a good way to reproduce would be great!

@monstari I am able to reproduce your issue with Sage Attention, but haven't tried torch compile yet. What happens if you do torch compile without sage attention? From my initial look, it may be a triton bug w/ sage attention, but I'll need to reconfirm later.

Since I have a way to reproduce the memory management issues when close to VRAM cap, I'll work on that too. Goal will be for the major problems to be solved over a course of 1-3 weeks and then submit the PR for review.

Thanks very much for working on this. I have been using for about a month or 2 now. Started with 2 3090 and recently 3 3090 (not great, it has to be 4 3090). So far its been ok aside from the OOM after repeated runs. I just updated and now going to test again see if now resolved.

@Kosinkadink
Copy link
Collaborator Author

@jkyamog for the OOMs, could you try to purposely made them happen with a 'simple' workflow and detail the steps here? I actually have a dual 3090 setup for testing at the moment as well, so would be very helpful! Also, what operating system and pytorch version are you running just so I can take a note?

@jkyamog
Copy link

jkyamog commented Aug 15, 2025

@jkyamog for the OOMs, could you try to purposely made them happen with a 'simple' workflow and detail the steps here? I actually have a dual 3090 setup for testing at the moment as well, so would be very helpful! Also, what operating system and pytorch version are you running just so I can take a note?

sure, this is the stock workflow from comfyui wiki, I then added MultiGPU and put 2 gpu. I got OOM error after a 2nd run. Flux doesn't really work well with multi gpu, but several batches or more complicated workflow still improves it. This is just a simple workflow to get OOM error. I have attached the image here, so you can easily import this on comfyui. Might be not related I have noticed after updating to head that usually now when putting 3 gpu unit it will keep on eating system ram and swap, until the kernel kills WSL. Before I can run 3 gpus, but not really helpful as 1 gpu is typically idle. So, it's not really an issue and it might be not related. I have 96GB ram, 88GB allocated to WSL with 32 GB swap.

SamplerCustomAdvanced
Allocation on device � would exceed allowed memory. (out of memory)
Currently allocated : 23.19 GiB
Requested : 126.00 MiB
Device limit : 24.00 GiB
Free (according to CUDA): 0 bytes
PyTorch limit (set by user-supplied memory fraction)
: 17179869184.00 GiB
This error means you ran out of memory on your GPU.

TIPS: If the workflow worked before you might have accidentally set the batch_size to a large number.

I am running in WSL here are the relevant libs on conda/pip
- pytorch-lightning==2.5.2
- torch==2.8.0
- torchaudio==2.8.0
- torchmetrics==1.7.3
- torchsde==0.2.6
- torchvision==0.23.0
- nvidia-cublas-cu12==12.8.4.1
- nvidia-cuda-cupti-cu12==12.8.90
- nvidia-cuda-nvrtc-cu12==12.8.93
- nvidia-cuda-runtime-cu12==12.8.90
- nvidia-cudnn-cu12==9.10.2.21
- nvidia-cufft-cu12==11.3.3.83
- nvidia-cufile-cu12==1.13.1.3
- nvidia-curand-cu12==10.3.9.90
- nvidia-cusolver-cu12==11.7.3.90
- nvidia-cusparse-cu12==12.5.8.93
- nvidia-cusparselt-cu12==0.7.1
- nvidia-nccl-cu12==2.27.3
- nvidia-nvjitlink-cu12==12.8.93
- nvidia-nvtx-cu12==12.8.90
ComfyUI_00299_

@Kosinkadink Kosinkadink requested a review from guill as a code owner August 15, 2025 23:50
@monstari
Copy link

I noticed that there’s no speed boost when using distilled models with CFG=1. Since Normalized Attention Guidance already provides similar negative conditioning at CFG=1, would it be possible to explore solutions similar to XDit’s parallel processing in the future?

Also, if we have more than two GPUs, I assume this solution wouldn’t be as useful, since we can only apply two conditioning streams.

Thanks for all your work still!

@QUTGXX
Copy link

QUTGXX commented Aug 29, 2025

我使用 RTX 5090 单独测试了这一点,使用 Wan 14B t2v 型号的显卡,生成耗时约 19.54 秒。当使用多 GPU 工作单元拆分将 RTX 5090 与 RTX 4090 组合使用时,时间缩短至 11.52 秒。

速度大约提高 1.7 倍(RTX 5090 + RTX 4090)。

也就是说,我还无法让 Sage Attention 和 Torch Compiler 与 MultiGPU 设置正常工作,我希望这个问题能尽快得到解决。

总体而言,此功能很有前景,特别是对于运行混合 GPU 配置的用户而言。

运行 Sage 时遇到的一些错误注意:

!!! 处理过程中出现异常 !!! 无法从 Triton(cpu 张量?)访问指针参数(位于 0)
回溯(最近一次调用最后一次):
文件“/home/rtl-6/execution.py”,第 496 行,执行
output_data、output_ui、has_subgraph、has_pending_tasks = await get_output_data(prompt_id、unique_id、obj、input_data_all、execution_block_cb=execution_block_cb、pre_execute_cb=pre_execute_cb、hidden_​​inputs=hidden_​​inputs)
^
...
​执行块cb=执行块cb,pre_execute_cb=pre_execute_cb,隐藏输入=隐藏输入)
^
...
​process_inputs(input_dict,i)
文件“/home/rtl-6/execution.py”,第 277 行,在 process_inputs
result = f(**inputs)
^^^^^^^^^^^
文件“/home/rtl-6/nodes.py”,第 1521 行,在 sample
return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/nodes.py”,第 1488 行,在 common_ksampler 中,
samples = comfy.sample.sample(model、noise、steps、cfg、sampler_name、scheduler、positive、negative、latent_image、
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/sample.py”,第 45 行,在 sample 中,
samples = sampler.sample(noise、positive、negative、cfg=cfg、latent_image=latent_image、start_step=start_step、last_step=last_step, force_full_denoise=force_full_denoise,denoise_mask=noise_mask,sigmas=sigmas,callback=callback,disable_pbar=disable_pbar,seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 1355 行,在示例中
返回样本(self.model、噪声、正、负、cfg、self.device、采样器、sigmas、self.model_options、latent_image=latent_image、denoise_mask=denoise_mask、callback=callback、disable_pbar=disable_pbar、seed=seed)
^
...
​cfg_guider.sample(噪声,latent_image,采样器,sigmas,denoise_mask,回调,disable_pbar,种子)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 1230 行,样本
输出 = executor.execute(噪声,latent_image,采样器,sigmas,denoise_mask,回调,disable_pbar,种子)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/patcher_extension.py”,第 113 行,在执行中
返回 self.original(*args, *kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 1196 行,在 outer_sample 中
输出 = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar,种子)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 1175 行,在 inner_sample
样本 = executor.execute(self,sigmas,extra_args,callback,noise,latent_image,denoise_mask,disable_pbar)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/patcher_extension.py”,第 113 行,在执行中
返回 self.original(args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 954 行,在样本中
samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args,callback=k_callback, disable=disable_pbar,self.extra_options

^
...
^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/k_diffusion/sampling.py”,第 190 行,在 sample_euler 中
denoised = model(x, sigma_hat * s_in, extra_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 604 行,通话中
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 1155 行,在
call

return self.predict_noise(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 1158 行,在 predict_noise
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed
)
^
... conds、x、timestep、model_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 211 行,在 calc_cond_batch 中
返回 executor.execute(model、conds、x_in、timestep、model_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/patcher_extension.py”,第 113 行,在执行中
返回 self.original(*args,**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 215 行,在 _calc_cond_batch 中
返回 _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/samplers.py”,第 530 行,在 _calc_cond_batch_multigpu 中
引发错误
文件“/usr/local/lib/python3.12/threading.py”,第 1052 行,在_bootstrap_inner
self.run()
文件“/usr/local/lib/python3.12/threading.py”,第 989 行,运行
self._target(*self.args, **self.kwargs)
文件“/home/rtl-6/comfy/samplers.py”,第 511 行,在_handle_batch __
output = model_current.apply_model(input_x, timestep
, **c).to(output_device).chunk(batch_chunks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/model_base.py”,第 152 行,在 apply_model
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/patcher_extension.py”,第 113 行,在执行中
返回 self.original(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/model_base.py”,第 190 行,在 apply_model
model_output = self.diffusion_model(xc, t, context=context,控制=控制,transformer_options=transformer_options,**extra_conds).float()

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件
“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”,第 1773 行,在 _wrapped_call_impl 中
返回 self._call_impl(*args, **kwargs)
^ ... “/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”,第 1784 行,在 _call_impl 中
返回 forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/ldm/wan/model.py”,第 580 行,在 forward 中
返回 self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transform_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/ldm/wan/model.py”,第 550 行,在 forward_orig
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”,第 1773 行,在 _wrapped_call_impl 中
返回 self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”,第 1784 行,在 _call_impl 中
返回 forward_call(*args,**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/ldm/wan/model.py”,第 221 行,在 forward
y = self.self_attn(
^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”,第 1773 行,在 _wrapped_call_impl 中
返回 self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”,第 1784 行,在 _call_impl 中
返回 forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/comfy/ldm/wan/model.py”,第 72 行,在 forward
x = optimal_attention(
^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py”,第 899 行,在 _fn 中
返回 fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/custom_nodes/comfyui-kjnodes/nodes/model_optimization_nodes.py”,第 81 行,在 attention_sage 中
out = sage_func(q, k, v, attn_mask=mask, is_causal=False,tensor_layout=tensor_layout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/custom_nodes/comfyui-kjnodes/nodes/model_optimization_nodes.py”,第 36 行,在 func
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/sageattention/core.py”,第 105 行,在 sageattn 中
q_int8、q_scale、k_int8、k_scale = per_block_int8(q、k、sm_scale=sm_scale、tensor_layout=tensor_layout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/sageattention/quant_per_block.py”,第 63 行,在 per_block_int8 中
quantum_per_block_int8_kernel[grid](
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/triton/runtime/jit.py”,第 347 行,在
return lambda *args, *kwargs: self.run(grid=grid, warmup=False, args,kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/triton/runtime/jit.py”,第591行,运行
kernel.run(grid_0,grid_1,grid_2,stream,kernel.function,kernel.packed_metadata,
文件“/home/rtl-6/Python-3.12.0/comfy-env-3.12/lib/python3.12/site-packages/triton/backends/nvidia/driver.py”,第529行,在
调用

self.launch(gridX,gridY,gridZ,stream,function, self.launch_cooperative_grid,global_scratch,*args)
ValueError:无法从Triton(cpu张量?)访问指针参数(在0处)

当我将 max_gpus 设置为 1 时,Sage Attention 和 Torch Compiler 可以工作,但不适用于多个 GPU。

Hello, I also want to run this model with 4 4090 gpus. Can u share the workflow?

@DKingAlpha
Copy link

DKingAlpha commented Aug 30, 2025

@Kosinkadink AMD users will need this patch

diff --git a/comfy/model_management.py b/comfy/model_management.py
index 4ac04b8b..b396a034 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -194,7 +194,7 @@ def get_all_torch_devices(exclude_current=False):
     global cpu_state
     devices = []
     if cpu_state == CPUState.GPU:
-        if is_nvidia():
+        if is_nvidia() or is_amd():
             for i in range(torch.cuda.device_count()):
                 devices.append(torch.device(i))
         elif is_intel_xpu():

couldn't confirm further. my multigpu setup broke recently and I can't confirm whether its in the repo or my rocm stack broken.
Oh I got it working. Turned out my local (outdated) flash attention installation broken to some extend.

@jkyamog
Copy link

jkyamog commented Sep 3, 2025

Thanks for all the work done here. I added 1 more GPU to my 3x 3090 setup. I was trying with WAN video models but it only used 2 GPUs because 3 is not a binary number. So I took a smaller 3060 12GB GPU from another system, so I can run 4 GPUs. I then downgraded to WAN2.1 t2v 1.3B so it fit on all GPU VRAM including the 3060. But it seems it behaves similarly to running 3 GPUs, only 2 GPUs are actually doing work even all of the GPU has had VRAM loaded in. Is this expected? Here is what the typical load looks like through a video generation where only 2 GPUs are doing work. Btw I did reorder the GPUs using CUDA_VISIBLE_DEVICES.
Screenshot 2025-09-03 at 2 27 22 PM

@Kosinkadink
Copy link
Collaborator Author

Kosinkadink commented Sep 10, 2025

Thank you for the additional info!

@DKingAlpha thanks for the heads up!

Firstly, has anyone here been able to get this working on Linux (not WSL)? And if so, what type of GPUs were they?

Secondly, @jkyamog this PR currently only does conditioning splitting - making conditioning run on separate GPUs. Wan2.1 has only two conditioning (positive and negative) without masking, so you are only able to accelerate it 2x with 2 GPUs - the other GPUs will have no work to be split for them. This is also the issue with using models that only have one conditioning - there is nothing to split. I will be looking at some parallel attention schemes to try to overcome this limitation soon.

I did not have as much time to look into the remaining issues as I thought, I apologize for the delay. I will keep looking into it + accelerating without just conditioning soon.

@ExpandedMancho
Copy link

Hi, in order to make this setup work with 2 GPUs do you need enough VRAM to be able to run the Wan model 2 times on your first GPU?

I noticed I get OOM errors when the deep.clone part starts, I'm guessing that the clone requires the full model to load and then also the copy of the model before it can paste it into the 2nd GPU?

Thanks.

@Kosinkadink
Copy link
Collaborator Author

That should not be a requirement. What are your exact errors? (post full stack trace + workflow)

@ExpandedMancho
Copy link

That should not be a requirement. What are your exact errors? (post full stack trace + workflow)

Hey, I found out that I'm going OOM when I use load_device = main_device in the WanVideo Model Loader and multi GPU Work Unite WAN node set at 2 max_gpus. However, when I do 1 GPU and this exact setup (load_device = main_device) it does work.

  • I’ve tried with and without block swap, accelerator LoRAs, and it didn't make a difference for me.
  • If I use the load_device = offload_device instead of the main_device the workflow works and I get no OOM, but then there’s no deep cloning happening (checked the CLI) and the 2nd GPU doesn’t get used at all.
  • image

Specs

GPU: 2x RTX 5090 (32GB VRAM 2x)
CUDA Version: 12.8
RAM: 186 GB
OS: Linux
Python version: Python 3.11.11

Relevant libs

nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvshmem-cu12==3.2.5
nvidia-nvtx-cu12==12.8.90
open_clip_torch==2.32.0
pytorch-triton==3.3.1+gitc8757738
rotary-embedding-torch==0.8.8
torch==2.9.0.dev20250629+cu128
torchaudio==2.8.0.dev20250629+cu128
torchsde==0.2.6
torchvision==0.23.0.dev20250629+cu128

ComfyUI startup (pytorch attention)

Set cuda device to: 0,1
Total VRAM 32120 MB, total RAM 386469 MB
pytorch version: 2.9.0.dev20250629+cu128
Enabled fp16 accumulation.
Set vram state to: NORMAL_VRAM
Device: cuda:0 NVIDIA GeForce RTX 5090 : cudaMallocAsync
Device: cuda:1 NVIDIA GeForce RTX 5090 : cudaMallocAsync
Using pytorch attention
Python version: 3.11.11 (main, Dec 4 2024, 08:55:07) [GCC 11.4.0]
ComfyUI version: 0.3.46
ComfyUI frontend version: 1.23.4

Workflow

https://pastebin.com/g2xnixXt

Stack Trace

==========SERVER got prompt==========
prompt event_id: []
Failed to copy /tmp/models/diffusion_models/MelBandRoformer_fp16.safetensors to temp dir: '/tmp/models/diffusion_models/MelBandRoformer_fp16.safetensors' is not in the subpath of '/workspace/ComfyUI' OR one path is relative and the other is absolute. falling back to original path
Converted mono input to stereo.
Resampling input 8000 to 44100
Processing chunks: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.97it/s]
[MultiTalk] --- Raw speaker lengths (samples) ---
  speaker 1: 244000 samples (shape: torch.Size([1, 1, 244000]))
[MultiTalk] total raw duration = 15.250s
[MultiTalk] multi_audio_type=para | final waveform shape=torch.Size([1, 1, 244000]) | length=244000 samples | seconds=15.250s (expected max of raw)
Failed to copy /tmp/models/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors to temp dir: '/tmp/models/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors' is not in the subpath of '/workspace/ComfyUI' OR one path is relative and the other is absolute. falling back to original path
CLIP layer names written to clip_layers.txt
clip_target: <comfy.sd.load_text_encoder_state_dicts.<locals>.EmptyClass object at 0x7785e868ee10> parameters: 5685458817 model_options: {'load_device': device(type='cpu'), 'offload_device': device(type='cpu')}
Using scaled fp8: fp8 matrix mult: False, scale input: False
CLIP/text encoder model load device: cpu, offload device: cpu, current: cuda:0, dtype: torch.float16
Requested to load WanTEModel
loaded completely 9.5367431640625e+25 6419.477203369141 True
Requested to load CLIPVisionModelProjection
loaded completely 28918.5119140625 1208.09814453125 True
Clip embeds shape: torch.Size([1, 257, 1280]), dtype: torch.float32
Combined clip embeds shape: torch.Size([1, 257, 1280])
CUDA Compute Capability: 12.0
Detected model in_channels: 36
Model cross attention type: i2v, num_heads: 40, num_layers: 40
Model variant detected: i2v_480
InfiniteTalk detected, patching model...
model_type FLOW
Creating deepclone of WanVideoModel for cuda:1.
!!! Exception during processing !!! class_type: MultiGPU_WorkUnitsWAN node_id: 377 ex: Allocation on device 
Traceback (most recent call last):
  File "/workspace/ComfyUI/execution.py", line 482, in execute
    output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/ComfyUI/execution.py", line 292, in get_output_data
    return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/ComfyUI/execution.py", line 266, in _async_map_node_over_list
    await process_inputs(input_dict, i)
  File "/workspace/ComfyUI/execution.py", line 254, in process_inputs
    result = f(**inputs)
             ^^^^^^^^^^^
  File "/workspace/ComfyUI/comfy_extras/nodes_multigpu.py", line 71, in init_multigpu
    model = comfy.multigpu.create_multigpu_deepclones(model, max_gpus, gpu_options, reuse_loaded=True)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/ComfyUI/comfy/multigpu.py", line 90, in create_multigpu_deepclones
    device_patcher = model.deepclone_multigpu(new_load_device=device)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/ComfyUI/comfy/model_patcher.py", line 349, in deepclone_multigpu
    n.model = copy.deepcopy(n.model)
              ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 153, in deepcopy
    y = copier(memo)
        ^^^^^^^^^^^^
  File "/workspace/venv_cu128/lib/python3.11/site-packages/torch/_tensor.py", line 178, in __deepcopy__
    new_storage = self._typed_storage()._deepcopy(memo)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/venv_cu128/lib/python3.11/site-packages/torch/storage.py", line 1139, in _deepcopy
    return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo))
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/copy.py", line 153, in deepcopy
    y = copier(memo)
        ^^^^^^^^^^^^
  File "/workspace/venv_cu128/lib/python3.11/site-packages/torch/storage.py", line 243, in __deepcopy__
    new_storage = self.clone()
                  ^^^^^^^^^^^^
  File "/workspace/venv_cu128/lib/python3.11/site-packages/torch/storage.py", line 257, in clone
    return type(self)(self.nbytes(), device=self.device).copy_(self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.OutOfMemoryError: Allocation on device 

Got an OOM, unloading all loaded models. node_id: 377 class_type: MultiGPU_WorkUnitsWAN class_def: <class '/workspace/ComfyUI/comfy_extras/nodes_multigpu.MultiGPUWorkUnitsNodeWAN'>
Prompt executed in 32.86 seconds

@Kosinkadink
Copy link
Collaborator Author

Kosinkadink commented Sep 17, 2025

Multigpu work units are a feature only for nodes that use native sampling or specifically reimplement support - the node you're looking at is a wrapper custom node that does not use native sampling.

@edflyer
Copy link

edflyer commented Oct 19, 2025

Do you have a workflow that I can test to see if I installed this correctly?

@edflyer
Copy link

edflyer commented Oct 19, 2025

Thanks for all the work done here. I added 1 more GPU to my 3x 3090 setup. I was trying with WAN video models but it only used 2 GPUs because 3 is not a binary number. So I took a smaller 3060 12GB GPU from another system, so I can run 4 GPUs. I then downgraded to WAN2.1 t2v 1.3B so it fit on all GPU VRAM including the 3060. But it seems it behaves similarly to running 3 GPUs, only 2 GPUs are actually doing work even all of the GPU has had VRAM loaded in. Is this expected? Here is what the typical load looks like through a video generation where only 2 GPUs are doing work. Btw I did reorder the GPUs using CUDA_VISIBLE_DEVICES. Screenshot 2025-09-03 at 2 27 22 PM

What workflow are you using?

@rattus128
Copy link
Contributor

Firstly, has anyone here been able to get this working on Linux (not WSL)? And if so, what type of GPUs were they?

I think I have it working with a fix.

2xA40 on a runpod. I reproduced black outputs, and colorful noise in flux-dev fp8, cfg=1.1.

root@f00a481f73a0:~/ComfyUI# cat /etc/*-release
DISTRIB_ID=Ubuntu
DISTRIB_RELEASE=24.04
DISTRIB_CODENAME=noble
DISTRIB_DESCRIPTION="Ubuntu 24.04.3 LTS"
PRETTY_NAME="Ubuntu 24.04.3 LTS"

I got a black screen on some async tensor casting experiments I was doing for another change, and debugged it to be a race between the cuda streams and the pytorch garbage collector, so I thought id check for the same bug here. I remember @Kosinkadink saying this was blocked by black screens in a discord post.

So I think something similar is going on here, where the GPU->GPU ops are asynchronous WRT to the CPU and the CPU is able to run ahead and queue a cudaAsyncFree on one GPU while the other is still bus mastering the .to transfers, depending on who is the bus master and tensor owner. In the case of pull DMA this can easily be a race that corrupts tensors before transfer completes. Pytorch documentation is sparse on this so its all theory.

So if i'm right, this can be fixed by always bounce buffering through RAM which syncs the CPU:

diff --git a/comfy/samplers.py b/comfy/samplers.py
index ed702304..a93dbde4 100755
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -158,7 +158,7 @@ def cond_cat(c_list, device=None):
         conds = temp[k]
         out[k] = conds[0].concat(conds[1:])
         if device is not None and hasattr(out[k], 'to'):
-            out[k] = out[k].to(device)
+            out[k] = out[k].cpu().to(device)
 
     return out
 
@@ -470,7 +470,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, co
nds: list[list[dict]], x_in: t
                         patches = p.patches
 
                     batch_chunks = len(cond_or_uncond)
-                    input_x = torch.cat(input_x).to(device)
+                    input_x = torch.cat(input_x).cpu().to(device)
                     c = cond_cat(c, device=device)
                     timestep_ = torch.cat([timestep.to(device)] * batch_chunk
s)
 
@@ -500,9 +500,9 @@ def _calc_cond_batch_multigpu(model: BaseModel, co
nds: list[list[dict]], x_in: t
                         c['control'] = device_control.get_control(input_x, ti
mestep_, c, len(cond_or_uncond), transformer_options)
 
                     if 'model_function_wrapper' in model_options:
-                        output = model_options['model_function_wrapper'](
model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "
cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
+                        output = model_options['model_function_wrap
per'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c"
: c, "cond_or_uncond": cond_or_uncond}).cpu().to(output_device).chunk(batch_ch
unks)
                     else:
-                        output = model_current.apply_model(input_x, times
tep_, **c).to(output_device).chunk(batch_chunks)
+                        output = model_current.apply_model(input_x,
 timestep_, **c).cpu().to(output_device).chunk(batch_chunks)
                     results.append(thread_result(output, mult, area, batch_ch
unks, cond_or_uncond))
         except Exception as e:
             results.append(thread_result(None, None, None, None, None, error=
e))

There is in theory a performance penalty here as it changes the DMA path from master-slave to master-RAM-master, but i'm not observing any penalty in my initial tests.

Here is B=4 1024x1024 cfg=1.1 Flux dev speeds:

1 GPU

100%|████████████████████████████████████████████████████████| 20/20 [02:20<00:00,  7.00s/it]

2 GPUs - This branch unchanged (corrupted output)

100%|███████████████████████████████████████████████████| 20/20 [01:22<00:00,  4.11s/it]

2 GPUs - With above fix

100%|███████████████████████████████████████████████████| 20/20 [01:22<00:00,  4.11s/it]


Properly syncing the GPU->GPU DMA is a complex web of driver specifics, so this is a lot easier.

If this ends up being slow for other use cases (very large latents), you could chunk the .to as a series of queued copies instead, so the two bus masters start overlapping work and the performance will likely converge on something very close to master-slave give the above.

Screenshot from 2025-10-26 18-20-53

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.