Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/agent/README_NPU.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
Add Support for Huawei Ascend Devices on DeepEyes

# Installation

## Basic Environment Preparation

| software | version |
| :-------- | :---------- |
| Python | ==3.10 |
| CANN | ==8.1.RC1 |
| torch | ==2.5.1 |
| torch_npu | ==2.5.1.RC1 |

## Install vllm & vllm-ascend

To ensure proper usage of vllm in verl, it is recommended to install vllm & vllm-ascend via source code compilation.

### Install vllm 0.9.1

```bash
git clone -b v0.9.1 https://github.com/vllm-project/vllm.git
```

Comment out the dependency on torch in requirements/build.txt

```bash
cd vllm
pip install -r ./requirements/build.txt
VLLM_TARGET_DEVICE=empty pip install -e .
```

### Install vllm-ascend 0.9.1

```bash
git clone -b v0.9.1-dev https://github.com/vllm-project/vllm-ascend.git
cd vllm-ascend

export COMPILE_CUSTOM_KERNELS=1
pip install -e . --no-build-isolation
```

## Install verl

```bash
cd verl
pip install -r requirements-npu.txt
pip install -e .
```
# Start Training
We use Qwen-2.5-VL-7B-Instruct as our foundation model for RL training. Qwen-2.5-VL-32B-Instruct is also supported.

Step 1: Start a vllm serving of Qwen-2.5-72B-Instruct for llm-as-a-judge verification.
```bash
# download Qwen-2.5-72B-Instruct model
huggingface-cli download --resume-download https://huggingface.co/Qwen/Qwen2.5-72B-Instruct --local-dir /path/to/your/local/filedir --local-dir-use-symlinks False

# start vllm serving
vllm serve /path/to/your/local/filedir \
--port 18901 \
--gpu-memory-utilization 0.8 \
--max-model-len 32768 \
--tensor-parallel-size 8 \
--served-model-name "judge" \
--trust-remote-code \
--disable-log-requests
```

Step 2: Build a ray cluster for all of the training nodes. Prepare data before starting training. Our training dataset can be downloaded from [huggingface](https://huggingface.co/datasets/ChenShawn/DeepEyes-Datasets-47k).

Step 3: Use one of the following scripts to start training.

```bash
# your wandb access key here...
wandb login

# the IP and port for your Qwen-2.5-72B-Instruct vllm serving
export LLM_AS_A_JUDGE_BASE="http://your.vllm.machine.ip:18901/v1"

# umber of training nodes
export WORLD_SIZE=8

# config for 7B
bash examples/agent/train_qwen25vl_grpo_agent_npu.sh
```

80 changes: 80 additions & 0 deletions examples/agent/train_qwen25vl_grpo_agent_npu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
set -x

export VLLM_USE_V1=1
export LLM_AS_A_JUDGE_BASE="http://localhost:18901/v1"
export WANDB_MODE=offline
export WORLD_SIZE=1
export USE_OPTIMIZED_MODEL=0


unset PYTORCH_CUDA_ALLOC_CONF
PROJECT_NAME="agent_vlagent"
EXPERIMENT_NAME="debug_for_single_node"

export SAVE_CHECKPOINT_DIR=./verl_checkpoints

BASEDIR=./datasets_agent
VISUAL_DATASET_TRAIN_0_6_2=${BASEDIR}/data_v0.6.2_reason.parquet
VISUAL_DATASET_TRAIN_0_1_2=${BASEDIR}/data_0.1.2_visual_toolbox_v2.parquet
VISUAL_DATASET_TRAIN_0_8=${BASEDIR}/data_v0.8_visual_toolbox_v2.parquet
VISUAL_DATASET_TEST=${BASEDIR}/seekworld_test.parquet
EUREKA_DATASET_TRAIN=${BASEDIR}/data_thinklite_reasoning_acc.parquet

REF_MODEL_PATH=./Qwen2.5-VL-7B-Instruct
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
+debug=False \
+vs_debug=False \
data.train_files=[${VISUAL_DATASET_TRAIN_0_1_2},${VISUAL_DATASET_TRAIN_0_8},${EUREKA_DATASET_TRAIN}] \
data.val_files=[${EUREKA_DATASET_TRAIN}] \
data.train_batch_size=64 \
data.max_prompt_length=8192 \
data.max_response_length=20480 \
data.return_raw_chat=True \
data.filter_overlong_prompts=True \
algorithm.adv_estimator=grpo \
algorithm.kl_ctrl.kl_coef=0.0 \
actor_rollout_ref.model.path=${REF_MODEL_PATH} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.0 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0.0 \
actor_rollout_ref.actor.checkpoint.contents=['model','hf_model','optimizer','extra'] \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.rollout.max_num_batched_tokens=32768 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.enforce_eager=True \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.rollout.enable_chunked_prefill=False \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.use_torch_compile=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.rollout.agent.activate_agent=True \
actor_rollout_ref.rollout.agent.tool_name_key=env_name \
actor_rollout_ref.rollout.agent.single_response_max_tokens=10240 \
actor_rollout_ref.rollout.agent.max_turns=5 \
actor_rollout_ref.rollout.agent.concurrent_workers=1 \
actor_rollout_ref.rollout.agent.show_tqdm=True \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.val_before_train=False \
trainer.n_gpus_per_node=8 \
trainer.nnodes=${WORLD_SIZE} \
trainer.save_freq=32 \
trainer.test_freq=10000 \
+trainer.device_name=npu \
trainer.project_name=${PROJECT_NAME} \
trainer.experiment_name=${EXPERIMENT_NAME} \
trainer.default_local_dir=${SAVE_CHECKPOINT_DIR}/${PROJECT_NAME}/${EXPERIMENT_NAME} \
+trainer.tensorboard_dir=${SAVE_CHECKPOINT_DIR}/logs/tensorboard \
+trainer.rl_logging_board_dir=${SAVE_CHECKPOINT_DIR}/logs/rl_logging_board \
trainer.total_epochs=32 2>&1 | tee $(date +%Y%m%d%H%M%S).log
33 changes: 28 additions & 5 deletions verl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import pkg_resources
from packaging.version import parse as parse_version
from pkg_resources import DistributionNotFound

from .protocol import DataProto
from .utils.device import is_npu_available
from .utils.logging_utils import set_basic_config

version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__)))

with open(os.path.join(version_folder, "version/version")) as f:
__version__ = f.read().strip()

import logging

from .protocol import DataProto
from .utils.logging_utils import set_basic_config

set_basic_config(level=logging.WARNING)

from . import single_controller

__all__ = ["DataProto", "__version__"]

Expand All @@ -39,3 +43,22 @@
from modelscope.utils.hf_util import patch_hub

patch_hub()

if is_npu_available:
from .models.transformers import npu_patch as npu_patch

package_name = "transformers"
required_version_spec = "4.52.4"
try:
installed_version = pkg_resources.get_distribution(package_name).version
installed = parse_version(installed_version)
required = parse_version(required_version_spec)
if not installed >= required:
raise ValueError(
f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, "
f"current version is {installed}."
)
except DistributionNotFound as e:
raise ImportError(
f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}"
) from e
51 changes: 51 additions & 0 deletions verl/models/transformers/npu_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Copyright 2025 The Qwen Team and The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch
import torch_npu
from torch_npu import npu_rotary_mul as apply_rotary_emb
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm


# This patch takes effect when using apply_rotary_pos_emb_flashatt on
# qwen2_5_vl and will be removed in subsequent versions
# https://github.com/huggingface/transformers/pull/38491
def apply_rotary_pos_emb_flashatt_npu(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
cos = cos.repeat(1, 2)
sin = sin.repeat(1, 2)
q_embed = apply_rotary_emb(
q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()
).type_as(q)
k_embed = apply_rotary_emb(
k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float()
).type_as(k)
return q_embed, k_embed


# This api can improve performance on ASCEND NPU
def rms_norm_forward(self, x):
return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0]


Qwen2RMSNorm.forward = rms_norm_forward
modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu
3 changes: 1 addition & 2 deletions verl/models/transformers/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
)

try:
from flash_attn import flash_attn_func, flash_attn_varlen_func

from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
except ImportError:
flash_attn_varlen_func = None
Expand Down
5 changes: 3 additions & 2 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from torch.utils.data import DataLoader

from verl.utils.py_functional import union_two_dict
from verl.utils.device import get_torch_device

__all__ = ["DataProto", "union_tensor_dict"]

Expand Down Expand Up @@ -249,7 +250,7 @@ def __setstate__(self, data):
batch_deserialized_bytes, non_tensor_batch, meta_info = data
batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
batch = torch.load(
batch_deserialized, weights_only=False, map_location="cpu" if not torch.cuda.is_available() else None
batch_deserialized, weights_only=False, map_location="cpu" if not get_torch_device().is_available() else None
)
self.batch = batch
self.non_tensor_batch = non_tensor_batch
Expand Down Expand Up @@ -770,7 +771,7 @@ def all_gather_data_proto(data: DataProto, process_group):
group_size = torch.distributed.get_world_size(group=process_group)
assert isinstance(data, DataProto)
prev_device = data.batch.device
data.batch = data.batch.cuda(device=torch.cuda.current_device())
data.batch = data.batch.to(get_torch_device().current_device())
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0)
data.batch = data.batch.to(prev_device)
# all gather non_tensor_batch
Expand Down
11 changes: 6 additions & 5 deletions verl/single_controller/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from dataclasses import dataclass

from .decorator import Dispatch, Execute, register
from verl.utils.device import get_torch_device


@dataclass
Expand Down Expand Up @@ -137,7 +138,7 @@ def __init__(self, cuda_visible_devices=None) -> None:

###
# [SUPPORT AMD: torch]
if "AMD" in torch.cuda.get_device_name():
if "AMD" in get_torch_device().get_device_name():
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("ROCR_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.environ.get("RAY_LOCAL_RANK")
###
Expand All @@ -155,13 +156,13 @@ def __init__(self, cuda_visible_devices=None) -> None:

###
# [SUPPORT AMD: torch]
if "AMD" in torch.cuda.get_device_name():
if "AMD" in get_torch_device().get_device_name():
self.local_rank = int(os.environ["LOCAL_RANK"])
###

###
# [SUPPORT AMD: torch]
if "AMD" in torch.cuda.get_device_name():
if "AMD" in get_torch_device().get_device_name():
cuda_visible_devices = str(local_rank)
###

Expand All @@ -182,8 +183,8 @@ def __init__(self, cuda_visible_devices=None) -> None:
###
# [SUPPORT AMD: torch]
# torch.cuda.set_device(local_rank)
if "AMD" in torch.cuda.get_device_name():
torch.cuda.set_device(int(cuda_visible_devices))
if "AMD" in get_torch_device().get_device_name():
get_torch_device().set_device(int(cuda_visible_devices))
###

def _configure_with_meta(self, meta: WorkerMeta):
Expand Down
Loading