diff --git a/bert_pytorch/dataset/log_dataset.py b/bert_pytorch/dataset/log_dataset.py index 1944b66..41e455f 100644 --- a/bert_pytorch/dataset/log_dataset.py +++ b/bert_pytorch/dataset/log_dataset.py @@ -125,10 +125,10 @@ def collate_fn(self, batch, percentile=100, dynamical_pad=True): output["time_input"].append(time_input) output["time_label"].append(time_label) - output["bert_input"] = torch.tensor(output["bert_input"], dtype=torch.long) - output["bert_label"] = torch.tensor(output["bert_label"], dtype=torch.long) - output["time_input"] = torch.tensor(output["time_input"], dtype=torch.float) - output["time_label"] = torch.tensor(output["time_label"], dtype=torch.float) + output["bert_input"] = torch.from_numpy(np.array(output["bert_input"])) + output["bert_label"] = torch.from_numpy(np.array(output["bert_label"])) + output["time_input"] = torch.from_numpy(np.array(output["time_input"])) + output["time_label"] = torch.from_numpy(np.array(output["time_label"])) return output