diff --git a/applications/Chat/coati/dataset/prompt_dataset.py b/applications/Chat/coati/dataset/prompt_dataset.py index 5858052c836a..0bdcbbc5928e 100644 --- a/applications/Chat/coati/dataset/prompt_dataset.py +++ b/applications/Chat/coati/dataset/prompt_dataset.py @@ -35,14 +35,14 @@ def __init__(self, logger.info(f"Limiting dataset to {max_datasets_size} examples.") list_data_dict = list_data_dict[:max_datasets_size] - for data_dict in list_data_dict: - token = tokenizer(data_dict["instruction"], - return_tensors='pt', - max_length=max_length, - padding='max_length', - truncation=True) - for k, tensor in token.items(): - self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind()) + instructions = [data_dict["instruction"] for data_dict in list_data_dict] + tokens = tokenizer(instructions, + return_tensors='pt', + max_length=max_length, + padding='max_length', + truncation=True) + for k, tensor in tokens.items(): + self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind() def __len__(self): return len(self.keyed_prompt["input_ids"]) diff --git a/applications/Chat/coati/dataset/sft_dataset.py b/applications/Chat/coati/dataset/sft_dataset.py index 3e2453468bbc..3702d00cc609 100644 --- a/applications/Chat/coati/dataset/sft_dataset.py +++ b/applications/Chat/coati/dataset/sft_dataset.py @@ -74,21 +74,18 @@ def __getitem__(self, idx): return dict(input_ids=self.input_ids[idx], labels=self.labels[idx]) -def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict: +def _tokenize_fn(strings: Sequence[str], + tokenizer: transformers.PreTrainedTokenizer, + max_length: int + ) -> Dict[str, torch.Tensor]: """Tokenize a list of strings.""" - tokenized_list = [ - tokenizer( - text, - return_tensors="pt", - padding="longest", - max_length=max_length, - truncation=True, - ) for text in strings - ] - input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] - input_ids_lens = labels_lens = [ - tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list - ] + tokenized_list = tokenizer( + strings, return_tensors="pt", padding="longest", + max_length=max_length, truncation=True + ) + input_ids = labels = tokenized_list["input_ids"] + input_ids_lens = labels_lens = \ + tokenized_list["input_ids"].ne(tokenizer.pad_token_id).sum(dim=-1) return dict( input_ids=input_ids, labels=labels, @@ -105,7 +102,10 @@ def preprocess( ) -> Dict: """Preprocess the data by tokenizing.""" examples = [s + t for s, t in zip(sources, targets)] - examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length) for strings in (examples, sources)] + examples_tokenized, sources_tokenized = [ + _tokenize_fn(strings, tokenizer, max_length) + for strings in (examples, sources) + ] input_ids = examples_tokenized["input_ids"] labels = copy.deepcopy(input_ids) for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): diff --git a/applications/Chat/coati/experience_maker/naive.py b/applications/Chat/coati/experience_maker/naive.py index 94546eeb28e7..e5bb029e63d0 100644 --- a/applications/Chat/coati/experience_maker/naive.py +++ b/applications/Chat/coati/experience_maker/naive.py @@ -1,5 +1,6 @@ import torch -from coati.models.utils import compute_reward, normalize +from coati.models.generation import generate_with_actor +from coati.models.utils import calc_action_log_probs, compute_reward, normalize from .base import Experience, ExperienceMaker @@ -16,13 +17,16 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie self.initial_model.eval() self.reward_model.eval() - sequences, attention_mask, action_mask = self.actor.generate(input_ids, + sequences, attention_mask, action_mask = generate_with_actor(self.actor, + input_ids, return_action_mask=True, **generate_kwargs) num_actions = action_mask.size(1) - action_log_probs = self.actor(sequences, num_actions, attention_mask) - base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask) + actor_output = self.actor(sequences, attention_mask) + action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions) + base_model_output = self.initial_model(sequences, attention_mask) + base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions) value = self.critic(sequences, action_mask, attention_mask) r = self.reward_model(sequences, attention_mask) reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask) diff --git a/applications/Chat/coati/models/base/__init__.py b/applications/Chat/coati/models/base/__init__.py index fe4152f2b760..c5f748a0c85a 100644 --- a/applications/Chat/coati/models/base/__init__.py +++ b/applications/Chat/coati/models/base/__init__.py @@ -1,3 +1,5 @@ +from typing import Union + import torch.nn as nn from .actor import Actor @@ -5,10 +7,10 @@ from .reward_model import RewardModel -def get_base_model(model: nn.Module) -> nn.Module: +def get_base_model(model: Union[Actor, Critic, RewardModel]) -> nn.Module: """Get the base model of our wrapper classes. - For Actor, it's base model is ``actor.model`` and it's usually a ``transformers.PreTrainedModel``. - For Critic and RewardModel, it's base model is itself. + For Actor, Critic and RewardModel, return ``model.model``, + it's usually a ``transformers.PreTrainedModel``. Args: model (nn.Module): model to get base model from @@ -16,9 +18,9 @@ def get_base_model(model: nn.Module) -> nn.Module: Returns: nn.Module: the base model """ - if isinstance(model, Actor): - return model.get_base_model() - return model + assert isinstance(model, (Actor, Critic, RewardModel)), \ + f'Expect Actor, Critic or RewardModel, got {type(model)}, use unwrap_model first.' + return model.model __all__ = ['Actor', 'Critic', 'RewardModel', 'get_base_model'] diff --git a/applications/Chat/coati/models/base/actor.py b/applications/Chat/coati/models/base/actor.py index 71fbf7bbae7d..2034d5cc81d4 100644 --- a/applications/Chat/coati/models/base/actor.py +++ b/applications/Chat/coati/models/base/actor.py @@ -1,12 +1,9 @@ -from typing import Optional, Tuple, Union +from typing import Optional import torch import torch.nn as nn -import torch.nn.functional as F -from ..generation import generate from ..lora import LoRAModule -from ..utils import log_probs_from_logits class Actor(LoRAModule): @@ -24,42 +21,16 @@ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = self.model = model self.convert_to_lora() - @torch.no_grad() - def generate( - self, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs - ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: - sequences = generate(self.model, input_ids, **kwargs) - attention_mask = None - pad_token_id = kwargs.get('pad_token_id', None) - if pad_token_id is not None: - attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) - if not return_action_mask: - return sequences, attention_mask, None - input_len = input_ids.size(1) - eos_token_id = kwargs.get('eos_token_id', None) - if eos_token_id is None: - action_mask = torch.ones_like(sequences, dtype=torch.bool) - else: - # left padding may be applied, only mask action - action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 - action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input - action_mask[:, :input_len] = False - action_mask = action_mask[:, 1:] - return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] - def forward(self, - sequences: torch.LongTensor, - num_actions: int, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - """Returns action log probs + input_ids: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + **model_kwargs, # HACK: `generate` method may pass more kwargs + ) -> torch.Tensor: + """Returns model output. """ - output = self.model(sequences, attention_mask=attention_mask) - logits = output['logits'] - log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) - return log_probs[:, -num_actions:] - - def get_base_model(self): - return self.model + output = self.model( + input_ids, + attention_mask=attention_mask, + **model_kwargs + ) + return output diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index f57c9458a271..0156e2284e52 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -1,8 +1,10 @@ -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.distributed as dist import torch.nn as nn +import torch.nn.functional as F + try: from transformers.generation_logits_process import ( @@ -55,9 +57,8 @@ def sample(model: nn.Module, unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) for _ in range(input_ids.size(1), max_length): - model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else { - 'input_ids': input_ids - } + model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) \ + if prepare_inputs_fn is not None else {'input_ids': input_ids} outputs = model(**model_inputs) next_token_logits = outputs['logits'][:, -1, :] @@ -144,3 +145,35 @@ def generate(model: nn.Module, raise NotImplementedError else: raise ValueError("Unsupported generation mode") + + +@torch.no_grad() +def generate_with_actor(actor_model: nn.Module, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs + ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], + Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: + """Generate token sequence with actor model. Refer to `generate` for more details. + """ + # generate sequences + sequences = generate(actor_model, input_ids, **kwargs) + + # calculate auxiliary tensors + attention_mask = None + pad_token_id = kwargs.get('pad_token_id', None) + if pad_token_id is not None: + attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device) + if not return_action_mask: + return sequences, attention_mask, None + input_len = input_ids.size(1) + eos_token_id = kwargs.get('eos_token_id', None) + if eos_token_id is None: + action_mask = torch.ones_like(sequences, dtype=torch.bool) + else: + # left padding may be applied, only mask action + action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0 + action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input + action_mask[:, :input_len] = False + action_mask = action_mask[:, 1:] + return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):] diff --git a/applications/Chat/coati/models/utils.py b/applications/Chat/coati/models/utils.py index 0ff13181fcd2..b9f15f894a1f 100644 --- a/applications/Chat/coati/models/utils.py +++ b/applications/Chat/coati/models/utils.py @@ -46,6 +46,25 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T return log_probs_labels.squeeze(-1) +def calc_action_log_probs(output: torch.Tensor, + sequences: torch.LongTensor, + num_actions: int + ) -> torch.Tensor: + """Calculate action log probs. + + Args: + output (torch.Tensor): Output tensor of Actor.forward. + sequences (torch.LongTensor): Input sequences. + num_actions (int): Number of actions. + + Returns: + torch.Tensor: Action log probs. + """ + logits = output['logits'] + log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:]) + return log_probs[:, -num_actions:] + + def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: tensor = tensor * mask tensor = tensor.sum(dim=dim) diff --git a/applications/Chat/coati/trainer/ppo.py b/applications/Chat/coati/trainer/ppo.py index fe5ae48d9c2f..e2e44e62533e 100644 --- a/applications/Chat/coati/trainer/ppo.py +++ b/applications/Chat/coati/trainer/ppo.py @@ -3,8 +3,9 @@ import torch import torch.nn as nn from coati.experience_maker import Experience, NaiveExperienceMaker -from coati.models.base import Actor, Critic +from coati.models.base import Actor, Critic, get_base_model from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss +from coati.models.utils import calc_action_log_probs from coati.replay_buffer import NaiveReplayBuffer from torch import Tensor from torch.optim import Optimizer @@ -165,7 +166,8 @@ def training_step(self, experience: Experience) -> Dict[str, float]: self.critic.train() # policy loss num_actions = experience.action_mask.size(1) - action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask) + actor_output = self.actor(experience.sequences, attention_mask=experience.attention_mask) + action_log_probs = calc_action_log_probs(actor_output, experience.sequences, num_actions) actor_loss = self.actor_loss_fn(action_log_probs, experience.action_log_probs, experience.advantages, @@ -175,8 +177,8 @@ def training_step(self, experience: Experience) -> Dict[str, float]: if self.ptx_coef != 0: batch = next(iter(self.pretrain_dataloader)) batch = to_device(batch, self.device) - ptx_log_probs = self.actor.get_base_model()(batch['input_ids'], - attention_mask=batch['attention_mask'])['logits'] + ptx_log_probs = self.actor(batch['input_ids'], + attention_mask=batch['attention_mask'])['logits'] ptx_loss = self.ptx_loss_fn(ptx_log_probs, batch['labels']) actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef) @@ -200,14 +202,15 @@ def training_step(self, experience: Experience) -> Dict[str, float]: return {'reward': experience.reward.mean().item()} -def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None: - origin_model = strategy.unwrap_model(actor) +def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict: + unwrapper_model = strategy.unwrap_model(actor) + hf_model = get_base_model(unwrapper_model) new_kwargs = {**generate_kwargs} # use huggingface models method directly - if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'): - new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation + if 'prepare_inputs_fn' not in generate_kwargs and hasattr(hf_model, 'prepare_inputs_for_generation'): + new_kwargs['prepare_inputs_fn'] = hf_model.prepare_inputs_for_generation - if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(origin_model, '_update_model_kwargs_for_generation'): - new_kwargs['update_model_kwargs_fn'] = origin_model._update_model_kwargs_for_generation + if 'update_model_kwargs_fn' not in generate_kwargs and hasattr(hf_model, '_update_model_kwargs_for_generation'): + new_kwargs['update_model_kwargs_fn'] = hf_model._update_model_kwargs_for_generation return new_kwargs diff --git a/applications/Chat/coati/trainer/strategies/base.py b/applications/Chat/coati/trainer/strategies/base.py index bd30422022ae..06f81f21ab26 100644 --- a/applications/Chat/coati/trainer/strategies/base.py +++ b/applications/Chat/coati/trainer/strategies/base.py @@ -4,7 +4,6 @@ import torch import torch.nn as nn -from coati.models.base import Actor, get_base_model from coati.replay_buffer import ReplayBuffer from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -69,21 +68,16 @@ def prepare( Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order. """ - def prepare_model(model: nn.Module): - if isinstance(model, Actor): - return Actor(self.setup_model(model.get_base_model())) - return self.setup_model(model) - rets = [] for arg in models_or_model_optim_pairs: if isinstance(arg, tuple): assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"' model, optimizer = arg - model = prepare_model(model) - optimizer = self.setup_optimizer(optimizer, get_base_model(model)) + model = self.setup_model(model) + optimizer = self.setup_optimizer(optimizer, model) rets.append((model, optimizer)) elif isinstance(arg, nn.Module): - rets.append(prepare_model(arg)) + rets.append(self.setup_model(model)) else: raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}') @@ -93,16 +87,15 @@ def prepare_model(model: nn.Module): @staticmethod def unwrap_model(model: nn.Module) -> nn.Module: - """Get the unwrapped model from a wrapped model. Useful for getting original huggingface model. - For Actor, it will unwrap `actor.model`. + """Get the unwrapped model from a wrapped model made by Strategy.prepare. Args: model (nn.Module): the model to unwrap Returns: - nn.Module: the original model (usually a huggingface model) + nn.Module: the original model """ - return get_base_model(model) + return model @abstractmethod def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: @@ -133,4 +126,4 @@ def save_pretrained(self, @abstractmethod def get_model_state_dict_shard(self, model: nn.Module, **config): - pass \ No newline at end of file + pass diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 88268b677eb2..fafd0918deaf 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -5,7 +5,6 @@ import torch.distributed as dist import torch.nn as nn import torch.optim as optim -from coati.models.base import get_base_model from torch.optim import Optimizer from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -153,14 +152,13 @@ def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None: def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: if only_rank0 and dist.get_rank() != 0 and self.stage != 3: return - base_model = get_base_model(model) if self.stage == 3: - assert isinstance(base_model, ZeroDDP) + assert isinstance(model, ZeroDDP) # for stage 3, state_dict() method should be called on every rank - state_dict = base_model.state_dict(only_rank_0=only_rank0) + state_dict = model.state_dict(only_rank_0=only_rank0) else: # only_rank0 is false or rank == 0 - state_dict = base_model.state_dict() + state_dict = model.state_dict() if only_rank0 and dist.get_rank() != 0: return torch.save(state_dict, path) @@ -172,11 +170,10 @@ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = Fal torch.save(optimizer.state_dict(), path) def unwrap_model(self, model: nn.Module) -> nn.Module: - base_model: Union[nn.Module, ZeroDDP] = get_base_model(model) if self.stage == 3: - assert isinstance(base_model, ZeroDDP) - return base_model.module - return base_model + assert isinstance(model, ZeroDDP) + return model.module + return model def save_pretrained(self, model: nn.Module, @@ -196,5 +193,5 @@ def get_model_state_dict_shard(self, model: nn.Module, **config): # if isinstance(module, LoraLinear): # module.merge_weights = True # module.eval() - base_model: ZeroDDP = get_base_model(model) - yield from base_model.state_dict_shard(max_shard_size=1024, only_rank_0=False) + assert isinstance(model, ZeroDDP) + yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False) diff --git a/applications/Chat/coati/trainer/strategies/ddp.py b/applications/Chat/coati/trainer/strategies/ddp.py index a1fecb36373f..713d7b90c6f0 100644 --- a/applications/Chat/coati/trainer/strategies/ddp.py +++ b/applications/Chat/coati/trainer/strategies/ddp.py @@ -69,8 +69,8 @@ def setup_sampler(self, dataset) -> DistributedSampler: return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank()) def unwrap_model(self, model: nn.Module) -> nn.Module: - base_model: DDP = super().unwrap_model(model) - return base_model.module + assert isinstance(model, DDP) + return model.module def save_pretrained(self, model: nn.Module, diff --git a/applications/Chat/coati/trainer/strategies/naive.py b/applications/Chat/coati/trainer/strategies/naive.py index 972deebeaa0d..202c480e06d9 100644 --- a/applications/Chat/coati/trainer/strategies/naive.py +++ b/applications/Chat/coati/trainer/strategies/naive.py @@ -58,14 +58,13 @@ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False collate_fn=replay_buffer.collate_fn) def save_model(self, model: nn.Module, path: str, only_rank0: bool = True) -> None: - base_model = get_base_model(model) - state_dict = base_model.state_dict() + state_dict = model.state_dict() torch.save(state_dict, path) def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None: - base_model = get_base_model(model) + unwrapped_model = self.unwrap_model(model) state_dict = torch.load(path, map_location=map_location) - base_model.load_state_dict(state_dict, strict=strict) + unwrapped_model.load_state_dict(state_dict, strict=strict) def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None: torch.save(optimizer.state_dict(), path) diff --git a/applications/Chat/examples/test_ci.sh b/applications/Chat/examples/test_ci.sh index 2fa6c6052f8d..ac3a9b507864 100755 --- a/applications/Chat/examples/test_ci.sh +++ b/applications/Chat/examples/test_ci.sh @@ -121,6 +121,14 @@ torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_datas --rm_pretrain 'gpt2' \ --rm_path ${BASE}/rm_ckpt_gpt.pt \ --save_path ${BASE}/actor_checkpoint_prompts.pt + +torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py --prompt_dataset $PROMPT_PATH --pretrain_dataset $PRETRAIN_DATASET \ + --strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \ + --update_timesteps 2 --max_epochs 1 --train_batch_size 2 \ + --pretrain 'gpt2' --model gpt2 \ + --rm_pretrain 'gpt2' \ + --rm_path ${BASE}/rm_ckpt_gpt.pt \ + --save_path ${BASE}/actor_checkpoint_prompts.pt rm -rf ${BASE}/rm_ckpt_gpt.pt rm -rf ${BASE}/actor_checkpoint_prompts.pt diff --git a/applications/Chat/tests/test_checkpoint.py b/applications/Chat/tests/test_checkpoint.py index 4c05a3431699..d93a5c94d8ea 100644 --- a/applications/Chat/tests/test_checkpoint.py +++ b/applications/Chat/tests/test_checkpoint.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist from coati.models.gpt import GPTActor +from coati.models.utils import calc_action_log_probs from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy from transformers.models.gpt2.configuration_gpt2 import GPT2Config @@ -43,7 +44,8 @@ def run_test_checkpoint(strategy): def run_step(): data = get_data(BATCH_SIZE) action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool) - action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask']) + actor_output = actor(data['input_ids'], data['attention_mask']) + action_log_probs = calc_action_log_probs(actor_output, data['input_ids'], action_mask.size(1)) loss = action_log_probs.sum() strategy.backward(loss, actor, actor_optim) strategy.optimizer_step(actor_optim)