Skip to content

Problem with PyTorch Version 1.10 #3

@karinaodm

Description

@karinaodm

Hi, I am trying to reproduce the results. It works correctly with PyTorch 1.5, but with PyTorch 1.10 - Parsing Computation Graph with torch.jit failed and with manual parse_graph function it takes up twice as much GPU memory.
Output with PyTorch Version 1.10.0a0+0aef44c (nvcr.io/nvidia/pytorch:21.10-py3 docker container):

Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
Parsing Computation Graph with torch.jit failed, revert to manual parse_graph function
Building Division Tree
Getting Max Terms
Solving Optimal for Each Max Term
100%|████████████████████████████████████████████████████| 330/330 [00:02<00:00, 138.06it/s]
Solving optimal gradient checkpointing takes 2.7020 s
/opt/conda/lib/python3.8/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Parsed graph forward check passed
Run graph forward check passed
Parsed graph backward check passed
Run graph backward check passed
/opt/conda/lib/python3.8/site-packages/torch/cuda/memory.py:271: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
100%|█████████████████████████████████████████████████████| 100/100 [00:24<00:00,  4.12it/s]
100%|█████████████████████████████████████████████████████| 100/100 [00:30<00:00,  3.25it/s]
Average Iteration Time: Checkpointing 0.3082 s, Regular 0.2427 s, overhead 26.99%
Average Peak Memory: Checkpointing 5251.0508 MB, Regular 8157.9248 MB, Memory Cut off 35.63%
Average Intermediate Tensors: Checkpointing 1023.0098 MB, Regular 3929.8838 MB, Memory Cut off 73.97%

Output after commenting the "try" at https://github.com/lordfjw/OptimalGradCheckpointing/blob/main/benchmark.py#L167

Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
Traceback (most recent call last):
  File "benchmark.py", line 212, in <module>
    main(arch, device)
  File "benchmark.py", line 168, in main
    G, source, target = parse_computation_graph(net, inputs)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 34, in parse_computation_graph
    computation_graph, input_node_ids, output_node_ids = parse_raw_computation_graph_from_jit(module, inputs)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 55, in parse_raw_computation_graph_from_jit
    computation_graph, _, input_node_ids, output_node_ids = build_computation_graph_recursively(module, inputs, inputs_nodes_ids=None, outputs_nodes_ids=None, cur_node_idx=None)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 412, in build_computation_graph_recursively
    internal_node_dicts = [parse_node_str(n) for n in graph_nodes]
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 412, in <listcomp>
    internal_node_dicts = [parse_node_str(n) for n in graph_nodes]
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 162, in parse_node_str
    shape = [int(s) for s in shape_str.split(', ')]
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 162, in <listcomp>
    shape = [int(s) for s in shape_str.split(', ')]
ValueError: invalid literal for int() with base 10: 'strides=[2048'

Output with PyTorch Version 1.5.0a0+8f84ded (nvcr.io/nvidia/pytorch:20.03-py3 docker container)

Processing resnet101, Input size (32, 3, 224, 224)--------------------
Parsing Computation Graph
/opt/conda/lib/python3.6/site-packages/torch/tensor.py:746: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
  warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
Building Division Tree
Getting Max Terms
Solving Optimal for Each Max Term
100%|██████████████████████████████████████████████████| 350/350 [00:03<00:00, 95.77it/s]
Solving optimal gradient checkpointing takes 4.1945 s
/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py:25: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Parsed graph forward check passed
Run graph forward check passed
Parsed graph backward check passed
Run graph backward check passed
100%|█████████████████████████████████████████████████████| 100/100 [00:19<00:00,  5.02it/s]
100%|█████████████████████████████████████████████████████| 100/100 [00:26<00:00,  3.80it/s]
Average Iteration Time: Checkpointing 0.2623 s, Regular 0.1983 s, overhead 32.28%
Average Peak Memory: Checkpointing 1524.6592 MB, Regular 4306.1680 MB, Memory Cut off 64.59%
Average Intermediate Tensors: Checkpointing 1145.3750 MB, Regular 3926.8838 MB, Memory Cut off 70.83%

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions