Skip to content

BOSA算法中vae_dyna的实现错误 #7

@wyuuchen

Description

@wyuuchen

似乎在代码中使用self.vae_dyna模型旨在拟合目标域状态转移,但是在算法中拟合时使用的数据似乎是源域和目标域mix之后的数据,这里似乎不是很正确,具体代码如下(offline_offline/bosa.py):
首先对于train函数的代码:
`def train(self, src_replay_buffer, tar_replay_buffer, batch_size=128, writer=None):
self.total_it += 1

    if src_replay_buffer.size < batch_size or tar_replay_buffer.size < batch_size:
        return
    batch_src, batch_tar = src_replay_buffer.sample(batch_size), tar_replay_buffer.sample(batch_size)
    batch_mix            = [torch.cat([b_tar, b_src], dim=0) for b_tar, b_src in zip(batch_tar, batch_src)]
    
    if self.total_it < self.vae_iteration:
        # vae model pretrain
        log_dict = self.vae_models_train(batch_mix)` (这里直接将mix之后的数据传入train)

而之后对于vae_models_train模型,则直接使用mix之后的数据训练vae_dyna_train(),即
def vae_models_train(self, batch: TensorBatch) -> Dict[str, float]: log_dict = {} self.total_it += 1 loss_dict_vae_policy = self.vae_policy_train(batch) loss_dict_vae_dynamics = self.vae_dyna_train(batch) log_dict.update(loss_dict_vae_policy) log_dict.update(loss_dict_vae_dynamics) return log_dict
而同样在具体训练的 vae_dyna_train函数中,同样直接将其用于VAE模型的拟合
`def vae_dyna_train(self, batch: TensorBatch) -> Dict[str, float]:
state, action, next_state, _, _ = batch
# Variational Auto-Encoder Training
recon, mean, std = self.vae_dyna(state, action, next_state)
recon_loss = F.mse_loss(recon, next_state.repeat(int(self.config['vae_dyna_ensemble']), 1, 1))
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + self.vae_dyna_beta * KL_loss

    self.vae_dyna_optimizer.zero_grad()
    vae_loss.backward()
    self.vae_dyna_optimizer.step()

`
但是文章的原始对于VAE的学习则是用于拟合目标域状态转移

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions