Hi, I am trying to reproduce the results. It works correctly with PyTorch 1.5, but with PyTorch 1.10 - Parsing Computation Graph with torch.jit failed and with manual parse_graph function it takes up twice as much GPU memory.
Output with PyTorch Version 1.10.0a0+0aef44c (nvcr.io/nvidia/pytorch:21.10-py3 docker container):
Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
Parsing Computation Graph with torch.jit failed, revert to manual parse_graph function
Building Division Tree
Getting Max Terms
Solving Optimal for Each Max Term
100%|████████████████████████████████████████████████████| 330/330 [00:02<00:00, 138.06it/s]
Solving optimal gradient checkpointing takes 2.7020 s
/opt/conda/lib/python3.8/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Parsed graph forward check passed
Run graph forward check passed
Parsed graph backward check passed
Run graph backward check passed
/opt/conda/lib/python3.8/site-packages/torch/cuda/memory.py:271: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
warnings.warn(
100%|█████████████████████████████████████████████████████| 100/100 [00:24<00:00, 4.12it/s]
100%|█████████████████████████████████████████████████████| 100/100 [00:30<00:00, 3.25it/s]
Average Iteration Time: Checkpointing 0.3082 s, Regular 0.2427 s, overhead 26.99%
Average Peak Memory: Checkpointing 5251.0508 MB, Regular 8157.9248 MB, Memory Cut off 35.63%
Average Intermediate Tensors: Checkpointing 1023.0098 MB, Regular 3929.8838 MB, Memory Cut off 73.97%
Output after commenting the "try" at https://github.com/lordfjw/OptimalGradCheckpointing/blob/main/benchmark.py#L167
Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
Traceback (most recent call last):
File "benchmark.py", line 212, in <module>
main(arch, device)
File "benchmark.py", line 168, in main
G, source, target = parse_computation_graph(net, inputs)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 34, in parse_computation_graph
computation_graph, input_node_ids, output_node_ids = parse_raw_computation_graph_from_jit(module, inputs)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 55, in parse_raw_computation_graph_from_jit
computation_graph, _, input_node_ids, output_node_ids = build_computation_graph_recursively(module, inputs, inputs_nodes_ids=None, outputs_nodes_ids=None, cur_node_idx=None)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 412, in build_computation_graph_recursively
internal_node_dicts = [parse_node_str(n) for n in graph_nodes]
File "/working_dir/OptimalGradCheckpointing/graph.py", line 412, in <listcomp>
internal_node_dicts = [parse_node_str(n) for n in graph_nodes]
File "/working_dir/OptimalGradCheckpointing/graph.py", line 162, in parse_node_str
shape = [int(s) for s in shape_str.split(', ')]
File "/working_dir/OptimalGradCheckpointing/graph.py", line 162, in <listcomp>
shape = [int(s) for s in shape_str.split(', ')]
ValueError: invalid literal for int() with base 10: 'strides=[2048'
Output with PyTorch Version 1.5.0a0+8f84ded (nvcr.io/nvidia/pytorch:20.03-py3 docker container)
Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
/opt/conda/lib/python3.6/site-packages/torch/tensor.py:746: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
Building Division Tree
Getting Max Terms
Solving Optimal for Each Max Term
100%|██████████████████████████████████████████████████| 350/350 [00:03<00:00, 95.77it/s]
Solving optimal gradient checkpointing takes 4.1945 s
/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Parsed graph forward check passed
Run graph forward check passed
Parsed graph backward check passed
Run graph backward check passed
100%|█████████████████████████████████████████████████████| 100/100 [00:19<00:00, 5.02it/s]
100%|█████████████████████████████████████████████████████| 100/100 [00:26<00:00, 3.80it/s]
Average Iteration Time: Checkpointing 0.2623 s, Regular 0.1983 s, overhead 32.28%
Average Peak Memory: Checkpointing 1524.6592 MB, Regular 4306.1680 MB, Memory Cut off 64.59%
Average Intermediate Tensors: Checkpointing 1145.3750 MB, Regular 3926.8838 MB, Memory Cut off 70.83%
Hi, I am trying to reproduce the results. It works correctly with PyTorch 1.5, but with PyTorch 1.10 -
Parsing Computation Graph with torch.jit failedand with manual parse_graph function it takes up twice as much GPU memory.Output with PyTorch Version 1.10.0a0+0aef44c (nvcr.io/nvidia/pytorch:21.10-py3 docker container):
Output after commenting the "try" at https://github.com/lordfjw/OptimalGradCheckpointing/blob/main/benchmark.py#L167
Output with PyTorch Version 1.5.0a0+8f84ded (nvcr.io/nvidia/pytorch:20.03-py3 docker container)