From 8546321bd6e3a74744b942f1cfbca2e39fc48c3d Mon Sep 17 00:00:00 2001 From: "X. HU" Date: Tue, 29 Jul 2025 15:22:48 +0800 Subject: [PATCH 1/5] [NPU] feat: Support FSDP worker and vLLM Ascend --- verl/models/transformers/qwen2_vl.py | 3 +- verl/protocol.py | 5 +- verl/single_controller/base/worker.py | 11 +-- verl/single_controller/ray/base.py | 28 +++--- verl/trainer/fsdp_sft_trainer.py | 47 ++++++---- verl/trainer/main_generation.py | 3 +- verl/trainer/main_ppo.py | 2 + verl/trainer/ppo/ray_trainer.py | 7 +- .../checkpoint/fsdp_checkpoint_manager.py | 9 +- verl/utils/debug/performance.py | 5 +- verl/utils/device.py | 86 +++++++++++++++++++ verl/utils/distributed.py | 5 +- verl/utils/flops_counter.py | 3 +- verl/utils/fsdp_utils.py | 15 ++-- verl/workers/actor/dp_actor.py | 14 ++- verl/workers/critic/dp_critic.py | 13 ++- verl/workers/fsdp_workers.py | 53 ++++++------ verl/workers/rollout/hf_rollout.py | 3 +- verl/workers/sharding_manager/fsdp_vllm.py | 31 +++---- 19 files changed, 235 insertions(+), 108 deletions(-) create mode 100644 verl/utils/device.py diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 283abc43..a6cfe3a9 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -28,8 +28,7 @@ ) try: - from flash_attn import flash_attn_func, flash_attn_varlen_func - + from transformers.modeling_flash_attention_utils import flash_attn_func, flash_attn_varlen_func _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) except ImportError: flash_attn_varlen_func = None diff --git a/verl/protocol.py b/verl/protocol.py index fcb1ecfd..6ddd34f9 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -31,6 +31,7 @@ from torch.utils.data import DataLoader from verl.utils.py_functional import union_two_dict +from verl.utils.device import get_torch_device __all__ = ["DataProto", "union_tensor_dict"] @@ -249,7 +250,7 @@ def __setstate__(self, data): batch_deserialized_bytes, non_tensor_batch, meta_info = data batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) batch = torch.load( - batch_deserialized, weights_only=False, map_location="cpu" if not torch.cuda.is_available() else None + batch_deserialized, weights_only=False, map_location="cpu" if not get_torch_device().is_available() else None ) self.batch = batch self.non_tensor_batch = non_tensor_batch @@ -770,7 +771,7 @@ def all_gather_data_proto(data: DataProto, process_group): group_size = torch.distributed.get_world_size(group=process_group) assert isinstance(data, DataProto) prev_device = data.batch.device - data.batch = data.batch.cuda(device=torch.cuda.current_device()) + data.batch = data.batch.to(get_torch_device().current_device()) data.batch = allgather_dict_tensors(data.batch.contiguous(), size=group_size, group=process_group, dim=0) data.batch = data.batch.to(prev_device) # all gather non_tensor_batch diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 59ff599f..cf452aac 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -20,6 +20,7 @@ from dataclasses import dataclass from .decorator import Dispatch, Execute, register +from verl.utils.device import get_torch_device @dataclass @@ -137,7 +138,7 @@ def __init__(self, cuda_visible_devices=None) -> None: ### # [SUPPORT AMD: torch] - if "AMD" in torch.cuda.get_device_name(): + if "AMD" in get_torch_device().get_device_name(): os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("ROCR_VISIBLE_DEVICES") os.environ["LOCAL_RANK"] = os.environ.get("RAY_LOCAL_RANK") ### @@ -155,13 +156,13 @@ def __init__(self, cuda_visible_devices=None) -> None: ### # [SUPPORT AMD: torch] - if "AMD" in torch.cuda.get_device_name(): + if "AMD" in get_torch_device().get_device_name(): self.local_rank = int(os.environ["LOCAL_RANK"]) ### ### # [SUPPORT AMD: torch] - if "AMD" in torch.cuda.get_device_name(): + if "AMD" in get_torch_device().get_device_name(): cuda_visible_devices = str(local_rank) ### @@ -182,8 +183,8 @@ def __init__(self, cuda_visible_devices=None) -> None: ### # [SUPPORT AMD: torch] # torch.cuda.set_device(local_rank) - if "AMD" in torch.cuda.get_device_name(): - torch.cuda.set_device(int(cuda_visible_devices)) + if "AMD" in get_torch_device().get_device_name(): + get_torch_device().set_device(int(cuda_visible_devices)) ### def _configure_with_meta(self, meta: WorkerMeta): diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index a55f157b..ac6ebf8b 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -89,7 +89,7 @@ def __init__( self.pgs = None self.detached = detached - def get_placement_groups(self, strategy="STRICT_PACK", name=None): + def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"): if self.pgs is not None: return self.pgs @@ -97,13 +97,11 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None): name if name else f"{self.name_prefix}verl_group_{'_'.join([str(count) for count in self._store])}:" ) # print(f"pg_name_prefix = {pg_name_prefix}") - pg_scheme = [ - [ - {"CPU": self.max_colocate_count, "GPU": 1} if self.use_gpu else {"CPU": self.max_colocate_count} - for _ in range(process_count) - ] - for process_count in self._store - ] + if device_name == "npu": + device_name = "NPU" + elif device_name == "cuda": + device_name = "GPU" + pg_scheme = [[{"CPU": self.max_colocate_count, device_name: 1} if self.use_gpu else {"CPU": self.max_colocate_count} for _ in range(process_count)] for process_count in self._store] lifetime = "detached" if self.detached else None @@ -172,7 +170,7 @@ def update_options(self, options: Dict): self._options.update(options) def __call__( - self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None + self, placement_group, placement_group_bundle_idx, use_gpu: bool = True, num_gpus=1, sharing_with=None, device_name="cuda" ) -> Any: if sharing_with is not None: target_node_id = ray.get(sharing_with.get_node_id.remote()) @@ -189,8 +187,11 @@ def __call__( } options.update(self._options) - if use_gpu: + if use_gpu and device_name == "cuda": options["num_gpus"] = num_gpus + if use_gpu and device_name == "npu": + options["resources"] = {"NPU": num_gpus} + if len(self._additional_resource) > 1: for k, v in self._additional_resource.items(): @@ -212,13 +213,14 @@ def __init__( detached=False, worker_names=None, ray_wait_register_center_timeout: int = 300, + device_name="cuda", **kwargs, ) -> None: super().__init__(resource_pool=resource_pool, **kwargs) self.ray_cls_with_init = ray_cls_with_init self.name_prefix = get_random_string(length=6) if name_prefix is None else name_prefix self._ray_wait_register_center_timeout = ray_wait_register_center_timeout - + self.device_name = device_name if worker_names is not None: assert self._is_init_with_detached_workers self._worker_names = worker_names @@ -248,7 +250,7 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d strategy = "PACK" if bin_pack: strategy = "STRICT_PACK" - pgs = resource_pool.get_placement_groups(strategy=strategy) + pgs = resource_pool.get_placement_groups(strategy=strategy, device_name=self.device_name) world_size = resource_pool.world_size self._world_size = world_size # cia.add_kwarg("_world_size", world_size) @@ -288,7 +290,7 @@ def _init_with_resource_pool(self, resource_pool, ray_cls_with_init, bin_pack, d # create a worker worker = ray_cls_with_init( - placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus + placement_group=pg, placement_group_bundle_idx=local_rank, use_gpu=use_gpu, num_gpus=num_gpus, device_name=self.device_name ) self._workers.append(worker) self._worker_names.append(name) diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index e6b895ec..cafc9f80 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -29,7 +29,6 @@ import torch import torch.distributed -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from peft import LoraConfig, TaskType, get_peft_model from tensordict import TensorDict from torch import nn, optim @@ -54,7 +53,11 @@ ulysses_pad_and_slice_inputs, ) from verl.workers.sharding_manager import FSDPUlyssesShardingManager - +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_SFT_LOGGING_LEVEL", "WARN")) @@ -85,6 +88,7 @@ def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceM self.device_mesh = device_mesh self.ulysses_device_mesh = ulysses_device_mesh self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) + self.device_name = get_device_name() # build tokenizer first local_model_path = copy_to_local(src=self.config.model.partial_pretrain, verbose=True) from verl.utils import hf_tokenizer @@ -274,7 +278,7 @@ def _build_model_optimizer(self): mixed_precision=mixed_precision, device_mesh=self.device_mesh, sync_module_states=True, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), cpu_offload=cpu_offload, use_orig_params=False, ) @@ -316,15 +320,15 @@ def _compute_loss_and_backward(self, batch, do_backward=True): use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 # Move inputs to GPU and prepare loss mask - input_ids = batch["input_ids"].cuda() - attention_mask = batch["attention_mask"].cuda() - position_ids = batch["position_ids"].cuda() - loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).cuda() + input_ids = batch["input_ids"].to(self.device_name) + attention_mask = batch["attention_mask"].to(self.device_name) + position_ids = batch["position_ids"].to(self.device_name) + loss_mask = batch.pop("loss_mask")[:, :-1].reshape(-1).to(self.device_name) loss_fct = nn.CrossEntropyLoss(reduction="none") # Context manager for sequence parallel if needed context = self.sharding_manager if use_sp else nullcontext() - with context, torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with context, torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): if not use_sp: # Standard forward pass without sequence parallel labels = input_ids[:, 1:].contiguous() @@ -446,15 +450,23 @@ def training_step(self, batch: TensorDict): log_gpu_memory_usage("After offload weights", logger=logger) - step_loss = torch.tensor(step_loss).cuda() - torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) - return {"train/loss": step_loss.detach().item(), "train/lr(1e-3)": lr * 1e3} + step_loss = torch.tensor(step_loss).to(self.device_name) + if is_cuda_available: + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(step_loss) + step_loss /= self.ulysses_device_mesh.size(0) + return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} def validation_step(self, batch: TensorDict): self.fsdp_model.eval() with torch.no_grad(): loss = self._compute_loss_and_backward(batch, do_backward=False) - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + if is_cuda_available: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + elif is_npu_available: + torch.distributed.all_reduce(loss) + loss /= self.ulysses_device_mesh.size(0) return loss def save_checkpoint(self, step): @@ -508,7 +520,7 @@ def fit(self): desc=f"Epoch {epoch + 1}/{self.config.trainer.total_epochs}", ): global_step += 1 - data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() + data = TensorDict(data, batch_size=self.config.data.train_batch_size).to(self.device_name) metric = self.training_step(data) if rank == 0: tracking.log(data=metric, step=global_step) @@ -518,7 +530,7 @@ def fit(self): # Perform final validation val_losses = [] for val_data in self.val_dataloader: - val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) val_loss = self.validation_step(val_data) val_losses.append(val_loss) if rank == 0: @@ -534,7 +546,7 @@ def fit(self): # validation val_losses = [] for data in self.val_dataloader: - data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).to(self.device_name) val_loss = self.validation_step(data) val_losses.append(val_loss) if rank == 0: @@ -556,12 +568,13 @@ def fit(self): @hydra.main(config_path="config", config_name="sft_trainer", version_base=None) def main(config): + device_name = get_device_name() local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) + device_mesh = init_device_mesh(device_type=device_name, mesh_shape=(world_size,), mesh_dim_names=("fsdp",)) dp_size = world_size // config.ulysses_sequence_parallel_size ulysses_device_mesh = init_device_mesh( - device_type="cuda", mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") + device_type=device_name, mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), mesh_dim_names=("dp", "sp") ) trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) trainer.fit() diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 0d12b5b4..0f80b7ca 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -38,6 +38,7 @@ from verl.utils.hdfs_io import makedirs from verl.utils.model import compute_position_id_with_mask from verl.workers.fsdp_workers import ActorRolloutRefWorker +from verl.utils.device import is_cuda_available @hydra.main(config_path="config", config_name="generation", version_base=None) @@ -81,7 +82,7 @@ def main_task(config): ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role="rollout") resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) - wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) + wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init, device_name="cuda" if is_cuda_available else "npu") wg.init_model() total_samples = len(dataset) diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 41e7b7c2..9f611513 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -21,6 +21,7 @@ import ray from verl.trainer.ppo.ray_trainer import RayPPOTrainer +from verl.utils.device import is_cuda_available def get_custom_reward_fn(config): @@ -203,6 +204,7 @@ def run(self, config): ray_worker_group_cls=ray_worker_group_cls, reward_fn=reward_fn, val_reward_fn=val_reward_fn, + device_name="cuda" if is_cuda_available else "npu" ) trainer.init_workers() trainer.fit() diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 01679e9c..aa5b5af7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -122,8 +122,7 @@ def get_n_gpus(self) -> int: def _check_resource_available(self): """Check if the resource pool can be satisfied in this ray cluster.""" node_available_resources = ray.state.available_resources_per_node() - node_available_gpus = {node: node_info.get("GPU", 0) for node, node_info in node_available_resources.items()} - + node_available_gpus = {node: node_info.get("GPU", 0) if "GPU" in node_info else node_info.get("NPU", 0) for node, node_info in node_available_resources.items()} # check total required gpus can be satisfied total_available_gpus = sum(node_available_gpus.values()) total_required_gpus = sum( @@ -277,6 +276,7 @@ def __init__( processor=None, reward_fn=None, val_reward_fn=None, + device_name="cuda", ): # assert torch.cuda.is_available(), 'cuda must be available on driver' @@ -297,6 +297,7 @@ def __init__( self.use_reference_policy = Role.RefPolicy in role_worker_mapping self.use_rm = Role.RewardModel in role_worker_mapping self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name self.validation_generations_logger = ValidationGenerationsLogger() # define in-reward KL control @@ -721,7 +722,7 @@ def init_workers(self): for resource_pool, class_dict in self.resource_pool_to_cls.items(): worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) wg_dict = self.ray_worker_group_cls( - resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, **wg_kwargs + resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls, device_name=self.device_name, **wg_kwargs ) spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) all_wg.update(spawn_wg) diff --git a/verl/utils/checkpoint/fsdp_checkpoint_manager.py b/verl/utils/checkpoint/fsdp_checkpoint_manager.py index 0bbc3d38..977473eb 100644 --- a/verl/utils/checkpoint/fsdp_checkpoint_manager.py +++ b/verl/utils/checkpoint/fsdp_checkpoint_manager.py @@ -23,6 +23,7 @@ from transformers import PreTrainedTokenizer, ProcessorMixin from verl.utils.fs import copy_to_local, is_non_local +from verl.utils.device import is_cuda_available from .checkpoint_manager import BaseCheckpointManager @@ -100,8 +101,8 @@ def load_checkpoint(self, local_path: str, hdfs_path: str = None, del_local_afte lr_scheduler_state_dict = extra_state_dict["lr_scheduler"] - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): self.model.load_state_dict(model_state_dict) if self.optimizer is not None: @@ -136,8 +137,8 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i torch.distributed.barrier() # every rank will save its own model and optim shard - state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True) - optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True) + state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True if is_cuda_available else False) + optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True if is_cuda_available else False) with warnings.catch_warnings(): warnings.simplefilter("ignore") with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg): diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index aacaaba5..0b263115 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -18,6 +18,7 @@ import torch.distributed as dist from verl.utils.logger.aggregate_logger import DecoratorLoggerBase +from verl.utils.device import get_torch_device def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): @@ -62,8 +63,8 @@ def f(*args, **kwargs): return f def log(self, func, *args, **kwargs): - memory_allocated = torch.cuda.memory_allocated() / 1024**3 - memory_reserved = torch.cuda.memory_reserved() / 1024**3 + memory_allocated = get_device_name().memory_allocated() / 1024**3 + memory_reserved = get_device_name().memory_reserved() / 1024**3 message = f"Before {func.__name__}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}" self.logging_function(message) diff --git a/verl/utils/device.py b/verl/utils/device.py new file mode 100644 index 00000000..ed85b0d5 --- /dev/null +++ b/verl/utils/device.py @@ -0,0 +1,86 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# This code is inspired by the torchtune. +# https://github.com/pytorch/torchtune/blob/main/torchtune/utils/_device.py +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license in https://github.com/pytorch/torchtune/blob/main/LICENSE + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def is_torch_npu_available() -> bool: + """Check the availability of NPU""" + try: + import torch_npu # noqa: F401 + + return torch.npu.is_available() + except ImportError: + return False + + +is_cuda_available = torch.cuda.is_available() +is_npu_available = is_torch_npu_available() + + +def get_visible_devices_keyword() -> str: + """Function that gets visible devices keyword name. + Returns: + 'CUDA_VISIBLE_DEVICES' or `ASCEND_RT_VISIBLE_DEVICES` + """ + return "CUDA_VISIBLE_DEVICES" if is_cuda_available else "ASCEND_RT_VISIBLE_DEVICES" + + +def get_device_name() -> str: + """Function that gets the torch.device based on the current machine. + This currently only supports CPU, CUDA, NPU. + Returns: + device + """ + if is_cuda_available: + device = "cuda" + elif is_npu_available: + device = "npu" + else: + device = "cpu" + return device + + +def get_torch_device() -> any: + """Return the corresponding torch attribute based on the device type string. + Returns: + module: The corresponding torch device namespace, or torch.cuda if not found. + """ + device_name = get_device_name() + try: + return getattr(torch, device_name) + except AttributeError: + logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.") + return torch.cuda + + +def get_device_id() -> int: + """Return current device id based on the device type. + Returns: + device index + """ + return get_torch_device().current_device() + + +def get_nccl_backend() -> str: + """Return nccl backend type based on the device type. + Returns: + nccl backend type string. + """ + if is_cuda_available: + return "nccl" + elif is_npu_available: + return "hccl" + else: + raise RuntimeError(f"No available nccl backend found on device type {get_device_name()}.") diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 7aa30c16..82363351 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -14,6 +14,7 @@ """Utilities for distributed training.""" import os +from verl.utils.device import is_cuda_available, get_torch_device def initialize_global_process_group(timeout_second=36000): @@ -21,11 +22,11 @@ def initialize_global_process_group(timeout_second=36000): import torch.distributed - torch.distributed.init_process_group("nccl", timeout=timedelta(seconds=timeout_second)) + torch.distributed.init_process_group("nccl" if is_cuda_available else "hccl",, timeout=timedelta(seconds=timeout_second)) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) if torch.distributed.is_initialized(): - torch.cuda.set_device(local_rank) + get_torch_device().set_device(local_rank) return local_rank, rank, world_size diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 4fb11943..ba31303a 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -14,6 +14,7 @@ import torch from transformers import PretrainedConfig +from verl.utils.device import is_cuda_available, get_torch_device VALID_CONFIG_TYPE = {"llama", "qwen2", "qwen2_vl", "qwen2_5_vl", "deepseek_v3"} @@ -29,7 +30,7 @@ def unit_convert(number, level): ptr += 1 return number - device_name = torch.cuda.get_device_name() + device_name = get_torch_device().get_device_name() flops = float("inf") # INF flops for unkown gpu type if "MI300X" in device_name: diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 555e2e8b..0b28e804 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -28,12 +28,13 @@ from torch.distributed.fsdp._runtime_utils import _lazy_init from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name +from verl.utils.device import get_torch_device, get_device_name def init_fn(x: torch.nn.Module): if torch.distributed.get_rank() != 0: - x = x.to_empty(device=torch.cuda.current_device(), recurse=False) - torch.cuda.empty_cache() + x = x.to_empty(device=get_device_name().current_device(), recurse=False) + get_device_name().empty_cache() return x @@ -135,7 +136,7 @@ def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True): flat_param._local_shard = flat_param.data assert id(flat_param._local_shard) != id(flat_param.data) if empty_cache: - torch.cuda.empty_cache() + get_torch_device().empty_cache() @torch.no_grad() @@ -144,12 +145,12 @@ def load_fsdp_model_to_gpu(model: FSDP): # lazy init FSDP model _lazy_init(model, model) assert model._is_root, "Only support root model loading to GPU" - device_id = torch.cuda.current_device() + device_id = get_torch_device().current_device() for handle in model._all_handles: if handle._offload_params: continue flat_param = handle.flat_param - handle.flat_param_to(torch.device(f"cuda:{device_id}"), non_blocking=True) + handle.flat_param_to(torch.device(f"{get_device_name()}:{device_id}"), non_blocking=True) # the following still keeps id(._local_shard) != id(.data) flat_param._local_shard = flat_param.data @@ -251,7 +252,7 @@ def parallel_load_safetensors(filepath): ckpt_chunks = [ckpt_chunks[rank * size : rank * size + size] for rank in range(world_size)] shard_states = {} - device = torch.cuda.current_device() + device = get_torch_device().current_device() for rank, files in enumerate(ckpt_chunks): if rank == dist.get_rank(): for file in files: @@ -291,7 +292,7 @@ def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, tor @torch.no_grad() def create_and_sync_state(param_name, state, is_param): assert param_name in shard_states, f"{param_name} not loaded" - device = torch.cuda.current_device() + device = get_torch_device().current_device() if is_param: param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) else: # buffer diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 84d85b73..1338bf0e 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -21,7 +21,6 @@ from typing import Tuple import torch -from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -34,6 +33,12 @@ from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available + +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis __all__ = ["DataParallelPPOActor"] @@ -57,6 +62,7 @@ def __init__(self, config, actor_module: nn.Module, actor_optimizer: torch.optim if self.config.get("use_torch_compile", True) # use torch compile by default else verl_F.entropy_from_logits ) + self.device_name = get_device_name() def _forward_micro_batch( self, micro_batch, temperature, calculate_entropy=False @@ -74,7 +80,7 @@ def _forward_micro_batch( [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 ) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.autocast(device_type=self.deivce_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -310,9 +316,9 @@ def update_policy(self, data: DataProto): for data in micro_batches: # Support all hardwares if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = {**data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch} else: - data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload + data = data.to(get_torch_device().current_device()) # actor device is cpu when using offload responses = data["responses"] response_length = responses.size(1) action_or_attn_mask = data['action_mask'] if 'action_mask' in data.keys() else data['attention_mask'] diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index b6e133e9..8b66d2d2 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -33,8 +33,12 @@ from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.critic import BasePPOCritic +from verl.utils.device import get_device_name, get_torch_device, is_npu_available, is_cuda_available -__all__ = ["DataParallelPPOCritic"] +if is_cuda_available: + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis +elif is_npu_available: + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -49,6 +53,7 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt print(f"Critic use_remove_padding={self.use_remove_padding}") self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1) + self.device_name = get_device_name() def _forward_micro_batch(self, micro_batch): response_length = micro_batch["responses"].size(-1) @@ -59,7 +64,7 @@ def _forward_micro_batch(self, micro_batch): [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 ) - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -221,9 +226,9 @@ def update_critic(self, data: DataProto): for data in micro_batches: # Support all devices if isinstance(data, DataProto): - data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} + data = {**data.batch.to(get_torch_device().current_device()), **data.non_tensor_batch} else: - data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload + data = data.to(get_torch_device().current_device()) # critic device is cpu when using offload input_ids = data["input_ids"] responses = data["responses"] diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 979a5020..3c7d914d 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -47,17 +47,20 @@ from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +device_name = get_device_name() + def create_device_mesh(world_size, fsdp_size): if fsdp_size < 0 or fsdp_size >= world_size: - device_mesh = init_device_mesh("cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) + device_mesh = init_device_mesh(device_name, mesh_shape=(world_size,), mesh_dim_names=["fsdp"]) else: device_mesh = init_device_mesh( - "cuda", mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] + device_name, mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"] ) return device_mesh @@ -99,7 +102,7 @@ def __init__(self, config: DictConfig, role: str): dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -279,7 +282,7 @@ def _build_model_optimizer( param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device.current_device(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, @@ -336,7 +339,7 @@ def _build_rollout(self, trust_remote_code=False): assert self.world_size % infer_tp == 0, ( f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" ) - rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) + rollout_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) rollout_name = self.config.rollout.name if rollout_name == "hf": from verl.workers.rollout import HFRollout @@ -499,13 +502,13 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_torch_device().current_device()) with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) @@ -518,8 +521,8 @@ def update_actor(self, data: DataProto): metrics["perf/mfu/actor"] = ( estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size ) - metrics["perf/max_memory_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3) - metrics["perf/max_memory_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3) + metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3) + metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3) metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3) self.actor_lr_scheduler.step() @@ -542,7 +545,7 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): # Support all hardwares - prompts = prompts.to(torch.cuda.current_device()) + prompts = prompts.to(get_torch_device().current_device()) assert self._is_rollout if self._is_offload_param: @@ -572,7 +575,7 @@ def generate_sequences(self, prompts: DataProto): output = output.to("cpu") # clear kv cache - torch.cuda.empty_cache() + get_torch_device().empty_cache() return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -582,7 +585,7 @@ def compute_log_prob(self, data: DataProto): load_fsdp_model_to_gpu(self.actor_module_fsdp) # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -615,7 +618,7 @@ def compute_ref_log_prob(self, data: DataProto): assert self._is_ref # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info["micro_batch_size"] = micro_batch_size @@ -676,7 +679,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -691,7 +694,7 @@ def __init__(self, config): dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -815,7 +818,7 @@ def _build_critic_model_optimizer(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, @@ -889,7 +892,7 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -912,11 +915,11 @@ def compute_values(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: - load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) + load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=get_torch_device().current_device()) # perform forward computation with self.ulysses_sharding_manager: @@ -990,7 +993,7 @@ def __init__(self, config): import torch.distributed if not torch.distributed.is_initialized(): - torch.distributed.init_process_group(backend="nccl") + torch.distributed.init_process_group(backend="nccl" if is_cuda_available else "hccl") self.config = config # build device mesh for Ulysses Sequence Parallel @@ -1005,7 +1008,7 @@ def __init__(self, config): dp = world_size // self.ulysses_sequence_parallel_size if self.ulysses_sequence_parallel_size > 1: self.ulysses_device_mesh = init_device_mesh( - "cuda", mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] + device_name, mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"] ) self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) @@ -1073,7 +1076,7 @@ def _build_model(self, config): param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=torch.cuda.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, cpu_offload=CPUOffload(offload_params=True), @@ -1094,7 +1097,7 @@ def _forward_micro_batch(self, micro_batch): from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + with torch.no_grad(), torch.autocast(device_type=device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -1229,7 +1232,7 @@ def compute_rm_score(self, data: DataProto): from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches # Support all hardwares - data = data.to(torch.cuda.current_device()) + data = data.to(get_torch_device().current_device()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) else: @@ -1244,7 +1247,7 @@ def compute_rm_score(self, data: DataProto): rm_data = DataProto.from_dict(rm_inputs) # Support all hardwares - rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) + rm_data.batch = rm_data.batch.to(get_torch_device().current_device()) # perform forward computation with self.ulysses_sharding_manager: diff --git a/verl/workers/rollout/hf_rollout.py b/verl/workers/rollout/hf_rollout.py index 60f9bc9d..3d7208f7 100644 --- a/verl/workers/rollout/hf_rollout.py +++ b/verl/workers/rollout/hf_rollout.py @@ -28,6 +28,7 @@ from verl import DataProto from verl.utils.torch_functional import get_response_mask +from verl.utils.device import get_torch_device from .base import BaseRollout @@ -138,7 +139,7 @@ def _generate_minibatch(self, prompts: DataProto) -> DataProto: ) # empty cache before compute old_log_prob - torch.cuda.empty_cache() + get_torch_device().empty_cache() self.module.train() return DataProto(batch=batch) diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 8b20efa8..16396c4f 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -26,6 +26,7 @@ from verl.third_party.vllm import LLM, vllm_version from verl.third_party.vllm import parallel_state as vllm_ps from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage +from verl.utils.device import get_torch_device from .base import BaseShardingManager from .patch import patched_ds_v3_load_weights, patched_qwen_moe_load_weights @@ -65,26 +66,26 @@ def __init__( self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() # Note that torch_random_states may be different on each dp rank - self.torch_random_states = torch.cuda.get_rng_state() + self.torch_random_states = get_torch_device.get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - torch.cuda.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + get_torch_device.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device.get_rng_state() + get_torch_device.set_rng_state(self.torch_random_states) else: self.gen_random_states = None @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __enter__(self): - # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and + # NOTE: Basically, we only need `get_torch_device.empty_cache()` before vllm wake_up and # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory # to speed up memory allocations. # # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 - torch.cuda.empty_cache() + get_torch_device.empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) params = self.module.state_dict() @@ -109,7 +110,7 @@ def __enter__(self): self.update_params(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params - torch.cuda.empty_cache() + get_torch_device.empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) @@ -118,14 +119,14 @@ def __enter__(self): # TODO: offload FSDP model weights # self.module.cpu() - # torch.cuda.empty_cache() + # get_torch_device.empty_cache() # if torch.distributed.get_rank() == 0: - # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') + # print(f'after model to cpu in sharding manager memory allocated: {get_torch_device.memory_allocated() / 1e9}GB, reserved: {get_torch_device.memory_reserved() / 1e9}GB') # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: - self.torch_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.gen_random_states) + self.torch_random_states = get_torch_device.get_rng_state() + get_torch_device.set_rng_state(self.gen_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __exit__(self, exc_type, exc_value, traceback): @@ -140,17 +141,17 @@ def __exit__(self, exc_type, exc_value, traceback): # self.module.to('cuda') # if torch.distributed.get_rank() == 0: - # print(f'after actor module to cuda in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') + # print(f'after actor module to cuda in sharding manager memory allocated: {get_torch_device.memory_allocated() / 1e9}GB, reserved: {get_torch_device.memory_reserved() / 1e9}GB') self.module.train() # add empty cache after each compute - torch.cuda.empty_cache() + get_torch_device.empty_cache() # restore random states if self.device_mesh is not None: - self.gen_random_states = torch.cuda.get_rng_state() - torch.cuda.set_rng_state(self.torch_random_states) + self.gen_random_states = get_torch_device.get_rng_state() + get_torch_device.set_rng_state(self.torch_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def preprocess_data(self, data: DataProto) -> DataProto: From 756b5534957fb2c4f7c3d8db727be35643269503 Mon Sep 17 00:00:00 2001 From: zyang6 Date: Thu, 14 Aug 2025 09:39:15 +0000 Subject: [PATCH 2/5] Add Fused NPU Operators --- verl/__init__.py | 33 ++++++++++++++--- verl/models/transformers/npu_patch.py | 51 +++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 5 deletions(-) create mode 100644 verl/models/transformers/npu_patch.py diff --git a/verl/__init__.py b/verl/__init__.py index 89fab359..61720345 100644 --- a/verl/__init__.py +++ b/verl/__init__.py @@ -12,21 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os +import pkg_resources +from packaging.version import parse as parse_version +from pkg_resources import DistributionNotFound + +from .protocol import DataProto +from .utils.device import is_npu_available +from .utils.logging_utils import set_basic_config + version_folder = os.path.dirname(os.path.join(os.path.abspath(__file__))) with open(os.path.join(version_folder, "version/version")) as f: __version__ = f.read().strip() -import logging - -from .protocol import DataProto -from .utils.logging_utils import set_basic_config set_basic_config(level=logging.WARNING) -from . import single_controller __all__ = ["DataProto", "__version__"] @@ -39,3 +43,22 @@ from modelscope.utils.hf_util import patch_hub patch_hub() + +if is_npu_available: + from .models.transformers import npu_patch as npu_patch + + package_name = "transformers" + required_version_spec = "4.52.4" + try: + installed_version = pkg_resources.get_distribution(package_name).version + installed = parse_version(installed_version) + required = parse_version(required_version_spec) + if not installed >= required: + raise ValueError( + f"{package_name} version >= {required_version_spec} is required on ASCEND NPU, " + f"current version is {installed}." + ) + except DistributionNotFound as e: + raise ImportError( + f"package {package_name} is not installed, please run pip install {package_name}=={required_version_spec}" + ) from e diff --git a/verl/models/transformers/npu_patch.py b/verl/models/transformers/npu_patch.py new file mode 100644 index 00000000..6378ae73 --- /dev/null +++ b/verl/models/transformers/npu_patch.py @@ -0,0 +1,51 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Copyright 2025 The Qwen Team and The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch_npu +from torch_npu import npu_rotary_mul as apply_rotary_emb +from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm + + +# This patch takes effect when using apply_rotary_pos_emb_flashatt on +# qwen2_5_vl and will be removed in subsequent versions +# https://github.com/huggingface/transformers/pull/38491 +def apply_rotary_pos_emb_flashatt_npu( + q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + cos = cos.chunk(2, dim=-1)[0].contiguous() + sin = sin.chunk(2, dim=-1)[0].contiguous() + cos = cos.repeat(1, 2) + sin = sin.repeat(1, 2) + q_embed = apply_rotary_emb( + q.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() + ).type_as(q) + k_embed = apply_rotary_emb( + k.float(), cos.unsqueeze(0).unsqueeze(2).float(), sin.unsqueeze(0).unsqueeze(2).float() + ).type_as(k) + return q_embed, k_embed + + +# This api can improve performance on ASCEND NPU +def rms_norm_forward(self, x): + return torch_npu.npu_rms_norm(x, self.weight, epsilon=self.variance_epsilon)[0] + + +Qwen2RMSNorm.forward = rms_norm_forward +modeling_qwen2_5_vl.apply_rotary_pos_emb_flashatt = apply_rotary_pos_emb_flashatt_npu From 8d380655c88d6f5f3dc5a89fc991d25ef825d656 Mon Sep 17 00:00:00 2001 From: icerain-alt <450125138@qq.com> Date: Fri, 15 Aug 2025 03:05:37 +0000 Subject: [PATCH 3/5] feat: Add NPU device support for Verl --- verl/utils/debug/performance.py | 25 ++++++----- verl/utils/distributed.py | 7 ++- verl/utils/fsdp_utils.py | 7 +-- verl/workers/actor/dp_actor.py | 26 ++++++++--- .../envs/visual_agent/mm_search_engine.py | 45 ++++++++++--------- verl/workers/fsdp_workers.py | 10 +++-- verl/workers/rollout/vllm_rollout/__init__.py | 27 ++++++----- verl/workers/sharding_manager/fsdp_vllm.py | 33 +++++++------- 8 files changed, 103 insertions(+), 77 deletions(-) diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index 0b263115..b849df66 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -17,8 +17,8 @@ import torch import torch.distributed as dist -from verl.utils.logger.aggregate_logger import DecoratorLoggerBase from verl.utils.device import get_torch_device +from verl.utils.logger.aggregate_logger import DecoratorLoggerBase def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): @@ -36,10 +36,10 @@ def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging class GPUMemoryLogger(DecoratorLoggerBase): """_summary_ - + Usage: For example, in actor function, we initialize a GPUMemoryLogger - + ``` from verl.utils.debug.performance import GPUMemoryLogger @GPUMemoryLogger(role="actor") @@ -47,28 +47,31 @@ def update_actor(self, batch): # do something return ``` - + """ - + def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): if dist.is_initialized() and dist.get_world_size() > 1: rank = dist.get_rank() else: rank = 0 super().__init__(role, logger, level, rank, log_only_rank_0) - + def __call__(self, decorated_function: callable): def f(*args, **kwargs): return self.log(decorated_function, *args, **kwargs) + return f - + def log(self, func, *args, **kwargs): - memory_allocated = get_device_name().memory_allocated() / 1024**3 - memory_reserved = get_device_name().memory_reserved() / 1024**3 + memory_allocated = get_torch_device().memory_allocated() / 1024**3 + memory_reserved = get_torch_device().memory_reserved() / 1024**3 message = f"Before {func.__name__}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}" self.logging_function(message) output = func(*args, **kwargs) - message = f"After {func.__name__}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}" + message = ( + f"After {func.__name__}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}" + ) self.logging_function(message) - return output \ No newline at end of file + return output diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 82363351..69b2214c 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -14,7 +14,8 @@ """Utilities for distributed training.""" import os -from verl.utils.device import is_cuda_available, get_torch_device + +from verl.utils.device import get_torch_device, is_cuda_available def initialize_global_process_group(timeout_second=36000): @@ -22,7 +23,9 @@ def initialize_global_process_group(timeout_second=36000): import torch.distributed - torch.distributed.init_process_group("nccl" if is_cuda_available else "hccl",, timeout=timedelta(seconds=timeout_second)) + torch.distributed.init_process_group( + "nccl" if is_cuda_available else "hccl", timeout=timedelta(seconds=timeout_second) + ) local_rank = int(os.environ["LOCAL_RANK"]) rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 0b28e804..195aa1fe 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -28,13 +28,14 @@ from torch.distributed.fsdp._runtime_utils import _lazy_init from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name -from verl.utils.device import get_torch_device, get_device_name + +from verl.utils.device import get_device_name, get_torch_device def init_fn(x: torch.nn.Module): if torch.distributed.get_rank() != 0: - x = x.to_empty(device=get_device_name().current_device(), recurse=False) - get_device_name().empty_cache() + x = x.to_empty(device=get_torch_device().current_device(), recurse=False) + get_torch_device().empty_cache() return x diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 1338bf0e..4c628eb3 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -28,17 +28,17 @@ from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty from verl.utils.debug import GPUMemoryLogger +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor -from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available if is_cuda_available: - from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis + from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input elif is_npu_available: - from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis + from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input __all__ = ["DataParallelPPOActor"] @@ -80,7 +80,7 @@ def _forward_micro_batch( [inputs[key] for inputs in micro_batch["multi_modal_inputs"]], dim=0 ) - with torch.autocast(device_type=self.deivce_name, dtype=torch.bfloat16): + with torch.autocast(device_type=self.device_name, dtype=torch.bfloat16): input_ids = micro_batch["input_ids"] batch_size, seqlen = input_ids.shape attention_mask = micro_batch["attention_mask"] @@ -275,7 +275,15 @@ def update_policy(self, data: DataProto): temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error - select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages", "action_mask"] + select_keys = [ + "responses", + "input_ids", + "attention_mask", + "position_ids", + "old_log_probs", + "advantages", + "action_mask", + ] if self.config.use_kl_loss: select_keys.append("ref_log_prob") batch = data.select(batch_keys=select_keys, strict=False).batch @@ -300,7 +308,9 @@ def update_policy(self, data: DataProto): self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu ) num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys, strict=False).chunk(num_micro_batches) + micro_batches = data.select(select_keys, non_tensor_select_keys, strict=False).chunk( + num_micro_batches + ) elif self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) @@ -321,7 +331,9 @@ def update_policy(self, data: DataProto): data = data.to(get_torch_device().current_device()) # actor device is cpu when using offload responses = data["responses"] response_length = responses.size(1) - action_or_attn_mask = data['action_mask'] if 'action_mask' in data.keys() else data['attention_mask'] + action_or_attn_mask = ( + data["action_mask"] if "action_mask" in data.keys() else data["attention_mask"] + ) response_mask = action_or_attn_mask[:, -response_length:] old_log_prob = data["old_log_probs"] diff --git a/verl/workers/agent/envs/visual_agent/mm_search_engine.py b/verl/workers/agent/envs/visual_agent/mm_search_engine.py index d5716cd2..2778b22a 100644 --- a/verl/workers/agent/envs/visual_agent/mm_search_engine.py +++ b/verl/workers/agent/envs/visual_agent/mm_search_engine.py @@ -1,19 +1,14 @@ -import re -import os import json -import random -import requests -import numpy as np from io import BytesIO -from PIL import Image -from time import sleep -from playwright.sync_api import sync_playwright, Playwright from duckduckgo_search import DDGS +from PIL import Image +from playwright.sync_api import Playwright, sync_playwright -from verl.utils.dataset.rl_dataset import process_image +# from verl.utils.dataset.rl_dataset import process_image from verl.workers.agent.tool_envs import ToolBase, extract_tool_call_contents + class MMSearchEngine(ToolBase): name = "mm_search" @@ -32,20 +27,22 @@ def __init__(self, _name, **kwargs): super().__init__(name=self.name) def execute(self, action_string, **kwargs): - self.chatml_history.append({ - "role": "assistant", - "content": action_string, - }) + self.chatml_history.append( + { + "role": "assistant", + "content": action_string, + } + ) answers = extract_tool_call_contents(self.answer_start, self.answer_end, action_string) if answers: # print(f' [DEBUG] found answer in {action_string=}') - return '', 0.0, True, {} + return "", 0.0, True, {} search_list = extract_tool_call_contents(self.search_start, self.search_end, action_string) browse_list = extract_tool_call_contents(self.browse_start, self.browse_end, action_string) if len(search_list) > 0: - search_key = ' '.join([item.strip() for item in search_list]) + search_key = " ".join([item.strip() for item in search_list]) search_results = self.ddgs.text(search_key, max_results=self.top_k) result_text = self.convert_search_to_text(search_results) result_text = f"\n\n{result_text}\n\n" @@ -54,7 +51,7 @@ def execute(self, action_string, **kwargs): elif len(browse_list) > 0: browse_list = [url.strip() for url in browse_list] img_list = [self.get_screenshot_from_url(url) for url in browse_list] - self.multi_modal_data['image'] += img_list + self.multi_modal_data["image"] += img_list prompt_list = [f"Screenshot for website {url}\n" for url in browse_list] prompt_text = "\n\n".join(prompt_list) @@ -63,11 +60,11 @@ def execute(self, action_string, **kwargs): "prompt": prompt_text, "multi_modal_data": {"image": img_list}, } - print(f' [DEBUG browser] return {len(img_list)} images for {browse_list=}') + print(f" [DEBUG browser] return {len(img_list)} images for {browse_list=}") return obs, 0.0, False, {} else: # print(f' [DEBUG browser] no action_list in {action_string=}') - return '', 0.0, True, {} + return "", 0.0, True, {} def reset(self, raw_prompt, multi_modal_data, origin_multi_modal_data, **kwargs): """ @@ -80,7 +77,11 @@ def reset(self, raw_prompt, multi_modal_data, origin_multi_modal_data, **kwargs) """ self.ddgs = DDGS() self.chatml_history = raw_prompt.tolist() - if origin_multi_modal_data is None or not isinstance(origin_multi_modal_data, dict) or 'image' not in origin_multi_modal_data.keys(): + if ( + origin_multi_modal_data is None + or not isinstance(origin_multi_modal_data, dict) + or "image" not in origin_multi_modal_data.keys() + ): self.multi_modal_data = {"image": []} else: self.multi_modal_data = origin_multi_modal_data @@ -90,11 +91,11 @@ def convert_search_to_text(self, search_results): for result in search_results: docstr = json.dumps(result, ensure_ascii=False, indent=2) search_json_list.append(docstr) - return '\n'.join(search_json_list) + return "\n".join(search_json_list) def get_screenshot_from_url(self, url): def run_single(playwright: Playwright): - chromium = playwright.chromium # or "firefox" or "webkit". + chromium = playwright.chromium # or "firefox" or "webkit". browser = chromium.launch() page = browser.new_page() page.goto(url) @@ -105,4 +106,4 @@ def run_single(playwright: Playwright): with sync_playwright() as pw: img_bytes = run_single(pw) img_pil = Image.open(BytesIO(img_bytes)) - return img_pil \ No newline at end of file + return img_pil diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 3c7d914d..d42977fe 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -33,6 +33,7 @@ from verl.utils import hf_processor, hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( @@ -47,7 +48,6 @@ from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager -from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -282,7 +282,7 @@ def _build_model_optimizer( param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, - device_id=get_torch_device.current_device(), + device_id=get_torch_device().current_device(), sharding_strategy=sharding_strategy, # zero3 mixed_precision=mixed_precision, sync_module_states=True, @@ -339,7 +339,9 @@ def _build_rollout(self, trust_remote_code=False): assert self.world_size % infer_tp == 0, ( f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" ) - rollout_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) + rollout_device_mesh = init_device_mesh( + device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] + ) rollout_name = self.config.rollout.name if rollout_name == "hf": from verl.workers.rollout import HFRollout @@ -568,7 +570,7 @@ def generate_sequences(self, prompts: DataProto): offload_fsdp_optimizer(optimizer=self.actor_optimizer) prompts = self.rollout_sharding_manager.preprocess_data(prompts) - print(f' [DEBUG 222] data middle: {len(prompts)}') + print(f" [DEBUG 222] data middle: {len(prompts)}") output = self.rollout.generate_sequences(prompts=prompts) output = self.rollout_sharding_manager.postprocess_data(output) diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 1c0fe125..690c1217 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -19,7 +19,7 @@ import torch ### - +import os def get_version(pkg): try: @@ -28,22 +28,25 @@ def get_version(pkg): return None -package_name = "vllm" -package_version = get_version(package_name) +vllm_package_name = "vllm" +vllm_package_version = get_version(vllm_package_name) +if vllm_package_version is None: + raise PackageNotFoundError( + "To use vllm rollout, please ensure the 'vllm' package is properly installed. See " + https://verl.readthedocs.io/en/latets/start/install.html for more details" + ) -### -# package_version = get_version(package_name) -# [SUPPORT AMD:] -if "AMD" in torch.cuda.get_device_name(): +if "ROCM_PATH" in os.environ: import re - package_version = version(package_name) - package_version = re.match(r"(\d+\.\d+\.?\d*)", package_version).group(1) -else: - package_version = get_version(package_name) + match = re.match(r"(\d+\.\d+\.?\d*)", vllm_package_version) + if match: + vllm_package_version = match.group(1) + else: + raise ValueError(f"Warning: Could not parse version format:{vllm_package_version}") ### -if package_version <= "0.6.3": +if vllm_package_version <= "0.6.3": vllm_mode = "customized" from .fire_vllm_rollout import FIREvLLMRollout from .vllm_rollout import vLLMRollout diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 16396c4f..142f280b 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -29,7 +29,8 @@ from verl.utils.device import get_torch_device from .base import BaseShardingManager -from .patch import patched_ds_v3_load_weights, patched_qwen_moe_load_weights + +# from .patch import patched_ds_v3_load_weights, patched_qwen_moe_load_weights logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -66,26 +67,26 @@ def __init__( self.tp_rank = vllm_ps.get_tensor_model_parallel_rank() # Note that torch_random_states may be different on each dp rank - self.torch_random_states = get_torch_device.get_rng_state() + self.torch_random_states = get_torch_device().get_rng_state() # get a random rng states if self.device_mesh is not None: gen_dp_rank = self.device_mesh["dp"].get_local_rank() - get_torch_device.manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states - self.gen_random_states = get_torch_device.get_rng_state() - get_torch_device.set_rng_state(self.torch_random_states) + get_torch_device().manual_seed(gen_dp_rank + 1000) # make sure all tp ranks have the same random states + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) else: self.gen_random_states = None @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __enter__(self): - # NOTE: Basically, we only need `get_torch_device.empty_cache()` before vllm wake_up and + # NOTE: Basically, we only need `get_torch_device().empty_cache()` before vllm wake_up and # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator. # Out of vllm scope, we should avoid empty cache to let pytorch using caching memory # to speed up memory allocations. # # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103 - get_torch_device.empty_cache() + get_torch_device().empty_cache() log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) params = self.module.state_dict() @@ -110,7 +111,7 @@ def __enter__(self): self.update_params(params) log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) del params - get_torch_device.empty_cache() + get_torch_device().empty_cache() if "tags" in inspect.signature(self.inference_engine.wake_up).parameters: self.inference_engine.wake_up(tags=["kv_cache"]) @@ -119,14 +120,14 @@ def __enter__(self): # TODO: offload FSDP model weights # self.module.cpu() - # get_torch_device.empty_cache() + # get_torch_device().empty_cache() # if torch.distributed.get_rank() == 0: - # print(f'after model to cpu in sharding manager memory allocated: {get_torch_device.memory_allocated() / 1e9}GB, reserved: {get_torch_device.memory_reserved() / 1e9}GB') + # print(f'after model to cpu in sharding manager memory allocated: {get_torch_device().memory_allocated() / 1e9}GB, reserved: {get_torch_device.memory_reserved() / 1e9}GB') # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: - self.torch_random_states = get_torch_device.get_rng_state() - get_torch_device.set_rng_state(self.gen_random_states) + self.torch_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.gen_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def __exit__(self, exc_type, exc_value, traceback): @@ -141,17 +142,17 @@ def __exit__(self, exc_type, exc_value, traceback): # self.module.to('cuda') # if torch.distributed.get_rank() == 0: - # print(f'after actor module to cuda in sharding manager memory allocated: {get_torch_device.memory_allocated() / 1e9}GB, reserved: {get_torch_device.memory_reserved() / 1e9}GB') + # print(f'after actor module to cuda in sharding manager memory allocated: {get_torch_device().memory_allocated() / 1e9}GB, reserved: {get_torch_device().memory_reserved() / 1e9}GB') self.module.train() # add empty cache after each compute - get_torch_device.empty_cache() + get_torch_device().empty_cache() # restore random states if self.device_mesh is not None: - self.gen_random_states = get_torch_device.get_rng_state() - get_torch_device.set_rng_state(self.torch_random_states) + self.gen_random_states = get_torch_device().get_rng_state() + get_torch_device().set_rng_state(self.torch_random_states) @GPUMemoryLogger(role="fsdp vllm sharding_manager", logger=logger) def preprocess_data(self, data: DataProto) -> DataProto: From 39dcb9f309f03b6d457037f69fc94cddd946e4eb Mon Sep 17 00:00:00 2001 From: icerain-alt <450125138@qq.com> Date: Fri, 15 Aug 2025 03:13:19 +0000 Subject: [PATCH 4/5] Add NPU training launch script --- .../agent/train_qwen25vl_grpo_agent_npu.sh | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 examples/agent/train_qwen25vl_grpo_agent_npu.sh diff --git a/examples/agent/train_qwen25vl_grpo_agent_npu.sh b/examples/agent/train_qwen25vl_grpo_agent_npu.sh new file mode 100644 index 00000000..55b2f74b --- /dev/null +++ b/examples/agent/train_qwen25vl_grpo_agent_npu.sh @@ -0,0 +1,80 @@ +set -x + +export VLLM_USE_V1=1 +export LLM_AS_A_JUDGE_BASE="http://localhost:18901/v1" +export WANDB_MODE=offline +export WORLD_SIZE=1 +export USE_OPTIMIZED_MODEL=0 + + +unset PYTORCH_CUDA_ALLOC_CONF +PROJECT_NAME="agent_vlagent" +EXPERIMENT_NAME="debug_for_single_node" + +export SAVE_CHECKPOINT_DIR=./verl_checkpoints + +BASEDIR=./datasets_agent +VISUAL_DATASET_TRAIN_0_6_2=${BASEDIR}/data_v0.6.2_reason.parquet +VISUAL_DATASET_TRAIN_0_1_2=${BASEDIR}/data_0.1.2_visual_toolbox_v2.parquet +VISUAL_DATASET_TRAIN_0_8=${BASEDIR}/data_v0.8_visual_toolbox_v2.parquet +VISUAL_DATASET_TEST=${BASEDIR}/seekworld_test.parquet +EUREKA_DATASET_TRAIN=${BASEDIR}/data_thinklite_reasoning_acc.parquet + +REF_MODEL_PATH=./Qwen2.5-VL-7B-Instruct +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ + +debug=False \ + +vs_debug=False \ + data.train_files=[${VISUAL_DATASET_TRAIN_0_1_2},${VISUAL_DATASET_TRAIN_0_8},${EUREKA_DATASET_TRAIN}] \ + data.val_files=[${EUREKA_DATASET_TRAIN}] \ + data.train_batch_size=64 \ + data.max_prompt_length=8192 \ + data.max_response_length=20480 \ + data.return_raw_chat=True \ + data.filter_overlong_prompts=True \ + algorithm.adv_estimator=grpo \ + algorithm.kl_ctrl.kl_coef=0.0 \ + actor_rollout_ref.model.path=${REF_MODEL_PATH} \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0.0 \ + actor_rollout_ref.actor.checkpoint.contents=['model','hf_model','optimizer','extra'] \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.max_num_batched_tokens=32768 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.enforce_eager=True \ + actor_rollout_ref.rollout.free_cache_engine=True \ + actor_rollout_ref.rollout.enable_chunked_prefill=False \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.use_torch_compile=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.agent.activate_agent=True \ + actor_rollout_ref.rollout.agent.tool_name_key=env_name \ + actor_rollout_ref.rollout.agent.single_response_max_tokens=10240 \ + actor_rollout_ref.rollout.agent.max_turns=5 \ + actor_rollout_ref.rollout.agent.concurrent_workers=1 \ + actor_rollout_ref.rollout.agent.show_tqdm=True \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=${WORLD_SIZE} \ + trainer.save_freq=32 \ + trainer.test_freq=10000 \ + +trainer.device_name=npu \ + trainer.project_name=${PROJECT_NAME} \ + trainer.experiment_name=${EXPERIMENT_NAME} \ + trainer.default_local_dir=${SAVE_CHECKPOINT_DIR}/${PROJECT_NAME}/${EXPERIMENT_NAME} \ + +trainer.tensorboard_dir=${SAVE_CHECKPOINT_DIR}/logs/tensorboard \ + +trainer.rl_logging_board_dir=${SAVE_CHECKPOINT_DIR}/logs/rl_logging_board \ + trainer.total_epochs=32 2>&1 | tee $(date +%Y%m%d%H%M%S).log \ No newline at end of file From 37066ca5aaff15c0a85b3d12b8def5e7e04883e3 Mon Sep 17 00:00:00 2001 From: zyang6 Date: Tue, 26 Aug 2025 09:28:00 +0000 Subject: [PATCH 5/5] fix:check diff those that have no actual changes and update npu-readme --- examples/agent/README_NPU.md | 85 +++++++++++++++++++ verl/utils/debug/performance.py | 21 ++--- verl/utils/distributed.py | 3 +- verl/utils/fsdp_utils.py | 3 +- verl/workers/actor/dp_actor.py | 24 ++---- .../envs/visual_agent/mm_search_engine.py | 44 +++++----- verl/workers/fsdp_workers.py | 8 +- verl/workers/rollout/vllm_rollout/__init__.py | 7 +- verl/workers/sharding_manager/fsdp_vllm.py | 2 - 9 files changed, 128 insertions(+), 69 deletions(-) create mode 100644 examples/agent/README_NPU.md diff --git a/examples/agent/README_NPU.md b/examples/agent/README_NPU.md new file mode 100644 index 00000000..c5145bce --- /dev/null +++ b/examples/agent/README_NPU.md @@ -0,0 +1,85 @@ +Add Support for Huawei Ascend Devices on DeepEyes + +# Installation + +## Basic Environment Preparation + +| software | version | +| :-------- | :---------- | +| Python | ==3.10 | +| CANN | ==8.1.RC1 | +| torch | ==2.5.1 | +| torch_npu | ==2.5.1.RC1 | + +## Install vllm & vllm-ascend + +To ensure proper usage of vllm in verl, it is recommended to install vllm & vllm-ascend via source code compilation. + +### Install vllm 0.9.1 + +```bash +git clone -b v0.9.1 https://github.com/vllm-project/vllm.git +``` + +Comment out the dependency on torch in requirements/build.txt + +```bash +cd vllm +pip install -r ./requirements/build.txt +VLLM_TARGET_DEVICE=empty pip install -e . +``` + +### Install vllm-ascend 0.9.1 + +```bash +git clone -b v0.9.1-dev https://github.com/vllm-project/vllm-ascend.git +cd vllm-ascend + +export COMPILE_CUSTOM_KERNELS=1 +pip install -e . --no-build-isolation +``` + +## Install verl + +```bash +cd verl +pip install -r requirements-npu.txt +pip install -e . +``` +# Start Training +We use Qwen-2.5-VL-7B-Instruct as our foundation model for RL training. Qwen-2.5-VL-32B-Instruct is also supported. + +Step 1: Start a vllm serving of Qwen-2.5-72B-Instruct for llm-as-a-judge verification. +```bash +# download Qwen-2.5-72B-Instruct model +huggingface-cli download --resume-download https://huggingface.co/Qwen/Qwen2.5-72B-Instruct --local-dir /path/to/your/local/filedir --local-dir-use-symlinks False + +# start vllm serving +vllm serve /path/to/your/local/filedir \ + --port 18901 \ + --gpu-memory-utilization 0.8 \ + --max-model-len 32768 \ + --tensor-parallel-size 8 \ + --served-model-name "judge" \ + --trust-remote-code \ + --disable-log-requests +``` + +Step 2: Build a ray cluster for all of the training nodes. Prepare data before starting training. Our training dataset can be downloaded from [huggingface](https://huggingface.co/datasets/ChenShawn/DeepEyes-Datasets-47k). + +Step 3: Use one of the following scripts to start training. + +```bash +# your wandb access key here... +wandb login + +# the IP and port for your Qwen-2.5-72B-Instruct vllm serving +export LLM_AS_A_JUDGE_BASE="http://your.vllm.machine.ip:18901/v1" + +# umber of training nodes +export WORLD_SIZE=8 + +# config for 7B +bash examples/agent/train_qwen25vl_grpo_agent_npu.sh +``` + diff --git a/verl/utils/debug/performance.py b/verl/utils/debug/performance.py index b849df66..d1a13491 100644 --- a/verl/utils/debug/performance.py +++ b/verl/utils/debug/performance.py @@ -17,8 +17,8 @@ import torch import torch.distributed as dist -from verl.utils.device import get_torch_device from verl.utils.logger.aggregate_logger import DecoratorLoggerBase +from verl.utils.device import get_torch_device def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): @@ -36,10 +36,10 @@ def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging class GPUMemoryLogger(DecoratorLoggerBase): """_summary_ - + Usage: For example, in actor function, we initialize a GPUMemoryLogger - + ``` from verl.utils.debug.performance import GPUMemoryLogger @GPUMemoryLogger(role="actor") @@ -47,22 +47,21 @@ def update_actor(self, batch): # do something return ``` - + """ - + def __init__(self, role: str, logger: logging.Logger = None, level=logging.DEBUG, log_only_rank_0: bool = True): if dist.is_initialized() and dist.get_world_size() > 1: rank = dist.get_rank() else: rank = 0 super().__init__(role, logger, level, rank, log_only_rank_0) - + def __call__(self, decorated_function: callable): def f(*args, **kwargs): return self.log(decorated_function, *args, **kwargs) - return f - + def log(self, func, *args, **kwargs): memory_allocated = get_torch_device().memory_allocated() / 1024**3 memory_reserved = get_torch_device().memory_reserved() / 1024**3 @@ -70,8 +69,6 @@ def log(self, func, *args, **kwargs): message = f"Before {func.__name__}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}" self.logging_function(message) output = func(*args, **kwargs) - message = ( - f"After {func.__name__}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}" - ) + message = f"After {func.__name__}, memory allocated (GB): {memory_allocated}, memory reserved (GB): {memory_reserved}" self.logging_function(message) - return output + return output \ No newline at end of file diff --git a/verl/utils/distributed.py b/verl/utils/distributed.py index 69b2214c..598150c7 100644 --- a/verl/utils/distributed.py +++ b/verl/utils/distributed.py @@ -14,8 +14,7 @@ """Utilities for distributed training.""" import os - -from verl.utils.device import get_torch_device, is_cuda_available +from verl.utils.device import is_cuda_available, get_torch_device def initialize_global_process_group(timeout_second=36000): diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 195aa1fe..344a05cc 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -28,8 +28,7 @@ from torch.distributed.fsdp._runtime_utils import _lazy_init from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name - -from verl.utils.device import get_device_name, get_torch_device +from verl.utils.device import get_torch_device, get_device_name def init_fn(x: torch.nn.Module): diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 4c628eb3..b4dfed6c 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -28,17 +28,17 @@ from verl import DataProto from verl.trainer.ppo.core_algos import agg_loss, compute_policy_loss, kl_penalty from verl.utils.debug import GPUMemoryLogger -from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available from verl.utils.py_functional import append_to_dict from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches from verl.utils.torch_functional import logprobs_from_logits from verl.utils.ulysses import gather_outpus_and_unpad, ulysses_pad_and_slice_inputs from verl.workers.actor import BasePPOActor +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available if is_cuda_available: - from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input + from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis elif is_npu_available: - from transformers.integrations.npu_flash_attention import index_first_axis, pad_input, rearrange, unpad_input + from transformers.integrations.npu_flash_attention import pad_input, unpad_input, rearrange, index_first_axis __all__ = ["DataParallelPPOActor"] @@ -275,15 +275,7 @@ def update_policy(self, data: DataProto): temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error - select_keys = [ - "responses", - "input_ids", - "attention_mask", - "position_ids", - "old_log_probs", - "advantages", - "action_mask", - ] + select_keys = ["responses", "input_ids", "attention_mask", "position_ids", "old_log_probs", "advantages", "action_mask"] if self.config.use_kl_loss: select_keys.append("ref_log_prob") batch = data.select(batch_keys=select_keys, strict=False).batch @@ -308,9 +300,7 @@ def update_policy(self, data: DataProto): self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu ) num_micro_batches = mini_batch.batch.batch_size[0] // self.config.ppo_micro_batch_size_per_gpu - micro_batches = data.select(select_keys, non_tensor_select_keys, strict=False).chunk( - num_micro_batches - ) + micro_batches = data.select(select_keys, non_tensor_select_keys, strict=False).chunk(num_micro_batches) elif self.config.use_dynamic_bsz: max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) @@ -331,9 +321,7 @@ def update_policy(self, data: DataProto): data = data.to(get_torch_device().current_device()) # actor device is cpu when using offload responses = data["responses"] response_length = responses.size(1) - action_or_attn_mask = ( - data["action_mask"] if "action_mask" in data.keys() else data["attention_mask"] - ) + action_or_attn_mask = data['action_mask'] if 'action_mask' in data.keys() else data['attention_mask'] response_mask = action_or_attn_mask[:, -response_length:] old_log_prob = data["old_log_probs"] diff --git a/verl/workers/agent/envs/visual_agent/mm_search_engine.py b/verl/workers/agent/envs/visual_agent/mm_search_engine.py index 2778b22a..914937f9 100644 --- a/verl/workers/agent/envs/visual_agent/mm_search_engine.py +++ b/verl/workers/agent/envs/visual_agent/mm_search_engine.py @@ -1,14 +1,18 @@ +import re +import os import json +import random +import requests +import numpy as np from io import BytesIO +from PIL import Image +from time import sleep +from playwright.sync_api import sync_playwright, Playwright from duckduckgo_search import DDGS -from PIL import Image -from playwright.sync_api import Playwright, sync_playwright -# from verl.utils.dataset.rl_dataset import process_image from verl.workers.agent.tool_envs import ToolBase, extract_tool_call_contents - class MMSearchEngine(ToolBase): name = "mm_search" @@ -27,22 +31,20 @@ def __init__(self, _name, **kwargs): super().__init__(name=self.name) def execute(self, action_string, **kwargs): - self.chatml_history.append( - { - "role": "assistant", - "content": action_string, - } - ) + self.chatml_history.append({ + "role": "assistant", + "content": action_string, + }) answers = extract_tool_call_contents(self.answer_start, self.answer_end, action_string) if answers: # print(f' [DEBUG] found answer in {action_string=}') - return "", 0.0, True, {} + return '', 0.0, True, {} search_list = extract_tool_call_contents(self.search_start, self.search_end, action_string) browse_list = extract_tool_call_contents(self.browse_start, self.browse_end, action_string) if len(search_list) > 0: - search_key = " ".join([item.strip() for item in search_list]) + search_key = ' '.join([item.strip() for item in search_list]) search_results = self.ddgs.text(search_key, max_results=self.top_k) result_text = self.convert_search_to_text(search_results) result_text = f"\n\n{result_text}\n\n" @@ -51,7 +53,7 @@ def execute(self, action_string, **kwargs): elif len(browse_list) > 0: browse_list = [url.strip() for url in browse_list] img_list = [self.get_screenshot_from_url(url) for url in browse_list] - self.multi_modal_data["image"] += img_list + self.multi_modal_data['image'] += img_list prompt_list = [f"Screenshot for website {url}\n" for url in browse_list] prompt_text = "\n\n".join(prompt_list) @@ -60,11 +62,11 @@ def execute(self, action_string, **kwargs): "prompt": prompt_text, "multi_modal_data": {"image": img_list}, } - print(f" [DEBUG browser] return {len(img_list)} images for {browse_list=}") + print(f' [DEBUG browser] return {len(img_list)} images for {browse_list=}') return obs, 0.0, False, {} else: # print(f' [DEBUG browser] no action_list in {action_string=}') - return "", 0.0, True, {} + return '', 0.0, True, {} def reset(self, raw_prompt, multi_modal_data, origin_multi_modal_data, **kwargs): """ @@ -77,11 +79,7 @@ def reset(self, raw_prompt, multi_modal_data, origin_multi_modal_data, **kwargs) """ self.ddgs = DDGS() self.chatml_history = raw_prompt.tolist() - if ( - origin_multi_modal_data is None - or not isinstance(origin_multi_modal_data, dict) - or "image" not in origin_multi_modal_data.keys() - ): + if origin_multi_modal_data is None or not isinstance(origin_multi_modal_data, dict) or 'image' not in origin_multi_modal_data.keys(): self.multi_modal_data = {"image": []} else: self.multi_modal_data = origin_multi_modal_data @@ -91,11 +89,11 @@ def convert_search_to_text(self, search_results): for result in search_results: docstr = json.dumps(result, ensure_ascii=False, indent=2) search_json_list.append(docstr) - return "\n".join(search_json_list) + return '\n'.join(search_json_list) def get_screenshot_from_url(self, url): def run_single(playwright: Playwright): - chromium = playwright.chromium # or "firefox" or "webkit". + chromium = playwright.chromium # or "firefox" or "webkit". browser = chromium.launch() page = browser.new_page() page.goto(url) @@ -106,4 +104,4 @@ def run_single(playwright: Playwright): with sync_playwright() as pw: img_bytes = run_single(pw) img_pil = Image.open(BytesIO(img_bytes)) - return img_pil + return img_pil \ No newline at end of file diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index d42977fe..4ad245ac 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -33,7 +33,6 @@ from verl.utils import hf_processor, hf_tokenizer from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager from verl.utils.debug import log_gpu_memory_usage -from verl.utils.device import get_device_name, get_torch_device, is_cuda_available from verl.utils.flops_counter import FlopsCounter from verl.utils.fs import copy_to_local from verl.utils.fsdp_utils import ( @@ -48,6 +47,7 @@ from verl.utils.import_utils import import_external_libs from verl.utils.model import compute_position_id_with_mask from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager +from verl.utils.device import get_device_name, get_torch_device, is_cuda_available, is_npu_available logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) @@ -339,9 +339,7 @@ def _build_rollout(self, trust_remote_code=False): assert self.world_size % infer_tp == 0, ( f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}" ) - rollout_device_mesh = init_device_mesh( - device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"] - ) + rollout_device_mesh = init_device_mesh(device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]) rollout_name = self.config.rollout.name if rollout_name == "hf": from verl.workers.rollout import HFRollout @@ -570,7 +568,7 @@ def generate_sequences(self, prompts: DataProto): offload_fsdp_optimizer(optimizer=self.actor_optimizer) prompts = self.rollout_sharding_manager.preprocess_data(prompts) - print(f" [DEBUG 222] data middle: {len(prompts)}") + print(f' [DEBUG 222] data middle: {len(prompts)}') output = self.rollout.generate_sequences(prompts=prompts) output = self.rollout_sharding_manager.postprocess_data(output) diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 690c1217..50e9dd95 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -18,7 +18,7 @@ # [SUPPORT AMD:] import torch -### +### [SUPPORT ROCM] import os def get_version(pkg): @@ -31,10 +31,7 @@ def get_version(pkg): vllm_package_name = "vllm" vllm_package_version = get_version(vllm_package_name) if vllm_package_version is None: - raise PackageNotFoundError( - "To use vllm rollout, please ensure the 'vllm' package is properly installed. See " - https://verl.readthedocs.io/en/latets/start/install.html for more details" - ) + raise PackageNotFoundError("To use vllm rollout, please ensure the 'vllm' package is properly installed. See https://verl.readthedocs.io/en/latets/start/install.html for more details") if "ROCM_PATH" in os.environ: import re diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index 142f280b..affcc24a 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -30,8 +30,6 @@ from .base import BaseShardingManager -# from .patch import patched_ds_v3_load_weights, patched_qwen_moe_load_weights - logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))