From 0d1d8625d6594c05ff131fc9fab76341c848ac08 Mon Sep 17 00:00:00 2001 From: Ban Kawas Date: Mon, 22 Mar 2021 20:42:18 -0700 Subject: [PATCH] Remove/Resolve import duplicates + add `from_dict` classmethod to `MemoryNetworkInput` Summary: - fix import errors (remove duplicates + resolve path for train_and_evaluate_generic) - add `from_dict` classmethod to `MemoryNetworkInput` Reviewed By: kaiwenw Differential Revision: D27134600 fbshipit-source-id: 5f9579e542b374de4438677e4e75b91adc10f9f2 --- reagent/core/types.py | 17 +++++++++++++++ .../gym/preprocessors/trainer_preprocessor.py | 21 +++++++++---------- .../workflow/model_managers/model_manager.py | 3 +-- 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/reagent/core/types.py b/reagent/core/types.py index 09c81d75c..2b9a041b3 100644 --- a/reagent/core/types.py +++ b/reagent/core/types.py @@ -891,6 +891,23 @@ class MemoryNetworkInput(BaseInput): valid_step: Optional[torch.Tensor] = None extras: ExtraData = field(default_factory=ExtraData) + @classmethod + def from_dict(cls, d): + return cls( + state=FeatureData( + float_features=d["state"], + ), + next_state=FeatureData( + float_features=d["next_state"], + ), + action=d["action"], + reward=d["reward"], + time_diff=d["time_diff"], + not_terminal=d["not_terminal"], + step=d["step"], + extras=ExtraData.from_dict(d), + ) + def __len__(self): if len(self.state.float_features.size()) == 2: return self.state.float_features.size()[0] diff --git a/reagent/gym/preprocessors/trainer_preprocessor.py b/reagent/gym/preprocessors/trainer_preprocessor.py index 857369a9d..72fde8a5c 100644 --- a/reagent/gym/preprocessors/trainer_preprocessor.py +++ b/reagent/gym/preprocessors/trainer_preprocessor.py @@ -28,7 +28,6 @@ ONLINE_MAKER_MAP = {} REPLAY_BUFFER_MAKER_MAP = {} - def make_trainer_preprocessor( trainer: Trainer, device: torch.device, @@ -344,16 +343,16 @@ def __call__(self, batch): stacked_not_terminal[-1] = scalar_fields["not_terminal"] scalar_fields["not_terminal"] = stacked_not_terminal - return rlt.MemoryNetworkInput( - state=rlt.FeatureData(float_features=vector_fields["state"]), - next_state=rlt.FeatureData(float_features=vector_fields["next_state"]), - action=vector_fields["action"], - reward=scalar_fields["reward"], - not_terminal=scalar_fields["not_terminal"], - step=None, - time_diff=None, - ) - + dict_batch = { + "state": vector_fields["state"], + "next_state": vector_fields["next_state"], + "action": vector_fields["action"], + "reward": scalar_fields["reward"], + "not_terminal": scalar_fields["not_terminal"], + "step": None, + "time_diff": None, + } + return rlt.MemoryNetworkInput.from_dict(dict_batch) def get_possible_actions_for_gym(batch_size: int, num_actions: int) -> rlt.FeatureData: """ diff --git a/reagent/workflow/model_managers/model_manager.py b/reagent/workflow/model_managers/model_manager.py index bd7e42a48..fcc6f5eeb 100644 --- a/reagent/workflow/model_managers/model_manager.py +++ b/reagent/workflow/model_managers/model_manager.py @@ -9,8 +9,7 @@ from reagent.core.dataclasses import dataclass from reagent.core.parameters import NormalizationData from reagent.core.registry_meta import RegistryMeta -from reagent.core.tensorboardX import summary_writer_context -from reagent.training import ReAgentLightningModule, Trainer +from reagent.training import Trainer from reagent.workflow.data import ReAgentDataModule from reagent.workflow.types import ( Dataset,