Hello 👋
I am having issues with using mup.get_coord_data because some of my modules return dataclass objects. Currently only, dict, list, tuple and tensors are supported. It would be great, and fairly easy, to also support dataclasses.
I think that the only code to modify would be
|
def get_stat(d, x, fdict): |
|
if isinstance(x, (tuple, list)): |
|
for i, _x in enumerate(x): |
|
_d = copy(d) |
|
_d['module'] += f'[{i}]' |
|
get_stat(_d, _x, fdict) |
|
elif isinstance(x, dict): |
|
for name, _x in x.items(): |
|
_d = copy(d) |
|
_d['module'] += f'[{name}]' |
|
get_stat(_d, _x, fdict) |
|
elif isinstance(x, torch.Tensor): |
|
_d = copy(d) |
|
for fname, f in fdict.items(): |
|
_d[fname] = f(x).item() |
|
records.append(_d) |
|
elif x is None: |
|
pass |
|
else: |
|
raise NotImplementedError(f'Unexpected output type: {type(x)}') |
I can do a PR for that.
Hello 👋
I am having issues with using
mup.get_coord_databecause some of my modules return dataclass objects. Currently only,dict,list,tupleand tensors are supported. It would be great, and fairly easy, to also support dataclasses.I think that the only code to modify would be
mup/mup/coord_check.py
Lines 129 to 148 in 1981497
I can do a PR for that.