diff --git a/reagent/core/types.py b/reagent/core/types.py index 09c81d75c..2b9a041b3 100644 --- a/reagent/core/types.py +++ b/reagent/core/types.py @@ -891,6 +891,23 @@ class MemoryNetworkInput(BaseInput): valid_step: Optional[torch.Tensor] = None extras: ExtraData = field(default_factory=ExtraData) + @classmethod + def from_dict(cls, d): + return cls( + state=FeatureData( + float_features=d["state"], + ), + next_state=FeatureData( + float_features=d["next_state"], + ), + action=d["action"], + reward=d["reward"], + time_diff=d["time_diff"], + not_terminal=d["not_terminal"], + step=d["step"], + extras=ExtraData.from_dict(d), + ) + def __len__(self): if len(self.state.float_features.size()) == 2: return self.state.float_features.size()[0] diff --git a/reagent/gym/preprocessors/trainer_preprocessor.py b/reagent/gym/preprocessors/trainer_preprocessor.py index 857369a9d..72fde8a5c 100644 --- a/reagent/gym/preprocessors/trainer_preprocessor.py +++ b/reagent/gym/preprocessors/trainer_preprocessor.py @@ -28,7 +28,6 @@ ONLINE_MAKER_MAP = {} REPLAY_BUFFER_MAKER_MAP = {} - def make_trainer_preprocessor( trainer: Trainer, device: torch.device, @@ -344,16 +343,16 @@ def __call__(self, batch): stacked_not_terminal[-1] = scalar_fields["not_terminal"] scalar_fields["not_terminal"] = stacked_not_terminal - return rlt.MemoryNetworkInput( - state=rlt.FeatureData(float_features=vector_fields["state"]), - next_state=rlt.FeatureData(float_features=vector_fields["next_state"]), - action=vector_fields["action"], - reward=scalar_fields["reward"], - not_terminal=scalar_fields["not_terminal"], - step=None, - time_diff=None, - ) - + dict_batch = { + "state": vector_fields["state"], + "next_state": vector_fields["next_state"], + "action": vector_fields["action"], + "reward": scalar_fields["reward"], + "not_terminal": scalar_fields["not_terminal"], + "step": None, + "time_diff": None, + } + return rlt.MemoryNetworkInput.from_dict(dict_batch) def get_possible_actions_for_gym(batch_size: int, num_actions: int) -> rlt.FeatureData: """ diff --git a/reagent/workflow/model_managers/model_manager.py b/reagent/workflow/model_managers/model_manager.py index bd7e42a48..fcc6f5eeb 100644 --- a/reagent/workflow/model_managers/model_manager.py +++ b/reagent/workflow/model_managers/model_manager.py @@ -9,8 +9,7 @@ from reagent.core.dataclasses import dataclass from reagent.core.parameters import NormalizationData from reagent.core.registry_meta import RegistryMeta -from reagent.core.tensorboardX import summary_writer_context -from reagent.training import ReAgentLightningModule, Trainer +from reagent.training import Trainer from reagent.workflow.data import ReAgentDataModule from reagent.workflow.types import ( Dataset,