-
Notifications
You must be signed in to change notification settings - Fork 41
Description
您好,在尝试运行代码GRPO的部分时,我跟随README执行了bash ./src/scripts/run_grpo_video.sh.但是代码却卡在了./src/r1-v/src/open_r1/trainer/grpo_trainer.py文件中计算
per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs)
per_token_logps = self._get_per_token_logps(model, prompt_completion_ids, **prompt_inputs) # 这一行会造成死锁 per_token_logps = per_token_logps[:, prompt_length - 1 :]
具体而言卡在了:
def _get_per_token_logps(self, model, input_ids, **kwargs): # logits = model(input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw).logits # (B, L, V) # import pdb logits = model(input_ids, **kwargs).logits logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. per_token_logps = [] for logits_row, input_ids_row in zip(logits, input_ids): log_probs = logits_row.log_softmax(dim=-1) token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) per_token_logps.append(token_log_prob) return torch.stack(per_token_logps)
计算logits的那一行。经过我的查找,似乎有可能是构造的输入中,既有video也有image,导致kwargs中的值在不同rank不同,导致的死锁。请问这个该怎么解决呢?
设备:
使用8卡AMD显卡。ROCM7.0.0。