Skip to content

Commit e1cfdf2

Browse files
committed
Support PyTorch state_dict serialization
1 parent 851ca7a commit e1cfdf2

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

simvue/serialization.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,27 @@ def serialize(self, data, allow_pickle=False):
99
return serializer(data)
1010
return None, None
1111

12+
def _is_torch_tensor(data):
13+
"""
14+
Check if a dictionary is a PyTorch tensor or state dict
15+
"""
16+
module_name = data.__class__.__module__
17+
class_name = data.__class__.__name__
18+
19+
if module_name == 'collections' and class_name == 'OrderedDict':
20+
valid = True
21+
for item in data:
22+
module_name = data[item].__class__.__module__
23+
class_name = data[item].__class__.__name__
24+
if module_name != 'torch' or class_name != 'Tensor':
25+
valid = False
26+
if valid:
27+
return True
28+
elif module_name == 'torch' and class_name == 'Tensor':
29+
return True
30+
31+
return False
32+
1233
def get_serializer(data, allow_pickle):
1334
"""
1435
Determine which serializer to use
@@ -24,7 +45,7 @@ def get_serializer(data, allow_pickle):
2445
return _serialize_numpy_array
2546
elif module_name == 'pandas.core.frame' and class_name == 'DataFrame':
2647
return _serialize_dataframe
27-
elif module_name == 'torch' and class_name == 'Tensor':
48+
elif _is_torch_tensor(data):
2849
return _serialize_torch_tensor
2950
elif allow_pickle:
3051
return _serialize_pickle

0 commit comments

Comments
 (0)