From c8fea7ec76e27fb3034d4a232a6113311e065219 Mon Sep 17 00:00:00 2001 From: HELSON Date: Mon, 27 Mar 2023 14:27:10 +0800 Subject: [PATCH 1/2] [fx] meta registration compatibility --- colossalai/fx/_compatibility.py | 18 ++++-- ...ta_registrations.py => _meta_regist_12.py} | 0 colossalai/fx/_meta_regist_13.py | 57 +++++++++++++++++++ colossalai/fx/profiler/tensor.py | 1 + 4 files changed, 72 insertions(+), 4 deletions(-) rename colossalai/fx/{_meta_registrations.py => _meta_regist_12.py} (100%) create mode 100644 colossalai/fx/_meta_regist_13.py diff --git a/colossalai/fx/_compatibility.py b/colossalai/fx/_compatibility.py index 126403270301..6caad920d2ae 100644 --- a/colossalai/fx/_compatibility.py +++ b/colossalai/fx/_compatibility.py @@ -2,11 +2,21 @@ import torch -try: - from . import _meta_registrations - META_COMPATIBILITY = True -except: +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) + +if TORCH_MAJOR == 1 and TORCH_MINOR < 12: META_COMPATIBILITY = False +elif TORCH_MAJOR == 1 and TORCH_MINOR == 12: + from . import _meta_regist_12 + META_COMPATIBILITY = True +elif TORCH_MAJOR == 1 and TORCH_MINOR == 13: + from . import _meta_regist_13 + META_COMPATIBILITY = True +elif TORCH_MAJOR == 2: + from . import _meta_regist_13 + META_COMPATIBILITY = True + raise UserWarning("Colossalai is not tested with torch2.0 yet!!!") def compatibility(is_backward_compatible: bool = False) -> Callable: diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_regist_12.py similarity index 100% rename from colossalai/fx/_meta_registrations.py rename to colossalai/fx/_meta_regist_12.py diff --git a/colossalai/fx/_meta_regist_13.py b/colossalai/fx/_meta_regist_13.py new file mode 100644 index 000000000000..6caa87c449ab --- /dev/null +++ b/colossalai/fx/_meta_regist_13.py @@ -0,0 +1,57 @@ +import torch +from torch._meta_registrations import register_meta +from torch._prims_common import check + +aten = torch.ops.aten + + +# since we fix the torch version to 1.13.1, we have to add unimplemented meta ops +# all these functions are from here https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py +@register_meta([aten.convolution_backward.default]) +def meta_convolution_backward( + grad_output_, + input_, + weight_, + bias_sizes_opt, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + # High level logic taken from slow_conv3d_backward_cpu which should + # be representative of all convolution_backward impls + backend_grad_input = None + backend_grad_weight = None + backend_grad_bias = None + + if output_mask[0]: + backend_grad_input = grad_output_.new_empty(input_.size()) + if output_mask[1]: + backend_grad_weight = grad_output_.new_empty(weight_.size()) + if output_mask[2]: + backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) + + return (backend_grad_input, backend_grad_weight, backend_grad_bias) + + +@register_meta(aten._adaptive_avg_pool2d_backward.default) +def meta__adaptive_avg_pool2d_backward(grad_out, self): + ndim = grad_out.ndim + for i in range(1, ndim): + check( + grad_out.size(i) > 0, + lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ + size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", + ) + check( + ndim == 3 or ndim == 4, + lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", + ) + check( + self.dtype == grad_out.dtype, + lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", + ) + return self.new_empty(self.shape) diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 2ee5e5c47750..7298dff8737e 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -80,6 +80,7 @@ def unwrap(x): kwargs['device'] = torch.device('meta') # run aten for backend=CPU but actually on backend=Meta + print("FUNC", func.__name__) out = func(*args, **kwargs) # here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy From de9f2d4a07d0edc357bef8525ae06ad863b7357e Mon Sep 17 00:00:00 2001 From: HELSON Date: Mon, 27 Mar 2023 14:30:00 +0800 Subject: [PATCH 2/2] fix error --- colossalai/fx/profiler/tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py index 7298dff8737e..2ee5e5c47750 100644 --- a/colossalai/fx/profiler/tensor.py +++ b/colossalai/fx/profiler/tensor.py @@ -80,7 +80,6 @@ def unwrap(x): kwargs['device'] = torch.device('meta') # run aten for backend=CPU but actually on backend=Meta - print("FUNC", func.__name__) out = func(*args, **kwargs) # here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy