diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 5712505ae2ff..5301c87b9836 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -7,7 +7,7 @@ from .colo_tensor import _convert_output -WHITE_LIST_FUNCS = {torch.Tensor.__getitem__} +WHITE_LIST_FUNCS = {torch.Tensor.__getitem__, torch.Tensor.is_floating_point} def is_no_hook_op(func) -> bool: