Skip to content

Can't use run_segment with apex.amp #4

@karinaodm

Description

@karinaodm

I use code like this

run_segment = optimal_grad_checkpointing(model, inp)
run_segment, optimizer = apex.amp.initialize(run_segment, optimizer, opt_level="02", verbosity=0)
...
output = run_segment(images)

and get the error

output = run_segment(images)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
    **applier(kwargs, input_caster))
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 911, in forward
    return graph_forward(x, **self.info_dict)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 838, in graph_forward
    output = checkpoint(segment_checkpoint_forward(op), input)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 155, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
  File "/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 74, in forward
    outputs = run_function(*args)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 807, in custom_forward
    outputs = segment(*inputs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 911, in forward
    return graph_forward(x, **self.info_dict)
  File "/working_dir/OptimalGradCheckpointing/graph.py", line 840, in graph_forward
    output = op(input)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 349, in forward
    return self._conv_forward(input, self.weight)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 346, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

It would be effective to combine Optimal Gradient Checkpointing with apex.amp or torch.cuda.amp

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions