From b999f7e9fb894b69427f5f010d3c06f468666f22 Mon Sep 17 00:00:00 2001 From: Ana Gainaru Date: Sat, 7 Mar 2026 09:13:08 -0500 Subject: [PATCH] Separating the loop functions over dataloaders for training or infering --- src/config/configuration.py | 1 + src/driver/continuous_monitor.py | 2 +- src/model/torch_model_harness.py | 9 +++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/config/configuration.py b/src/config/configuration.py index 3bc3ece..357f38a 100644 --- a/src/config/configuration.py +++ b/src/config/configuration.py @@ -106,6 +106,7 @@ class TrainCfg: class DataCfg: name: str path: str + batch_size: int = 0 @dataclass(frozen=True) diff --git a/src/driver/continuous_monitor.py b/src/driver/continuous_monitor.py index 217287c..517aa7f 100644 --- a/src/driver/continuous_monitor.py +++ b/src/driver/continuous_monitor.py @@ -132,7 +132,7 @@ def _process_stream(self) -> None: Raises: StopIteration: When the data loader is exhausted """ - train_loader, val_loader = self.modelHarness.get_cur_data_loaders() + train_loader, val_loader = self.modelHarness.get_cur_loop_loaders() for batch_idx, batch in tqdm( enumerate(val_loader), diff --git a/src/model/torch_model_harness.py b/src/model/torch_model_harness.py index fee7a53..35d9459 100644 --- a/src/model/torch_model_harness.py +++ b/src/model/torch_model_harness.py @@ -55,6 +55,7 @@ def update_data_stream(self) -> None: def get_cur_data_loaders(self) -> Tuple[DataLoader, DataLoader]: """ Returns a training and validation dataloader compatible with the model input + that will be used for continual learning """ raise NotImplementedError @@ -88,6 +89,14 @@ def _unpack(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: x, y = batch return x, y + @torch.no_grad() + def get_cur_loop_loaders(self) -> Tuple[DataLoader, DataLoader]: + """ + Returns a training and validation dataloader compatible with the model input + that will be used to loop over for inference + """ + return self.get_cur_data_loaders() + @staticmethod def _to_scalar(x: Tensor | float) -> float: if isinstance(x, torch.Tensor):