Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions reagent/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,23 @@ class MemoryNetworkInput(BaseInput):
valid_step: Optional[torch.Tensor] = None
extras: ExtraData = field(default_factory=ExtraData)

@classmethod
def from_dict(cls, d):
return cls(
state=FeatureData(
float_features=d["state"],
),
next_state=FeatureData(
float_features=d["next_state"],
),
action=d["action"],
reward=d["reward"],
time_diff=d["time_diff"],
not_terminal=d["not_terminal"],
step=d["step"],
extras=ExtraData.from_dict(d),
)

def __len__(self):
if len(self.state.float_features.size()) == 2:
return self.state.float_features.size()[0]
Expand Down
21 changes: 10 additions & 11 deletions reagent/gym/preprocessors/trainer_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
ONLINE_MAKER_MAP = {}
REPLAY_BUFFER_MAKER_MAP = {}


def make_trainer_preprocessor(
trainer: Trainer,
device: torch.device,
Expand Down Expand Up @@ -344,16 +343,16 @@ def __call__(self, batch):
stacked_not_terminal[-1] = scalar_fields["not_terminal"]
scalar_fields["not_terminal"] = stacked_not_terminal

return rlt.MemoryNetworkInput(
state=rlt.FeatureData(float_features=vector_fields["state"]),
next_state=rlt.FeatureData(float_features=vector_fields["next_state"]),
action=vector_fields["action"],
reward=scalar_fields["reward"],
not_terminal=scalar_fields["not_terminal"],
step=None,
time_diff=None,
)

dict_batch = {
"state": vector_fields["state"],
"next_state": vector_fields["next_state"],
"action": vector_fields["action"],
"reward": scalar_fields["reward"],
"not_terminal": scalar_fields["not_terminal"],
"step": None,
"time_diff": None,
}
return rlt.MemoryNetworkInput.from_dict(dict_batch)

def get_possible_actions_for_gym(batch_size: int, num_actions: int) -> rlt.FeatureData:
"""
Expand Down
3 changes: 1 addition & 2 deletions reagent/workflow/model_managers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from reagent.core.dataclasses import dataclass
from reagent.core.parameters import NormalizationData
from reagent.core.registry_meta import RegistryMeta
from reagent.core.tensorboardX import summary_writer_context
from reagent.training import ReAgentLightningModule, Trainer
from reagent.training import Trainer
from reagent.workflow.data import ReAgentDataModule
from reagent.workflow.types import (
Dataset,
Expand Down