From 7662695cf6c11f713fc7e6a7664a0c32a663824f Mon Sep 17 00:00:00 2001 From: stevhliu Date: Mon, 1 Dec 2025 14:57:00 -0800 Subject: [PATCH 1/3] quickstart --- docs/source/en/_toctree.yml | 14 ++ docs/source/en/optimization_overview.md | 185 ++++++++++++++++++++++++ 2 files changed, 199 insertions(+) create mode 100644 docs/source/en/optimization_overview.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index b00b975eb162..d9cfb716be3a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -67,6 +67,20 @@ - local: perplexity title: Perplexity of fixed-length models title: Generate API + - sections: + - local: optimization_overview + title: Overview + - local: perf_torch_compile + title: torch.compile + - local: perf_infer_gpu_one + title: GPU + - local: perf_infer_gpu_multi + title: Distributed inference + - local: perf_infer_cpu + title: CPU + - local: perplexity + title: Perplexity of fixed-length models + title: Generate API - sections: - local: attention_interface title: Attention backends diff --git a/docs/source/en/optimization_overview.md b/docs/source/en/optimization_overview.md new file mode 100644 index 000000000000..40730b1672bd --- /dev/null +++ b/docs/source/en/optimization_overview.md @@ -0,0 +1,185 @@ + + +# Overview + +Transformers provides multiple inference optimization techniques to make models fast, affordable, and accessible. Options include alternative attention mechanisms for reduced memory traffic, code compilation for faster execution, and optimized kernels for throughput. Combine these techniques for maximum performance. + +> [!NOTE] +> Memory and speed are closely related but not the same. Shrinking your memory footprint makes a model "faster" because there is less data to move around. Pure speed optimizations don't always reduce memory and sometimes increase usage. Choose the appropriate optimization based on your use case and hardware. + +Use the table below to pick an optimization technique. + +| Technique | Speed | Memory | +|---|:---:|:---:| +| [Compilation](#compilation) | ✅ | | +| [Attention backends](#attention-backends) | ✅ | ✅ | +| [Kernels](#kernels) | ✅ | | +| [Quantization](#quantization) | ✅ | ✅ | +| [Caching](#caching) | ✅ | | +| [Parallelism](#parallelism) | ✅ | | +| [Continuous batching](#continuous-batching) | ✅ | | + +This guide gives you a quick start on optimization in Transformers. + +## Compilation + +[torch.compile](./perf_torch_compile) reduces Python overhead, fuses operations, and creates kernels tuned for your shapes and hardware. The first run warms it up and subsequent runs use the faster compiled path. + +Call `torch.compile()` on a model to enable it. + +```py +import torch +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B") +compiled_model = torch.compile(model) +``` + +## Attention backends + +Alternative [attention backends](./attention_interface) like FlashAttention lower memory traffic. They tile attention computations and avoid large intermediate tensors to reduce memory footprint. + +Set `attn_implementation` in [`~PreTrainedModel.from_pretrained`] to load an optimized attention backend. + +```py +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", attn_implementation="flash_attention_2") +``` + +## Kernels + +Kernels fuse operations to boost throughput. The [Kernels](https://huggingface.co/docs/kernels/en/index) library loads optimized compute kernels from the [Hub](https://huggingface.co/kernels-community) in a flexible and version-safe way. + +The example below loads an optimized FlashAttention-2 kernel without installing the package. + +```py +import torch +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2" +) +``` + +## Quantization + +[Quantization](./quantization/overview) shrinks the size of every parameter which lowers memory footprint and increases speed because you can do more operations. + +Pass a quantization config to the `quantization_config` argument in [`~PreTrainedModel.from_pretrained`]. Each quantization backend has a different config with different arguments. The example below quantizes a model to 4-bits and configures the computation dtype with the [bitsandbytes](./quantization/bitsandbytes) backend. + +```py +import torch +from transformers import AutoModelForCausalLM, BitsAndBytesConfig + +bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) + +model = AutoModelForCausalLM.from_pretrained( + "allenai/Olmo-3-7B-Think", quantization_config=bnb_config +) +``` + +## Caching + +[Caching](./kv_cache) increases speed by reusing past keys and values instead of recomputing them for every token. All Transformers models use a [`DynamicCache`] by default to allow the cache to grow proportionally with decoding. + +Pick a caching strategy that fits your use case. If you want maximum speed, consider a [`StaticCache`]. A [`StaticCache`] is a fixed-size cache compatible with [torch.compile](#compilation). + +Use the `cache_implementation` argument in [`~GenerationMixin.generate`] to set a cache strategy. + +```py +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2" +) +inputs = tokenizer("The Le Décret Pain states that a baguette must,", return_tensors="pt") +outputs = model.generate(**inputs, do_sample=False, max_new_tokens=50, cache_implementation="static") +``` + +## Parallelism + +[Parallelism](./perf_infer_gpu_multi) distributes a model across devices so models too big for one device run fast. This approach uses more memory due to sharding overhead and communication to sync results. + +[Tensor parallelism](./perf_infer_gpu_multi) splits a model layer across devices. Set `tp_plan="auto"` in [`~PreTrainedModel.from_pretrained`] to enable it. + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", tp_plan="auto") +print(model._tp_plan) +``` + +[Expert parallelism](./expert_parallelism) distributes experts across devices for mixture-of-experts (MoE) models. Set `enable_expert_parallel` in [`DistributedConfig`] to enable it. + +```py +from transformers import AutoModelForCausalLM +from transformers.distributed.configuration_utils import DistributedConfig + +distributed_config = DistributedConfig(enable_expert_parallel=True) +model = AutoModelForCausalLM.from_pretrained( + "openai/gpt-oss-120b", + distributed_config=distributed_config, +) +``` + +## Continuous batching + +[Continuous batching](./continuous_batching) maximizes throughput by keeping the GPU busy with dynamic scheduling and chunked prefill. [Serving](./serving.md) applications use it to process multiple incoming requests concurrently. + +Use [`~ContinuousMixin.generate_batch`] to enable continuous batching. + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + +model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-0.6B", + attn_implementation="paged|sdpa", + device_map="cuda", + torch_dtype=torch.bfloat16, +) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + +prompts = [ + "The Le Décret Pain states that a baguette must", + "Explain gravity in one sentence.", + "Name the capital of France.", +] +inputs = [tokenizer.encode(p) for p in prompts] + +generation_config = GenerationConfig( + max_new_tokens=32, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=False, + max_batch_tokens=512, +) + +outputs = model.generate_batch( + inputs=inputs, + generation_config=generation_config, +) + +for request_id, output in outputs.items(): + text = tokenizer.decode(output.generated_tokens, skip_special_tokens=True) + print(f"[{request_id}] {text}") +``` \ No newline at end of file From 3cc148aaf3e812f2f91f620599789de824a8dc86 Mon Sep 17 00:00:00 2001 From: stevhliu Date: Tue, 9 Dec 2025 14:52:37 -0800 Subject: [PATCH 2/3] feedback --- docs/source/en/_toctree.yml | 12 ----------- docs/source/en/optimization_overview.md | 28 +++++++++++++++---------- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d9cfb716be3a..d12667e9a0c6 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -70,18 +70,6 @@ - sections: - local: optimization_overview title: Overview - - local: perf_torch_compile - title: torch.compile - - local: perf_infer_gpu_one - title: GPU - - local: perf_infer_gpu_multi - title: Distributed inference - - local: perf_infer_cpu - title: CPU - - local: perplexity - title: Perplexity of fixed-length models - title: Generate API - - sections: - local: attention_interface title: Attention backends - local: continuous_batching diff --git a/docs/source/en/optimization_overview.md b/docs/source/en/optimization_overview.md index 40730b1672bd..ac38e1341912 100644 --- a/docs/source/en/optimization_overview.md +++ b/docs/source/en/optimization_overview.md @@ -27,9 +27,9 @@ Use the table below to pick an optimization technique. |---|:---:|:---:| | [Compilation](#compilation) | ✅ | | | [Attention backends](#attention-backends) | ✅ | ✅ | -| [Kernels](#kernels) | ✅ | | +| [Kernels](#kernels) | ✅ | ✅ | | [Quantization](#quantization) | ✅ | ✅ | -| [Caching](#caching) | ✅ | | +| [Caching](#caching) | ✅ | ✅ | | [Parallelism](#parallelism) | ✅ | | | [Continuous batching](#continuous-batching) | ✅ | | @@ -39,16 +39,23 @@ This guide gives you a quick start on optimization in Transformers. [torch.compile](./perf_torch_compile) reduces Python overhead, fuses operations, and creates kernels tuned for your shapes and hardware. The first run warms it up and subsequent runs use the faster compiled path. -Call `torch.compile()` on a model to enable it. +Pass a [fixed size cache](./kv_cache#fixed-size-cache) to [`~GenerationMixin.generate`] to trigger `torch.compile` automatically. ```py import torch -from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.float16, device_map="auto") +input = tokenizer("The French Bread Law states", return_tensors="pt").to(model.device) -model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B") -compiled_model = torch.compile(model) +output = model.generate(**input, do_sample=False, max_new_tokens=20, cache_implementation="static") +tokenizer.batch_decode(output, skip_special_tokens=True)[0] ``` +> [!WARNING] +> Avoid calling `torch.compile(model)` outside of [`~GenerationMixin.generate`] to prevent the model from recompiling every step. + ## Attention backends Alternative [attention backends](./attention_interface) like FlashAttention lower memory traffic. They tile attention computations and avoid large intermediate tensors to reduce memory footprint. @@ -63,7 +70,7 @@ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", attn_implementat ## Kernels -Kernels fuse operations to boost throughput. The [Kernels](https://huggingface.co/docs/kernels/en/index) library loads optimized compute kernels from the [Hub](https://huggingface.co/kernels-community) in a flexible and version-safe way. +Kernels fuse operations to boost throughput and reduce memory usage. The [Kernels](https://huggingface.co/docs/kernels/en/index) library loads optimized compute kernels from the [Hub](https://huggingface.co/kernels-community) in a flexible and version-safe way. The example below loads an optimized FlashAttention-2 kernel without installing the package. @@ -95,9 +102,8 @@ model = AutoModelForCausalLM.from_pretrained( ## Caching -[Caching](./kv_cache) increases speed by reusing past keys and values instead of recomputing them for every token. All Transformers models use a [`DynamicCache`] by default to allow the cache to grow proportionally with decoding. - -Pick a caching strategy that fits your use case. If you want maximum speed, consider a [`StaticCache`]. A [`StaticCache`] is a fixed-size cache compatible with [torch.compile](#compilation). +[Caching](./kv_cache) speeds up generation by reusing past keys and values instead of recomputing them for every token. To offset and reduce the memory cost of storing past keys and values, Transformers +supports offloading the cache to the CPU. Only the current layer remains on the GPU. Use the `cache_implementation` argument in [`~GenerationMixin.generate`] to set a cache strategy. @@ -110,7 +116,7 @@ model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2" ) inputs = tokenizer("The Le Décret Pain states that a baguette must,", return_tensors="pt") -outputs = model.generate(**inputs, do_sample=False, max_new_tokens=50, cache_implementation="static") +outputs = model.generate(**inputs, do_sample=False, max_new_tokens=50, cache_implementation="offloaded") ``` ## Parallelism From 070fcb8b52963c20bdacc4ee24e14a445d8c4189 Mon Sep 17 00:00:00 2001 From: stevhliu Date: Mon, 15 Dec 2025 14:16:33 -0800 Subject: [PATCH 3/3] feedback --- docs/source/en/optimization_overview.md | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/docs/source/en/optimization_overview.md b/docs/source/en/optimization_overview.md index ac38e1341912..86f3d7b009e7 100644 --- a/docs/source/en/optimization_overview.md +++ b/docs/source/en/optimization_overview.md @@ -16,7 +16,7 @@ rendered properly in your Markdown viewer. # Overview -Transformers provides multiple inference optimization techniques to make models fast, affordable, and accessible. Options include alternative attention mechanisms for reduced memory traffic, code compilation for faster execution, and optimized kernels for throughput. Combine these techniques for maximum performance. +Transformers provides multiple inference optimization techniques to make models fast, affordable, and accessible. Options include alternative attention mechanisms for reduced memory traffic, code compilation for faster execution, and optimized kernels for throughput. Stack these techniques for maximum performance. > [!NOTE] > Memory and speed are closely related but not the same. Shrinking your memory footprint makes a model "faster" because there is less data to move around. Pure speed optimizations don't always reduce memory and sometimes increase usage. Choose the appropriate optimization based on your use case and hardware. @@ -33,7 +33,7 @@ Use the table below to pick an optimization technique. | [Parallelism](#parallelism) | ✅ | | | [Continuous batching](#continuous-batching) | ✅ | | -This guide gives you a quick start on optimization in Transformers. +This guide gives you a quick start on Transformers optimizations. ## Compilation @@ -58,7 +58,7 @@ tokenizer.batch_decode(output, skip_special_tokens=True)[0] ## Attention backends -Alternative [attention backends](./attention_interface) like FlashAttention lower memory traffic. They tile attention computations and avoid large intermediate tensors to reduce memory footprint. +Alternative [attention backends](./attention_interface) lower memory traffic. For example, FlashAttention tiles attention computations and avoids large intermediate tensors to reduce memory footprint. Set `attn_implementation` in [`~PreTrainedModel.from_pretrained`] to load an optimized attention backend. @@ -133,19 +133,6 @@ model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruc print(model._tp_plan) ``` -[Expert parallelism](./expert_parallelism) distributes experts across devices for mixture-of-experts (MoE) models. Set `enable_expert_parallel` in [`DistributedConfig`] to enable it. - -```py -from transformers import AutoModelForCausalLM -from transformers.distributed.configuration_utils import DistributedConfig - -distributed_config = DistributedConfig(enable_expert_parallel=True) -model = AutoModelForCausalLM.from_pretrained( - "openai/gpt-oss-120b", - distributed_config=distributed_config, -) -``` - ## Continuous batching [Continuous batching](./continuous_batching) maximizes throughput by keeping the GPU busy with dynamic scheduling and chunked prefill. [Serving](./serving.md) applications use it to process multiple incoming requests concurrently.