Hi,
thank you so much for providing this code!
Using the automatic computation graph parser, I was able to use the optimal gradient checkpoints during model training without writing much additional code. Now I can train with almost 2x larger batches, which is very helpful for my application!
I just had a few minor issues when running the code and I want to quickly mention them here in case anyone else experiences the same:
- In graph.py line 429: The assertion
len(input_node_names) == len(inputs_nodes_ids) fails because the list inputs_nodes_ids contains None. It works after removing all None from the list (inputs_nodes_ids = [i for i in inputs_nodes_ids if i is not None]). However, I'm not sure if doing this could have any adverse effects??
- In graph.py line 159: Parsing the shape string fails because the
node_type looks sth like "Float(2, 1024, strides=[1024, 1], requires_grad=0, device=cuda:0)" (I used pytorch v1.9).
Quick fix:
if 'strides' in node_type:
shape_str = node_type.split('(')[-1].split(', strides')[0]
else:
shape_str = node_type.split('(')[-1].split(')')[0]
- I also had to add a few lines of code to the
get_python_module_from_node_op function (in graph.py) in order to handle 'prim:ListUnpack', 'aten::constant_pad_nd', 'aten::squeeze', but this was straightforward based on the examples in your code :)
Thank you again for making my life easier!
Hi,
thank you so much for providing this code!
Using the automatic computation graph parser, I was able to use the optimal gradient checkpoints during model training without writing much additional code. Now I can train with almost 2x larger batches, which is very helpful for my application!
I just had a few minor issues when running the code and I want to quickly mention them here in case anyone else experiences the same:
len(input_node_names) == len(inputs_nodes_ids)fails because the listinputs_nodes_idscontainsNone. It works after removing allNonefrom the list (inputs_nodes_ids = [i for i in inputs_nodes_ids if i is not None]). However, I'm not sure if doing this could have any adverse effects??node_typelooks sth like"Float(2, 1024, strides=[1024, 1], requires_grad=0, device=cuda:0)"(I used pytorch v1.9).Quick fix:
get_python_module_from_node_opfunction (in graph.py) in order to handle'prim:ListUnpack', 'aten::constant_pad_nd', 'aten::squeeze', but this was straightforward based on the examples in your code :)Thank you again for making my life easier!