From be9711ff27e081738b663d89d121879ab0e70b01 Mon Sep 17 00:00:00 2001 From: lucylq Date: Mon, 29 Sep 2025 14:32:46 -0700 Subject: [PATCH] Move tensor layout into exir TensorLayout is used by the data_serializer.py abstract class, which is part of //executorch/exir:lib https://www.internalfb.com/code/fbsource/[721dbb621c0b8c460715e5c49bc9c12757e8ccc8]/fbcode/executorch/exir/_serialize/data_serializer.py?lines=6 Shouldn't add extension/flat_tensor:flat_tensor_schema as a dependency to executorch/exir:lib. Note the C++ equivalent is in runtime/core, so I think this makes sense. Differential Revision: [D83504588](https://our.internmc.facebook.com/intern/diff/D83504588/) [ghstack-poisoned] --- exir/TARGETS | 10 +++++++++ exir/_serialize/TARGETS | 1 + exir/_serialize/_serialize.py | 2 +- exir/_serialize/data_serializer.py | 2 +- exir/tensor_layout.py | 21 +++++++++++++++++++ extension/flat_tensor/serialize/TARGETS | 3 +++ .../flat_tensor/serialize/flat_tensor.fbs | 2 ++ .../serialize/flat_tensor_schema.py | 9 +------- extension/flat_tensor/test/test_serialize.py | 2 +- 9 files changed, 41 insertions(+), 11 deletions(-) create mode 100644 exir/tensor_layout.py diff --git a/exir/TARGETS b/exir/TARGETS index 853d5e199ba..402e9a21bd1 100644 --- a/exir/TARGETS +++ b/exir/TARGETS @@ -79,6 +79,16 @@ runtime.python_library( ], ) +runtime.python_library( + name = "tensor_layout", + srcs = [ + "tensor_layout.py", + ], + deps = [ + ":scalar_type", + ] +) + runtime.python_library( name = "memory", srcs = [ diff --git a/exir/_serialize/TARGETS b/exir/_serialize/TARGETS index 83a2d4957ce..7163da25ff7 100644 --- a/exir/_serialize/TARGETS +++ b/exir/_serialize/TARGETS @@ -65,5 +65,6 @@ runtime.python_library( deps = [ "//executorch/exir:schema", "//executorch/exir:tensor", + "//executorch/exir:tensor_layout", ], ) diff --git a/exir/_serialize/_serialize.py b/exir/_serialize/_serialize.py index e2147458545..06e81997654 100644 --- a/exir/_serialize/_serialize.py +++ b/exir/_serialize/_serialize.py @@ -16,12 +16,12 @@ DataEntry, DataPayload, DataSerializer, - TensorLayout, ) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.emit import EmitterOutput from executorch.exir.schema import Tensor, TensorDataLocation +from executorch.exir.tensor_layout import TensorLayout def serialize_for_executorch( diff --git a/exir/_serialize/data_serializer.py b/exir/_serialize/data_serializer.py index e828b4d0ae3..cee34506b66 100644 --- a/exir/_serialize/data_serializer.py +++ b/exir/_serialize/data_serializer.py @@ -3,7 +3,7 @@ from typing import Dict, Optional, Sequence from executorch.exir._serialize._cord import Cord -from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorLayout +from executorch.exir.tensor_layout import TensorLayout @dataclass diff --git a/exir/tensor_layout.py b/exir/tensor_layout.py new file mode 100644 index 00000000000..f8f77ebeea3 --- /dev/null +++ b/exir/tensor_layout.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from dataclasses import dataclass +from typing import List + +from executorch.exir.scalar_type import ScalarType + + +# Note: keep this in sync with the TensorLayout definition in +# executorch/extension/flat_tensor/serialize/flat_tensor.fbs +@dataclass +class TensorLayout: + scalar_type: ScalarType + sizes: List[int] + dim_order: List[int] diff --git a/extension/flat_tensor/serialize/TARGETS b/extension/flat_tensor/serialize/TARGETS index 229f6930f4e..b9ccadf9f23 100644 --- a/extension/flat_tensor/serialize/TARGETS +++ b/extension/flat_tensor/serialize/TARGETS @@ -13,6 +13,9 @@ runtime.python_library( visibility = [ "//executorch/...", ], + deps = [ + "//executorch/exir:tensor_layout", + ] ) runtime.python_library( diff --git a/extension/flat_tensor/serialize/flat_tensor.fbs b/extension/flat_tensor/serialize/flat_tensor.fbs index abf331697d6..4b71e13e2c4 100644 --- a/extension/flat_tensor/serialize/flat_tensor.fbs +++ b/extension/flat_tensor/serialize/flat_tensor.fbs @@ -7,6 +7,8 @@ namespace flat_tensor_flatbuffer; file_identifier "FT01"; file_extension "ptd"; +// Note: keep this in sync with the python definition in +// executorch/exir/tensor_layout.py table TensorLayout { scalar_type: executorch_flatbuffer.ScalarType; diff --git a/extension/flat_tensor/serialize/flat_tensor_schema.py b/extension/flat_tensor/serialize/flat_tensor_schema.py index 53b0fe98ea9..2fcf2c6eb81 100644 --- a/extension/flat_tensor/serialize/flat_tensor_schema.py +++ b/extension/flat_tensor/serialize/flat_tensor_schema.py @@ -9,18 +9,11 @@ from dataclasses import dataclass from typing import List, Optional -from executorch.exir.scalar_type import ScalarType +from executorch.exir.tensor_layout import TensorLayout # Note: check executorch/extension/data_format/flat_tensor.fbs for explanations of these fields. -@dataclass -class TensorLayout: - scalar_type: ScalarType - sizes: List[int] - dim_order: List[int] - - @dataclass class DataSegment: offset: int diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index 13402e60a65..726a8845c2e 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -22,7 +22,7 @@ from executorch.exir._serialize.padding import aligned_size from executorch.exir.schema import ScalarType -from executorch.extension.flat_tensor.serialize.flat_tensor_schema import TensorLayout +from executorch.exir.tensor_layout import TensorLayout from executorch.extension.flat_tensor.serialize.serialize import ( _deserialize_to_flat_tensor,