run_segment = optimal_grad_checkpointing(model, inp)
run_segment, optimizer = apex.amp.initialize(run_segment, optimizer, opt_level="02", verbosity=0)
...
output = run_segment(images)
output = run_segment(images)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/apex/amp/_initialize.py", line 197, in new_fwd
**applier(kwargs, input_caster))
File "/working_dir/OptimalGradCheckpointing/graph.py", line 911, in forward
return graph_forward(x, **self.info_dict)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 838, in graph_forward
output = checkpoint(segment_checkpoint_forward(op), input)
File "/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 155, in checkpoint
return CheckpointFunction.apply(function, preserve, *args)
File "/opt/conda/lib/python3.6/site-packages/torch/utils/checkpoint.py", line 74, in forward
outputs = run_function(*args)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 807, in custom_forward
outputs = segment(*inputs)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 911, in forward
return graph_forward(x, **self.info_dict)
File "/working_dir/OptimalGradCheckpointing/graph.py", line 840, in graph_forward
output = op(input)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 349, in forward
return self._conv_forward(input, self.weight)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 346, in _conv_forward
self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same
It would be effective to combine Optimal Gradient Checkpointing with apex.amp or torch.cuda.amp
I use code like this
and get the error
It would be effective to combine Optimal Gradient Checkpointing with apex.amp or torch.cuda.amp