Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/diffgram/core/directory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from diffgram.file.file import File
from ..regular.regular import refresh_from_dict
import logging
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
from diffgram.tensorflow_diffgram.diffgram_tensorflow_dataset import DiffgramTensorflowDataset
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator
from multiprocessing.pool import ThreadPool as Pool
Expand Down Expand Up @@ -155,6 +154,7 @@ def to_pytorch(self, transform = None):
Transforms the file list inside the dataset into a pytorch dataset.
:return:
"""
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
file_id_list = self.file_id_list
pytorch_dataset = DiffgramPytorchDataset(
project = self.client,
Expand Down
3 changes: 1 addition & 2 deletions sdk/diffgram/core/sliced_directory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from diffgram.core.directory import Directory
from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
from diffgram.tensorflow_diffgram.diffgram_tensorflow_dataset import DiffgramTensorflowDataset
import urllib

Expand Down Expand Up @@ -37,7 +36,7 @@ def to_pytorch(self, transform = None):
Transforms the file list inside the dataset into a pytorch dataset.
:return:
"""

from diffgram.pytorch_diffgram.diffgram_pytorch_dataset import DiffgramPytorchDataset
pytorch_dataset = DiffgramPytorchDataset(
project = self.client,
diffgram_file_id_list = self.file_id_list,
Expand Down
7 changes: 6 additions & 1 deletion sdk/diffgram/pytorch_diffgram/diffgram_pytorch_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from torch.utils.data import Dataset, DataLoader
import torch as torch # type: ignore
from diffgram.core.diffgram_dataset_iterator import DiffgramDatasetIterator

try:
import torch as torch # type: ignore
except ModuleNotFoundError:
raise ModuleNotFoundError(
"'torch' module should be installed to convert the Dataset into torch (pytorch) format"
)

class DiffgramPytorchDataset(DiffgramDatasetIterator, Dataset):

Expand Down