Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,8 @@
title: FullyShardedDataParallel
- local: deepspeed
title: DeepSpeed
- local: tensor_parallelism
title: Tensor parallelism
- local: debugging
title: Multi-GPU debugging
- local: perf_train_cpu_many
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_multi.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,5 @@ The `placement` attribute tells PyTorch how to place a tensor on devices in `Dev
- Check the [expert parallelism](./expert_parallelism) guide if you're using a mixture-of-experts (MoE) model. These models support tensor parallelism and expert parallelism.

- Read the [Tensor Parallelism (TP) in Transformers: 5 Minutes to Understand](https://huggingface.co/blog/qgallouedec/tp) blog post for a quick overview of tensor parallelism and learn how column and row parallel setups differ.

- See the [Tensor parallelism](./tensor_parallelism) training guide to learn how to use it in a training setting.
101 changes: 101 additions & 0 deletions docs/source/en/tensor_parallelism.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
<!--Copyright 2026 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Tensor parallelism

Tensor parallelism (TP) splits weight matrices column-wise or row-wise across GPUs. Each GPU holds a shard, computes a partial result, and synchronizes with an all-reduce to produce the full output.

TP relies on frequent cross-GPU communication. It works best on hardware with fast intra-node links such as NVLink.

```text
┌─────────────────────────────┐
│ X (replicated) │
└────┬──────────┬─────────┬───┘
│ │ │
┌────▼───┐ ┌────▼───┐ ┌───▼────┐
│ ▓▓▓ W₀ │ │ ░░░ W₁ │ │ ███ W₂ │
│ X@W₀ │ │ X@W₁ │ │ X@W₂ │
└────┬───┘ └────┬───┘ └───┬────┘
└──────────┼─────────┘
Y₀+Y₁+Y₂
┌────────────────────────────┐
│ Y (full) │
└────────────────────────────┘
```

Transformers supports TP for architectures whose config defines `base_model_tp_plan`. Check that field first to see whether a model supports native TP.

```py
from transformers import AutoConfig

config = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B")
print(config.base_model_tp_plan is not None)
print(config.base_model_tp_plan)
```

If a model supports TP, set `tp_plan="auto"` in [`~PreTrainedModel.from_pretrained`]. Transformers initializes the device mesh and shards the supported layers for you.

> [!WARNING]
> Don't use `device_map` with `tp_plan`. The two conflict at the weight-loading level. `device_map` places whole modules on specific GPUs, while `tp_plan` shards those same parameters across all GPUs.

```py
import torch

from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
dtype=torch.bfloat16,
tp_plan="auto",
)
```

[`Trainer`] detects `tp_plan`, reads `tp_size` from the model, and creates a [`~accelerate.parallelism_config.ParallelismConfig`] automatically.

Launch training on one node with 4 GPUs.

```shell
torchrun --nproc-per-node 4 train_tp.py
```

## ParallelismConfig

Pass [`~accelerate.parallelism_config.ParallelismConfig`] explicitly when combining TP with other parallelism techniques like [FSDP](./fsdp).

```py
import torch

from accelerate import ParallelismConfig
from transformers import AutoModelForCausalLM, TrainingArguments

model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
dtype=torch.bfloat16,
tp_plan="auto",
)

parallelism_config = ParallelismConfig(tp_size=4)

args = TrainingArguments(
...,
parallelism_config=parallelism_config,
)
```

## Next steps

- Read the [Tensor Parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) chapter from The Ultra-Scale Playbook for more details about how it works.
- Read the [tensor parallelism inference guide](./perf_infer_gpu_multi) to learn more about partitioning strategies, manual TP plans, and implementation details.
Loading