From 14b169bb2fea5c86f3bc8ca2d4579c0aa2e9f343 Mon Sep 17 00:00:00 2001 From: Sajithkumar Ganesan Date: Fri, 21 Jun 2024 22:42:22 +0100 Subject: [PATCH] #42 Fix Deprecation warnings --- bert_pytorch/dataset/log_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bert_pytorch/dataset/log_dataset.py b/bert_pytorch/dataset/log_dataset.py index 1944b66..41e455f 100644 --- a/bert_pytorch/dataset/log_dataset.py +++ b/bert_pytorch/dataset/log_dataset.py @@ -125,10 +125,10 @@ def collate_fn(self, batch, percentile=100, dynamical_pad=True): output["time_input"].append(time_input) output["time_label"].append(time_label) - output["bert_input"] = torch.tensor(output["bert_input"], dtype=torch.long) - output["bert_label"] = torch.tensor(output["bert_label"], dtype=torch.long) - output["time_input"] = torch.tensor(output["time_input"], dtype=torch.float) - output["time_label"] = torch.tensor(output["time_label"], dtype=torch.float) + output["bert_input"] = torch.from_numpy(np.array(output["bert_input"])) + output["bert_label"] = torch.from_numpy(np.array(output["bert_label"])) + output["time_input"] = torch.from_numpy(np.array(output["time_input"])) + output["time_label"] = torch.from_numpy(np.array(output["time_label"])) return output