From 8293b543fe40a786f50a71b9e4b373b18c84fd33 Mon Sep 17 00:00:00 2001 From: Avasam Date: Wed, 18 Oct 2023 21:52:11 -0400 Subject: [PATCH 1/6] Fix and run protobuf generation scripts --- scripts/generate_proto_stubs.sh | 12 +- scripts/sync_tensorflow_protobuf_stubs.sh | 54 +- stubs/protobuf/METADATA.toml | 2 +- stubs/tensorflow/METADATA.toml | 1 + .../compiler/xla/autotune_results_pb2.pyi | 94 +++ .../compiler/xla/service/hlo_pb2.pyi | 122 ++- .../compiler/xla/service/metrics_pb2.pyi | 73 ++ .../tensorflow/compiler/xla/xla_data_pb2.pyi | 151 ++-- .../tensorflow/core/example/example_pb2.pyi | 28 +- .../tensorflow/core/example/feature_pb2.pyi | 4 +- .../tensorflow/core/framework/api_def_pb2.pyi | 6 +- .../core/framework/dataset_options_pb2.pyi | 15 +- .../tensorflow/core/framework/dataset_pb2.pyi | 105 +++ .../core/framework/full_type_pb2.pyi | 76 +- .../core/framework/function_pb2.pyi | 12 +- .../tensorflow/core/framework/graph_pb2.pyi | 10 +- .../framework/graph_transfer_info_pb2.pyi | 2 +- .../core/framework/node_def_pb2.pyi | 10 +- .../tensorflow/core/framework/op_def_pb2.pyi | 4 +- .../optimized_function_graph_pb2.pyi | 83 +++ .../tensorflow/core/framework/summary_pb2.pyi | 2 +- .../tensorflow/core/framework/tensor_pb2.pyi | 10 +- .../core/framework/tensor_shape_pb2.pyi | 8 +- .../core/framework/tensor_slice_pb2.pyi | 2 +- .../tensorflow/core/framework/types_pb2.pyi | 24 +- .../core/framework/versions_pb2.pyi | 6 +- .../core/protobuf/bfc_memory_map_pb2.pyi | 162 +--- .../tensorflow/core/protobuf/cluster_pb2.pyi | 22 +- .../protobuf/composite_tensor_variant_pb2.pyi | 2 +- .../tensorflow/core/protobuf/config_pb2.pyi | 260 +++---- .../protobuf/core_platform_payloads_pb2.pyi | 2 +- .../core/protobuf/data_service_pb2.pyi | 8 +- .../tensorflow/core/protobuf/debug_pb2.pyi | 6 +- .../core/protobuf/device_filters_pb2.pyi | 12 +- .../core/protobuf/fingerprint_pb2.pyi | 12 +- .../core/protobuf/meta_graph_pb2.pyi | 28 +- .../core/protobuf/named_tensor_pb2.pyi | 2 +- .../core/protobuf/rewriter_config_pb2.pyi | 12 +- .../core/protobuf/rpc_options_pb2.pyi | 10 + .../core/protobuf/saved_object_graph_pb2.pyi | 28 +- .../tensorflow/core/protobuf/saver_pb2.pyi | 2 +- .../core/protobuf/service_config_pb2.pyi | 10 +- .../tensorflow/core/protobuf/snapshot_pb2.pyi | 25 + .../tensorflow/core/protobuf/struct_pb2.pyi | 10 +- .../core/protobuf/tensor_bundle_pb2.pyi | 8 +- .../core/protobuf/tensorflow_server_pb2.pyi | 6 +- .../protobuf/tpu/compilation_result_pb2.pyi | 2 +- .../tpu/optimization_parameters_pb2.pyi | 60 +- .../core/protobuf/tpu/topology_pb2.pyi | 2 +- .../tpu/tpu_embedding_configuration_pb2.pyi | 8 +- .../core/protobuf/verifier_config_pb2.pyi | 2 +- .../tensorflow/core/util/event_pb2.pyi | 37 +- .../tensorflow/core/util/test_log_pb2.pyi | 584 +-------------- .../python/keras/protobuf/versions_pb2.pyi | 8 +- .../tsl/protobuf/autotuning_pb2.pyi | 250 +++++++ .../tsl/protobuf/bfc_memory_map_pb2.pyi | 162 ++++ .../tsl/protobuf/coordination_config_pb2.pyi | 112 +++ .../tsl/protobuf/coordination_service_pb2.pyi | 699 ++++++++++++++++++ .../distributed_runtime_payloads_pb2.pyi | 85 +++ .../tensorflow/tsl/protobuf/dnn_pb2.pyi | 440 +++++++++++ .../tsl/protobuf/error_codes_pb2.pyi | 20 +- .../tsl/protobuf/rpc_options_pb2.pyi | 72 ++ .../tensorflow/tsl/protobuf/test_log_pb2.pyi | 575 ++++++++++++++ 63 files changed, 3446 insertions(+), 1215 deletions(-) create mode 100644 stubs/tensorflow/tensorflow/compiler/xla/autotune_results_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/compiler/xla/service/metrics_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/core/framework/dataset_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/core/framework/optimized_function_graph_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/core/protobuf/rpc_options_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/tsl/protobuf/autotuning_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/tsl/protobuf/bfc_memory_map_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/tsl/protobuf/coordination_config_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/tsl/protobuf/coordination_service_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/tsl/protobuf/distributed_runtime_payloads_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/tsl/protobuf/dnn_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/tsl/protobuf/rpc_options_pb2.pyi create mode 100644 stubs/tensorflow/tensorflow/tsl/protobuf/test_log_pb2.pyi diff --git a/scripts/generate_proto_stubs.sh b/scripts/generate_proto_stubs.sh index 118bc15d26f1..aab56d4ff92a 100755 --- a/scripts/generate_proto_stubs.sh +++ b/scripts/generate_proto_stubs.sh @@ -13,7 +13,7 @@ set -ex -o pipefail # Update these two variables when rerunning script PROTOBUF_VERSION=21.8 PYTHON_PROTOBUF_VERSION=4.21.8 -MYPY_PROTOBUF_VERSION=v3.5.0 +MYPY_PROTOBUF_VERSION=3.5.0 if uname -a | grep Darwin; then # brew install coreutils wget @@ -48,7 +48,7 @@ source "$VENV/bin/activate" pip install -r "$REPO_ROOT/requirements-tests.txt" # for black and isort # Install mypy-protobuf -pip install "git+https://github.com/dropbox/mypy-protobuf@$MYPY_PROTOBUF_VERSION" +pip install mypy-protobuf=="$MYPY_PROTOBUF_VERSION" # Remove existing pyi find "$REPO_ROOT/stubs/protobuf/" -name '*_pb2.pyi' -delete @@ -76,5 +76,9 @@ protoc_install/bin/protoc --proto_path="$PYTHON_PROTOBUF_DIR/src" --mypy_out="re isort "$REPO_ROOT/stubs/protobuf" black "$REPO_ROOT/stubs/protobuf" -sed --in-place="" "s/mypy-protobuf [^\"]*/mypy-protobuf ${MYPY_PROTOBUF_VERSION}/" "$REPO_ROOT/stubs/protobuf/METADATA.toml" -sed --in-place="" "s/version = .*$/version = \"$(echo ${PYTHON_PROTOBUF_VERSION} | cut -d. -f1-2)\.\*\"/" "$REPO_ROOT/stubs/protobuf/METADATA.toml" +sed --in-place="" \ + "s/extra_description = .*$/extra_description = \"Generated with aid from [mypy-protobuf==$MYPY_PROTOBUF_VERSION](https:\/\/github.com\/nipunn1313\/mypy-protobuf\/tree\/v$MYPY_PROTOBUF_VERSION)\"/" \ + "$REPO_ROOT/stubs/protobuf/METADATA.toml" +sed --in-place="" \ + "s/version = .*$/version = \"$(echo ${PYTHON_PROTOBUF_VERSION} | cut -d. -f1-2)\.\*\"/" \ + "$REPO_ROOT/stubs/protobuf/METADATA.toml" diff --git a/scripts/sync_tensorflow_protobuf_stubs.sh b/scripts/sync_tensorflow_protobuf_stubs.sh index a68ea0c2e3f7..9a997a040de2 100755 --- a/scripts/sync_tensorflow_protobuf_stubs.sh +++ b/scripts/sync_tensorflow_protobuf_stubs.sh @@ -3,54 +3,52 @@ set -euxo pipefail # Partly based on scripts/generate_proto_stubs.sh. -# Generates the protobuf stubs for the given tensorflow version using -# mypy-protobuf. Should be run like ./sync_tensorflow_protobuf_stubs.sh +# Generates the protobuf stubs for the given tensorflow version using mypy-protobuf. # Generally, new minor versions are a good time to update the stubs. -cd "$(dirname "$0")" > /dev/null -cd ../stubs/tensorflow REPO_ROOT="$(realpath "$(dirname "${BASH_SOURCE[0]}")"/..)" - # This version should be consistent with the version in tensorflow's METADATA.toml. -TENSORFLOW_VERSION=2.11.0 +TENSORFLOW_VERSION=2.12.1 # Latest mypy-protobuf has dependency on protobuf >4, which is incompatible at runtime # with tensorflow. However, the stubs produced do still work with tensorflow. So after # installing mypy-protobuf, before running stubtest on tensorflow you should downgrade # protobuf<4. -MYPY_PROTOBUF_VERSION=3.4.0 +MYPY_PROTOBUF_VERSION=3.5.0 pip install mypy-protobuf=="$MYPY_PROTOBUF_VERSION" -mkdir repository +cd "$(dirname "$0")" > /dev/null +cd ../stubs/tensorflow +mkdir -p repository pushd repository &> /dev/null - git clone https://github.com/tensorflow/tensorflow.git + # If the script fails halfway, it's nice to be able to re-run it immediatly + if [ ! -d "tensorflow" ] ; then + git clone --depth 1 --branch v"$TENSORFLOW_VERSION" https://github.com/tensorflow/tensorflow.git + fi pushd tensorflow &> /dev/null - git checkout v"$TENSORFLOW_VERSION" - # Folders here cover the more commonly used protobufs externally and # their dependencies. Tensorflow has more protobufs and can be added if requested. protoc --mypy_out "relax_strict_optional_primitives:$REPO_ROOT/stubs/tensorflow" \ + tensorflow/compiler/xla/*.proto \ + tensorflow/compiler/xla/service/*.proto \ + tensorflow/core/example/*.proto \ + tensorflow/core/framework/*.proto \ tensorflow/core/protobuf/*.proto \ tensorflow/core/protobuf/tpu/*.proto \ - tensorflow/core/framework/*.proto \ tensorflow/core/util/*.proto \ - tensorflow/core/example/*.proto \ tensorflow/python/keras/protobuf/*.proto \ - tensorflow/tsl/protobuf/*.proto \ - tensorflow/compiler/xla/*.proto \ - tensorflow/compiler/xla/service/*.proto + tensorflow/tsl/protobuf/*.proto popd &> /dev/null popd &> /dev/null -rm -rf repository/ # These protos exist in a folder with protos used in python, but are not # included in the python wheel. They are likely only used for other # language builds. stubtest was used to identify them by looking for # ModuleNotFoundError. -rm tensorflow/core/protobuf/coordination_service_pb2.pyi \ - tensorflow/compiler/xla/service/hlo_execution_profile_data_pb2.pyi \ +rm tensorflow/compiler/xla/service/hlo_execution_profile_data_pb2.pyi \ tensorflow/compiler/xla/service/hlo_profile_printer_data_pb2.pyi \ tensorflow/compiler/xla/service/test_compilation_environment_pb2.pyi \ + tensorflow/compiler/xla/xla_pb2.pyi \ tensorflow/core/protobuf/autotuning_pb2.pyi \ tensorflow/core/protobuf/conv_autotuning_pb2.pyi \ tensorflow/core/protobuf/critical_section_pb2.pyi \ @@ -58,8 +56,20 @@ rm tensorflow/core/protobuf/coordination_service_pb2.pyi \ tensorflow/core/protobuf/master_pb2.pyi \ tensorflow/core/protobuf/master_service_pb2.pyi \ tensorflow/core/protobuf/replay_log_pb2.pyi \ + tensorflow/core/protobuf/tpu/compile_metadata_pb2.pyi \ tensorflow/core/protobuf/worker_pb2.pyi \ tensorflow/core/protobuf/worker_service_pb2.pyi \ - tensorflow/core/protobuf/tpu/compile_metadata_pb2.pyi \ - tensorflow/core/util/example_proto_fast_parsing_test_pb2.pyi \ - tensorflow/compiler/xla/xla_pb2.pyi + tensorflow/core/util/example_proto_fast_parsing_test_pb2.pyi + +ruff "$REPO_ROOT/stubs/tensorflow/tensorflow" --exit-non-zero-on-fix --fix-only +black "$REPO_ROOT/stubs/tensorflow/tensorflow" + +sed --in-place="" \ + "s/extra_description = .*$/extra_description = \"Partially generated with aid from [mypy-protobuf==$MYPY_PROTOBUF_VERSION](https:\/\/github.com\/nipunn1313\/mypy-protobuf\/tree\/v$MYPY_PROTOBUF_VERSION)\"/" \ + "$REPO_ROOT/stubs/tensorflow/METADATA.toml" +sed --in-place="" \ + "s/version = .*$/version = \"$(echo ${TENSORFLOW_VERSION} | cut -d. -f1-2)\.\*\"/" \ + "$REPO_ROOT/stubs/tensorflow/METADATA.toml" + +# Cleanup last. If the script fails halfway, it's nice to be able to re-run it immediatly +rm -rf repository/ diff --git a/stubs/protobuf/METADATA.toml b/stubs/protobuf/METADATA.toml index 57a243aaa5d1..e0a0fb1034b7 100644 --- a/stubs/protobuf/METADATA.toml +++ b/stubs/protobuf/METADATA.toml @@ -1,6 +1,6 @@ version = "4.24.*" upstream_repository = "https://github.com/protocolbuffers/protobuf" -extra_description = "Generated with aid from [mypy-protobuf](https://github.com/nipunn1313/mypy-protobuf) v3.5.0" +extra_description = "Generated with aid from [mypy-protobuf==3.5.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.5.0)" partial_stub = true [tool.stubtest] diff --git a/stubs/tensorflow/METADATA.toml b/stubs/tensorflow/METADATA.toml index 4b9c51c19820..b4e63e7999ed 100644 --- a/stubs/tensorflow/METADATA.toml +++ b/stubs/tensorflow/METADATA.toml @@ -2,6 +2,7 @@ version = "2.12.*" upstream_repository = "https://github.com/tensorflow/tensorflow" # requires a version of numpy with a `py.typed` file requires = ["numpy>=1.20", "types-protobuf"] +extra_description = "Partially generated with aid from [mypy-protobuf==3.5.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.5.0)" partial_stub = true [tool.stubtest] diff --git a/stubs/tensorflow/tensorflow/compiler/xla/autotune_results_pb2.pyi b/stubs/tensorflow/tensorflow/compiler/xla/autotune_results_pb2.pyi new file mode 100644 index 000000000000..8fc21a775889 --- /dev/null +++ b/stubs/tensorflow/tensorflow/compiler/xla/autotune_results_pb2.pyi @@ -0,0 +1,94 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import sys +import tensorflow.tsl.protobuf.autotuning_pb2 + +if sys.version_info >= (3, 8): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing_extensions.final +class AutotuneResults(google.protobuf.message.Message): + """A collection of algorithms for particular dot/convs. Usually this is "the + best" algorithm for the particular dot/conv, although that's not strictly + required. + + Users don't interact with this proto directly. It's used internally to + facilitate ahead-of-time autotuning -- The string used by + xla::{Serialize,Load}AutotuneResults is, internally, a serialization of this + proto. + + LINT.IfChange + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing_extensions.final + class Entry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DEVICE_FIELD_NUMBER: builtins.int + HLO_FIELD_NUMBER: builtins.int + RESULT_FIELD_NUMBER: builtins.int + device: builtins.str + hlo: builtins.str + @property + def result(self) -> tensorflow.tsl.protobuf.autotuning_pb2.AutotuneResult: + """nb: These results are always tied to a particular version of + cublas/cudnn, but this is *especially* true for cublasLt results. For + cublasLt gemms, the result is an index into the list of candidate + algorithms returned by cublasLt. Different version of cublasLt -> + different list of algos -> different interpretation of results! + """ + def __init__( + self, + *, + device: builtins.str | None = ..., + hlo: builtins.str | None = ..., + result: tensorflow.tsl.protobuf.autotuning_pb2.AutotuneResult | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["result", b"result"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["device", b"device", "hlo", b"hlo", "result", b"result"]) -> None: ... + + VERSION_FIELD_NUMBER: builtins.int + DOTS_FIELD_NUMBER: builtins.int + CONVS_FIELD_NUMBER: builtins.int + version: builtins.int + @property + def dots(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AutotuneResults.Entry]: ... + @property + def convs(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AutotuneResults.Entry]: ... + def __init__( + self, + *, + version: builtins.int | None = ..., + dots: collections.abc.Iterable[global___AutotuneResults.Entry] | None = ..., + convs: collections.abc.Iterable[global___AutotuneResults.Entry] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["convs", b"convs", "dots", b"dots", "version", b"version"]) -> None: ... + +global___AutotuneResults = AutotuneResults diff --git a/stubs/tensorflow/tensorflow/compiler/xla/service/hlo_pb2.pyi b/stubs/tensorflow/tensorflow/compiler/xla/service/hlo_pb2.pyi index b72e95f3e71d..f1df9142232e 100644 --- a/stubs/tensorflow/tensorflow/compiler/xla/service/hlo_pb2.pyi +++ b/stubs/tensorflow/tensorflow/compiler/xla/service/hlo_pb2.pyi @@ -57,21 +57,21 @@ class _CustomCallApiVersionEnumTypeWrapper(google.protobuf.internal.enum_type_wr API_VERSION_UNSPECIFIED: _CustomCallApiVersion.ValueType # 0 API_VERSION_ORIGINAL: _CustomCallApiVersion.ValueType # 1 """The first version of the API, with the following signatures: - + CPU: void do_custom_call(void* out, const void** in); - + GPU: void do_custom_call(CUstream stream, void** buffers, const char* opaque, size_t opaque_len); """ API_VERSION_STATUS_RETURNING: _CustomCallApiVersion.ValueType # 2 """When the ability to return success/failure status was added: - + CPU: void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status); - + GPU: void do_custom_call(CUstream stream, void** buffers, const char* opaque, size_t opaque_len, @@ -83,17 +83,38 @@ class _CustomCallApiVersionEnumTypeWrapper(google.protobuf.internal.enum_type_wr CPUs and GPUs. For GPUs, the behaviors invoked by API_VERSION_STATUS_RETURNING and API_VERSION_STATUS_RETURNING_UNIFIED are the same. - + CPU: void do_custom_call(void* out, const void** in, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); - + GPU: void do_custom_call(CUstream stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); """ + API_VERSION_TYPED_FFI: _CustomCallApiVersion.ValueType # 4 + """Api version implementing XLA runtime custom call calling convention. These + custom calls can be registered as an XLA runtime custom call (1) or as XLA + runtime FFI binding (2). + + This type of custom call uses custom ABI to pass type information along + with custom call arguments. Also it passes buffer arguments together with + data type, sizes and strides. + + Example: (XLA runtime custom call) + + absl::Status DoCustomCall(StridedMemrefView arg, float attr); + + CustomCall::Bind("custom_call") + .Arg() + .Attr("attr") + .To(DoCustomCall); + + (1) xla/runtime/custom_call.h + (2) xla/runtime/ffi/ffi.h + """ class CustomCallApiVersion(_CustomCallApiVersion, metaclass=_CustomCallApiVersionEnumTypeWrapper): """The version of the API used by the custom call function. The signatures for @@ -142,6 +163,27 @@ GPU: const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); """ +API_VERSION_TYPED_FFI: CustomCallApiVersion.ValueType # 4 +"""Api version implementing XLA runtime custom call calling convention. These +custom calls can be registered as an XLA runtime custom call (1) or as XLA +runtime FFI binding (2). + +This type of custom call uses custom ABI to pass type information along +with custom call arguments. Also it passes buffer arguments together with +data type, sizes and strides. + +Example: (XLA runtime custom call) + + absl::Status DoCustomCall(StridedMemrefView arg, float attr); + + CustomCall::Bind("custom_call") + .Arg() + .Attr("attr") + .To(DoCustomCall); + +(1) xla/runtime/custom_call.h +(2) xla/runtime/ffi/ffi.h +""" global___CustomCallApiVersion = CustomCallApiVersion class _Kind: @@ -174,7 +216,7 @@ global___Kind = Kind @typing_extensions.final class HloInstructionProto(google.protobuf.message.Message): """Serialization of HloInstruction. - Next ID: 80 + Next ID: 81 """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -254,7 +296,7 @@ class HloInstructionProto(google.protobuf.message.Message): CHOLESKY_OPTIONS_FIELD_NUMBER: builtins.int PARAMETER_REPLICATION_FIELD_NUMBER: builtins.int CUSTOM_CALL_HAS_SIDE_EFFECT_FIELD_NUMBER: builtins.int - CUSTOM_CALL_OUTPUT_OPERAND_ALIASING_FIELD_NUMBER: builtins.int + OUTPUT_OPERAND_ALIASING_FIELD_NUMBER: builtins.int CUSTOM_CALL_SCHEDULE_FIELD_NUMBER: builtins.int DELTA_FIELD_NUMBER: builtins.int INDICES_ARE_SORTED_FIELD_NUMBER: builtins.int @@ -263,6 +305,7 @@ class HloInstructionProto(google.protobuf.message.Message): RNG_ALGORITHM_FIELD_NUMBER: builtins.int COMPARISON_TYPE_FIELD_NUMBER: builtins.int IS_CROSS_PROGRAM_PREFETCH_FIELD_NUMBER: builtins.int + CROSS_PROGRAM_PREFETCH_INDEX_FIELD_NUMBER: builtins.int PADDING_TYPE_FIELD_NUMBER: builtins.int CUSTOM_CALL_API_VERSION_FIELD_NUMBER: builtins.int ASYNC_GROUP_ID_FIELD_NUMBER: builtins.int @@ -424,9 +467,9 @@ class HloInstructionProto(google.protobuf.message.Message): kCustomCall. """ @property - def custom_call_output_operand_aliasing(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.CustomCallOutputOperandAliasing]: - """A list of CustomCallOutputOperandAliasing pairs that specifies aliasing - buffers between output and operands for kCustomCall. + def output_operand_aliasing(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.OutputOperandAliasing]: + """A list of OutputOperandAliasing pairs that specifies aliasing buffers + between output and operands for kCustomCall and kFusion. """ custom_call_schedule: global___CustomCallSchedule.ValueType """Specifies the desired schedule for the custom-call. The field is only @@ -450,7 +493,10 @@ class HloInstructionProto(google.protobuf.message.Message): comparison_type: builtins.str """The comparison type used for kCompare.""" is_cross_program_prefetch: builtins.bool - """Specifies if this is a cross-program-prefetch, used by kCopyStart.""" + """Specifies if this is a cross-program-prefetch, used by kCopyStart. + Deprecated and replaced by optional_cross_program_prefetch_index. + """ + cross_program_prefetch_index: builtins.int padding_type: tensorflow.compiler.xla.xla_data_pb2.PaddingType.ValueType """If a convolution is dynamic, a dynamic padding type will be specified.""" custom_call_api_version: global___CustomCallApiVersion.ValueType @@ -526,7 +572,7 @@ class HloInstructionProto(google.protobuf.message.Message): cholesky_options: tensorflow.compiler.xla.xla_data_pb2.CholeskyOptions | None = ..., parameter_replication: tensorflow.compiler.xla.xla_data_pb2.ParameterReplication | None = ..., custom_call_has_side_effect: builtins.bool | None = ..., - custom_call_output_operand_aliasing: collections.abc.Iterable[tensorflow.compiler.xla.xla_data_pb2.CustomCallOutputOperandAliasing] | None = ..., + output_operand_aliasing: collections.abc.Iterable[tensorflow.compiler.xla.xla_data_pb2.OutputOperandAliasing] | None = ..., custom_call_schedule: global___CustomCallSchedule.ValueType | None = ..., delta: builtins.int | None = ..., indices_are_sorted: builtins.bool | None = ..., @@ -535,13 +581,15 @@ class HloInstructionProto(google.protobuf.message.Message): rng_algorithm: tensorflow.compiler.xla.xla_data_pb2.RandomAlgorithm.ValueType | None = ..., comparison_type: builtins.str | None = ..., is_cross_program_prefetch: builtins.bool | None = ..., + cross_program_prefetch_index: builtins.int | None = ..., padding_type: tensorflow.compiler.xla.xla_data_pb2.PaddingType.ValueType | None = ..., custom_call_api_version: global___CustomCallApiVersion.ValueType | None = ..., async_group_id: builtins.int | None = ..., async_execution_thread: builtins.str | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "async_group_id", b"async_group_id", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_output_operand_aliasing", b"custom_call_output_operand_aliasing", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "async_group_id", b"async_group_id", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index"]) -> typing_extensions.Literal["cross_program_prefetch_index"] | None: ... global___HloInstructionProto = HloInstructionProto @@ -657,13 +705,13 @@ class HloInputOutputAliasProto(google.protobuf.message.Message): (described by parameter number and a ShapeIndex of the parameter) and an output (described by a ShapeIndex of the root instruction). For example: - + entry = { output_shape_index={1}, parameter_number=0, parameter_shape_index={1, 2}, } - + This entry indicates that the first paremter's {1, 2} element is aliased with the {1} element of the root instruction. """ @@ -712,29 +760,29 @@ class DynamicParameterBindingProto(google.protobuf.message.Message): @typing_extensions.final class Binding(google.protobuf.message.Message): - """A list of bindings which indicates that the `target_dim_num` in + """A list of bindings which indicates that the `target_param_dim_num` in the subshape `target_param_index` of parameter `target_param_num` is a dynamic dimension and its real dynamic size is represented by `dynamic_param_index` in parameter `dynamic_param_num`. - + As an example, imagine we have a program: - + ENTRY main { a = f32[] parameter(0) b = f32[10] parameter(1) ROOT root = (f32[], f32[10]) tuple(%a, %b) } - + Let's say 'b' (param index 1) is a dynamic shape whose input has an upperbound of 10 and real size is determined at runtime.'a' represents the real size of b's first dimension. - + In this case, the fields are set in the following way: dynamic_param_num = 1 dynamic_param_index = {} target_param_num = 0 target_param_index = {} - target_param_dim = 0 + target_param_dim_num = 0 """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -780,16 +828,19 @@ class CrossProgramPrefetch(google.protobuf.message.Message): PARAMETER_FIELD_NUMBER: builtins.int INDEX_FIELD_NUMBER: builtins.int + OFFSET_FIELD_NUMBER: builtins.int parameter: builtins.int @property def index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... + offset: builtins.int def __init__( self, *, parameter: builtins.int | None = ..., index: collections.abc.Iterable[builtins.int] | None = ..., + offset: builtins.int | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["index", b"index", "parameter", b"parameter"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["index", b"index", "offset", b"offset", "parameter", b"parameter"]) -> None: ... global___CrossProgramPrefetch = CrossProgramPrefetch @@ -803,7 +854,7 @@ class HloModuleProto(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _ProfileTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HloModuleProto._ProfileType.ValueType], builtins.type): # noqa: F821 + class _ProfileTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HloModuleProto._ProfileType.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor INVALID: HloModuleProto._ProfileType.ValueType # 0 FLAG: HloModuleProto._ProfileType.ValueType # 1 @@ -940,28 +991,22 @@ class LogicalBufferProto(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - COMPUTATION_NAME_FIELD_NUMBER: builtins.int INSTRUCTION_NAME_FIELD_NUMBER: builtins.int INSTRUCTION_ID_FIELD_NUMBER: builtins.int SHAPE_INDEX_FIELD_NUMBER: builtins.int - computation_name: builtins.str - """NOTE: module_name isn't necessary, since all LogicalBuffers are - associated with a single HloModule. - TODO(b/239098765): Remove instruction_name and computation_name. - """ instruction_name: builtins.str + """TODO(b/239098765): Remove instruction_name and computation_name.""" instruction_id: builtins.int @property def shape_index(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ... def __init__( self, *, - computation_name: builtins.str | None = ..., instruction_name: builtins.str | None = ..., instruction_id: builtins.int | None = ..., shape_index: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... - def ClearField(self, field_name: typing_extensions.Literal["computation_name", b"computation_name", "instruction_id", b"instruction_id", "instruction_name", b"instruction_name", "shape_index", b"shape_index"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["instruction_id", b"instruction_id", "instruction_name", b"instruction_name", "shape_index", b"shape_index"]) -> None: ... ID_FIELD_NUMBER: builtins.int SIZE_FIELD_NUMBER: builtins.int @@ -1076,7 +1121,7 @@ class HeapSimulatorTrace(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _KindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HeapSimulatorTrace.Event._Kind.ValueType], builtins.type): # noqa: F821 + class _KindEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[HeapSimulatorTrace.Event._Kind.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor ALLOC: HeapSimulatorTrace.Event._Kind.ValueType # 0 """A memory region was allocated for the buffer.""" @@ -1488,14 +1533,10 @@ class XlaRuntimeExecutableProto(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor HLO_MODULE_PROTO_FIELD_NUMBER: builtins.int - ENTRY_FUNC_ATTRS_FIELD_NUMBER: builtins.int OBJ_FILE_FIELD_NUMBER: builtins.int MLIR_MODULE_FIELD_NUMBER: builtins.int @property def hlo_module_proto(self) -> global___HloModuleProto: ... - @property - def entry_func_attrs(self) -> global___EntryFunctionAttributes: - """XLA-specific attributes of the executable's entry function.""" obj_file: builtins.bytes """TODO(b/232263665)): Serialized executable has to know what APIs it has to be linked with, including the version. For example Gpu executable must be @@ -1509,11 +1550,10 @@ class XlaRuntimeExecutableProto(google.protobuf.message.Message): self, *, hlo_module_proto: global___HloModuleProto | None = ..., - entry_func_attrs: global___EntryFunctionAttributes | None = ..., obj_file: builtins.bytes | None = ..., mlir_module: builtins.str | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["entry_func_attrs", b"entry_func_attrs", "hlo_module_proto", b"hlo_module_proto"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["entry_func_attrs", b"entry_func_attrs", "hlo_module_proto", b"hlo_module_proto", "mlir_module", b"mlir_module", "obj_file", b"obj_file"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["hlo_module_proto", b"hlo_module_proto"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["hlo_module_proto", b"hlo_module_proto", "mlir_module", b"mlir_module", "obj_file", b"obj_file"]) -> None: ... global___XlaRuntimeExecutableProto = XlaRuntimeExecutableProto diff --git a/stubs/tensorflow/tensorflow/compiler/xla/service/metrics_pb2.pyi b/stubs/tensorflow/tensorflow/compiler/xla/service/metrics_pb2.pyi new file mode 100644 index 000000000000..f39a6bd921d6 --- /dev/null +++ b/stubs/tensorflow/tensorflow/compiler/xla/service/metrics_pb2.pyi @@ -0,0 +1,73 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.duration_pb2 +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import google.protobuf.timestamp_pb2 +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing_extensions.final +class CompilationLogEntry(google.protobuf.message.Message): + """Defines XLA compilation metrics.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class _CompilationStage: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _CompilationStageEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[CompilationLogEntry._CompilationStage.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + UNSPECIFIED: CompilationLogEntry._CompilationStage.ValueType # 0 + END_TO_END: CompilationLogEntry._CompilationStage.ValueType # 1 + HLO_PASSES: CompilationLogEntry._CompilationStage.ValueType # 2 + CODE_GENERATION: CompilationLogEntry._CompilationStage.ValueType # 3 + BACKEND_PASSES: CompilationLogEntry._CompilationStage.ValueType # 4 + + class CompilationStage(_CompilationStage, metaclass=_CompilationStageEnumTypeWrapper): + """Defines compilation stages for which metrics are collected.""" + + UNSPECIFIED: CompilationLogEntry.CompilationStage.ValueType # 0 + END_TO_END: CompilationLogEntry.CompilationStage.ValueType # 1 + HLO_PASSES: CompilationLogEntry.CompilationStage.ValueType # 2 + CODE_GENERATION: CompilationLogEntry.CompilationStage.ValueType # 3 + BACKEND_PASSES: CompilationLogEntry.CompilationStage.ValueType # 4 + + TIMESTAMP_FIELD_NUMBER: builtins.int + STAGE_FIELD_NUMBER: builtins.int + DURATION_FIELD_NUMBER: builtins.int + TASK_INDEX_FIELD_NUMBER: builtins.int + @property + def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: + """Time when the event captured by this log entry occurred.""" + stage: global___CompilationLogEntry.CompilationStage.ValueType + """Compilation stage recorded by this log entry.""" + @property + def duration(self) -> google.protobuf.duration_pb2.Duration: + """Duration of the given compilation stage.""" + task_index: builtins.int + """Task index from which this log entry was recorded.""" + def __init__( + self, + *, + timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + stage: global___CompilationLogEntry.CompilationStage.ValueType | None = ..., + duration: google.protobuf.duration_pb2.Duration | None = ..., + task_index: builtins.int | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["duration", b"duration", "timestamp", b"timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["duration", b"duration", "stage", b"stage", "task_index", b"task_index", "timestamp", b"timestamp"]) -> None: ... + +global___CompilationLogEntry = CompilationLogEntry diff --git a/stubs/tensorflow/tensorflow/compiler/xla/xla_data_pb2.pyi b/stubs/tensorflow/tensorflow/compiler/xla/xla_data_pb2.pyi index 2a10b0fff410..12557de1d5b6 100644 --- a/stubs/tensorflow/tensorflow/compiler/xla/xla_data_pb2.pyi +++ b/stubs/tensorflow/tensorflow/compiler/xla/xla_data_pb2.pyi @@ -54,7 +54,7 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._ U64: _PrimitiveType.ValueType # 9 F16: _PrimitiveType.ValueType # 10 """Floating-point values of fixed width. - + Note: if f16s are not natively supported on the device, they will be converted to f16 from f32 at arbirary points in the computation. """ @@ -65,6 +65,22 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._ and 7 bits for the mantissa. """ F64: _PrimitiveType.ValueType # 12 + F8E5M2: _PrimitiveType.ValueType # 19 + """FP8 dtypes, as described in this paper: https://arxiv.org/abs/2209.05433 + + F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the + existing IEEE types. + + F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only + Finite and NaN values are supported. Unlike IEEE types, infinities are not + supported. NaN is represented when the exponent and mantissa bits are all + 1s. All other values are finite. + + Support for these dtypes is under development. They do not yet work + properly in most cases. + TODO(b/259609697): Fully support FP8. + """ + F8E4M3FN: _PrimitiveType.ValueType # 20 C64: _PrimitiveType.ValueType # 15 """Complex values of fixed width. Paired F32 (real, imag), as in std::complex. @@ -76,7 +92,7 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._ sub-shapes. They are used for things like returning multiple values from a computation; e.g. a computation that returns weights and biases may have a signature that results in a tuple like (f32[784x2000], f32[2000]) - + If a shape proto has the tuple element type, it may not have any entries in the dimensions field. """ @@ -84,7 +100,7 @@ class _PrimitiveTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._ """An opaque type used for passing context-specific data to a custom operation. Shapes of this primitive type will have empty dimensions and tuple_shapes fields. - + (OPAQUE would be a better name for this identifier, but that conflicts with a macro defined in windows.h.) """ @@ -97,7 +113,7 @@ class PrimitiveType(_PrimitiveType, metaclass=_PrimitiveTypeEnumTypeWrapper): """Primitive types are the individual values that can be held in rectangular multidimensional arrays. A description of the rectangular multidimensional array dimensions / primitive type is given by Shape, below. - + LINT.IfChange """ @@ -128,6 +144,22 @@ floating-point format, but uses 1 bit for the sign, 8 bits for the exponent and 7 bits for the mantissa. """ F64: PrimitiveType.ValueType # 12 +F8E5M2: PrimitiveType.ValueType # 19 +"""FP8 dtypes, as described in this paper: https://arxiv.org/abs/2209.05433 + +F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the +existing IEEE types. + +F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only +Finite and NaN values are supported. Unlike IEEE types, infinities are not +supported. NaN is represented when the exponent and mantissa bits are all +1s. All other values are finite. + +Support for these dtypes is under development. They do not yet work +properly in most cases. +TODO(b/259609697): Fully support FP8. +""" +F8E4M3FN: PrimitiveType.ValueType # 20 C64: PrimitiveType.ValueType # 15 """Complex values of fixed width. Paired F32 (real, imag), as in std::complex. @@ -421,24 +453,28 @@ global___TileProto = TileProto class LayoutProto(google.protobuf.message.Message): """A layout describes how the array is placed in (1D) memory space. This includes the minor-to-major ordering of dimensions within a shape. - + Clients must specify the layouts of input Literals to the computation. Layouts specified in interior operations which take Shapes (for example, Convert) are ignored. - + See the XLA documentation for more information on shapes and layouts. - + LINT.IfChange """ DESCRIPTOR: google.protobuf.descriptor.Descriptor DIM_LEVEL_TYPES_FIELD_NUMBER: builtins.int + DIM_UNIQUE_FIELD_NUMBER: builtins.int + DIM_ORDERED_FIELD_NUMBER: builtins.int MINOR_TO_MAJOR_FIELD_NUMBER: builtins.int TILES_FIELD_NUMBER: builtins.int - ELEMENT_SIZE_IN_BITS_FIELD_NUMBER: builtins.int MEMORY_SPACE_FIELD_NUMBER: builtins.int + INDEX_PRIMITIVE_TYPE_FIELD_NUMBER: builtins.int + POINTER_PRIMITIVE_TYPE_FIELD_NUMBER: builtins.int PHYSICAL_SHAPE_FIELD_NUMBER: builtins.int + DYNAMIC_SHAPE_METADATA_PREFIX_BYTES_FIELD_NUMBER: builtins.int @property def dim_level_types(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[global___DimLevelType.ValueType]: """The dimension level type list for this array, specifying the way in which @@ -446,6 +482,16 @@ class LayoutProto(google.protobuf.message.Message): array is assumed to be dense. """ @property + def dim_unique(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: + """Whether each dimension is unique or ordered. Each of the following lists + must be empty, or have one entry for each entry of dim_level_types. If + either list is empty, all dimensions are assumed to be unique and ordered, + respectively. Entries in this list may not be false for some DimLevelType + values (such as DIM_DENSE in particular). + """ + @property + def dim_ordered(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: ... + @property def minor_to_major(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: """Sequence of dimension numbers, from minor (fastest varying index) to major (slowest varying index). This field is required. @@ -454,21 +500,23 @@ class LayoutProto(google.protobuf.message.Message): def tiles(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TileProto]: """A sequence of tiles, starting from the tile that's applied first to the Shape. - + TODO(b/119839262): implement tiling in each backend or add Unimplemented error. """ - element_size_in_bits: builtins.int - """Bit size of each element. If the size is bigger than what the element - type requires, the value is stored in the least significant - bits and the additional most significant bits are filled with 0's. - - TODO(b/119839262): implement in each backend or add Unimplemented error. - """ memory_space: builtins.int """Memory space where this array resides. The integer field is interpreted in a backend-specific manner. """ + index_primitive_type: global___PrimitiveType.ValueType + """The integer types to be used for indices and pointers. These fields must + not be used unless the layout represents a sparse array. The PrimitiveType + must correspond to an unsigned integer (U8, U16, U32, or U64). + If not provided, the compiler will use the largest unsigned integer + that is naturally supported by the target device (U32 or U64 in currently + supported devices). + """ + pointer_primitive_type: global___PrimitiveType.ValueType @property def physical_shape(self) -> global___ShapeProto: """The physical, on-device shape used to represent the shape this layout @@ -476,18 +524,27 @@ class LayoutProto(google.protobuf.message.Message): The layout(s) contained within the physical shape should not also contain a physical shape. """ + dynamic_shape_metadata_prefix_bytes: builtins.int + """The dynamic shape metadata size in bytes in front of the shape data. The + field may be non-zero for a static shape whose associated buffer is for a + dynamic shape, e.g. a result of SliceToDynamic. + """ def __init__( self, *, dim_level_types: collections.abc.Iterable[global___DimLevelType.ValueType] | None = ..., + dim_unique: collections.abc.Iterable[builtins.bool] | None = ..., + dim_ordered: collections.abc.Iterable[builtins.bool] | None = ..., minor_to_major: collections.abc.Iterable[builtins.int] | None = ..., tiles: collections.abc.Iterable[global___TileProto] | None = ..., - element_size_in_bits: builtins.int | None = ..., memory_space: builtins.int | None = ..., + index_primitive_type: global___PrimitiveType.ValueType | None = ..., + pointer_primitive_type: global___PrimitiveType.ValueType | None = ..., physical_shape: global___ShapeProto | None = ..., + dynamic_shape_metadata_prefix_bytes: builtins.int | None = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["physical_shape", b"physical_shape"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["dim_level_types", b"dim_level_types", "element_size_in_bits", b"element_size_in_bits", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "tiles", b"tiles"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["dim_level_types", b"dim_level_types", "dim_ordered", b"dim_ordered", "dim_unique", b"dim_unique", "dynamic_shape_metadata_prefix_bytes", b"dynamic_shape_metadata_prefix_bytes", "index_primitive_type", b"index_primitive_type", "memory_space", b"memory_space", "minor_to_major", b"minor_to_major", "physical_shape", b"physical_shape", "pointer_primitive_type", b"pointer_primitive_type", "tiles", b"tiles"]) -> None: ... global___LayoutProto = LayoutProto @@ -495,12 +552,12 @@ global___LayoutProto = LayoutProto class ShapeProto(google.protobuf.message.Message): """A shape describes the number of dimensions in the array, the size of each dimension, and the primitive component type. - + Tuples are a special case in that they have rank zero and have tuple_shapes defined. - + See the XLA documentation for more information on shapes and layouts. - + LINT.IfChange """ @@ -520,7 +577,7 @@ class ShapeProto(google.protobuf.message.Message): to N-1 for an N-dimensional array. The first element of 'dimensions' is the size of dimension 0, the second element is the size of dimension 1, and so forth. Empty list indicates a scalar. - + If the respective element in 'is_dimension_dynamic' is true then the value in this field represents an upper bound on the size of the dimension. """ @@ -605,7 +662,7 @@ global___ComputationStats = ComputationStats @typing_extensions.final class OpMetadata(google.protobuf.message.Message): """Symbolization metadata for HLO Instructions. - + This metadata is used for debugging XLA code generation, as well as performance profiling of XLA-generated executables. """ @@ -655,7 +712,7 @@ class OpMetadata(google.protobuf.message.Message): PROFILE_INFO_FIELD_NUMBER: builtins.int op_type: builtins.str """The framework op name that generated this XLA op. - + Frameworks that build on top of XLA should mirror the names of their ops back to users by specifying the op_type. In this way, even if the framework's "ops" are implemented as multiple XLA HLO Ops, they can be @@ -664,13 +721,13 @@ class OpMetadata(google.protobuf.message.Message): """ op_name: builtins.str """The user-specified name of the op. - + This name is often unique within a computation. Note: some frameworks add auto-generated names if the user does not provide one. """ source_file: builtins.str """Indicate a file and line that this op is associated to in a user's program. - + e.g. it could be the file and line of user code that generated the op. """ source_line: builtins.int @@ -848,7 +905,7 @@ class ChannelHandle(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _ChannelTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ChannelHandle._ChannelType.ValueType], builtins.type): # noqa: F821 + class _ChannelTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ChannelHandle._ChannelType.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor CHANNEL_TYPE_INVALID: ChannelHandle._ChannelType.ValueType # 0 """Invalid primitive type to serve as default.""" @@ -941,7 +998,7 @@ class LiteralProto(google.protobuf.message.Message): """Literals are used when the server and client need to exchange materialized data / results. Literals are also used to describe constants used in computations. - + Transfers to/from the client are encoded in literal form, and the structure of the repeated fields is implied by the shape. """ @@ -965,6 +1022,8 @@ class LiteralProto(google.protobuf.message.Message): BF16S_FIELD_NUMBER: builtins.int U16S_FIELD_NUMBER: builtins.int S16S_FIELD_NUMBER: builtins.int + F8E5M2S_FIELD_NUMBER: builtins.int + F8E4M3FNS_FIELD_NUMBER: builtins.int SPARSE_INDICES_FIELD_NUMBER: builtins.int @property def shape(self) -> global___ShapeProto: ... @@ -997,9 +1056,11 @@ class LiteralProto(google.protobuf.message.Message): bf16s: builtins.bytes u16s: builtins.bytes s16s: builtins.bytes + f8e5m2s: builtins.bytes + f8e4m3fns: builtins.bytes @property def sparse_indices(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: - """Next = 19""" + """Next = 21""" def __init__( self, *, @@ -1020,10 +1081,12 @@ class LiteralProto(google.protobuf.message.Message): bf16s: builtins.bytes | None = ..., u16s: builtins.bytes | None = ..., s16s: builtins.bytes | None = ..., + f8e5m2s: builtins.bytes | None = ..., + f8e4m3fns: builtins.bytes | None = ..., sparse_indices: collections.abc.Iterable[builtins.int] | None = ..., ) -> None: ... def HasField(self, field_name: typing_extensions.Literal["shape", b"shape"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "preds", b"preds", "s16s", b"s16s", "s32s", b"s32s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u32s", b"u32s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["bf16s", b"bf16s", "c128s", b"c128s", "c64s", b"c64s", "f16s", b"f16s", "f32s", b"f32s", "f64s", b"f64s", "f8e4m3fns", b"f8e4m3fns", "f8e5m2s", b"f8e5m2s", "preds", b"preds", "s16s", b"s16s", "s32s", b"s32s", "s64s", b"s64s", "s8s", b"s8s", "shape", b"shape", "sparse_indices", b"sparse_indices", "tuple_literals", b"tuple_literals", "u16s", b"u16s", "u32s", b"u32s", "u64s", b"u64s", "u8s", b"u8s"]) -> None: ... global___LiteralProto = LiteralProto @@ -1096,7 +1159,7 @@ global___WindowDimension = WindowDimension @typing_extensions.final class Window(google.protobuf.message.Message): """Describes the windowing in an operation such as convolution. - + The window is moved across a base area and for each position of the window a computation is performed. The field below describes the window and the movement of the window across a base area. @@ -1119,7 +1182,7 @@ global___Window = Window @typing_extensions.final class GatherDimensionNumbers(google.protobuf.message.Message): """Describes the dimension numbers for a gather operation. - + See https://www.tensorflow.org/performance/xla/operation_semantics#gather for more details. """ @@ -1136,9 +1199,9 @@ class GatherDimensionNumbers(google.protobuf.message.Message): interior of a dynamic-slice from the input tensor, the starting indices for which were computed from output_gather_dims (see the operation semantic for how this is defined) and the start_indices tensor. - + The window indices for a specific output index Out is computed as: - + i = 0 for (k : [0, input_tensor_shape.rank)) window_indices[k] = @@ -1173,7 +1236,7 @@ global___GatherDimensionNumbers = GatherDimensionNumbers @typing_extensions.final class ScatterDimensionNumbers(google.protobuf.message.Message): """Describes the dimension numbers for a scatter operation. - + All the fields are similar to the corresponding fields in GatherDimensionNumbers. Differences are noted below. """ @@ -1307,7 +1370,7 @@ class TriangularSolveOptions(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _TransposeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TriangularSolveOptions._Transpose.ValueType], builtins.type): # noqa: F821 + class _TransposeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[TriangularSolveOptions._Transpose.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor TRANSPOSE_INVALID: TriangularSolveOptions._Transpose.ValueType # 0 NO_TRANSPOSE: TriangularSolveOptions._Transpose.ValueType # 1 @@ -1415,7 +1478,7 @@ class OpSharding(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[OpSharding._Type.ValueType], builtins.type): # noqa: F821 + class _TypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[OpSharding._Type.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor REPLICATED: OpSharding._Type.ValueType # 0 """This sharding is replicated across all devices (implies maximal, @@ -1573,7 +1636,7 @@ class PrecisionConfig(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _PrecisionEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[PrecisionConfig._Precision.ValueType], builtins.type): # noqa: F821 + class _PrecisionEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[PrecisionConfig._Precision.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor DEFAULT: PrecisionConfig._Precision.ValueType # 0 HIGH: PrecisionConfig._Precision.ValueType # 1 @@ -1613,7 +1676,7 @@ class ParameterReplication(google.protobuf.message.Message): def replicated_at_leaf_buffers(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.bool]: """A list of boolean values for the flattened leaf buffers. Each value indicates whether the corresponding leaf buffer is replicated. - + If this field is empty, it means no buffer is replicated. Otherwise, the number of elements in this field must match the number of leaf buffers in the HLO instruction's shape. @@ -1631,7 +1694,7 @@ global___ParameterReplication = ParameterReplication class WhileLoopBackendConfig(google.protobuf.message.Message): """A backend-config for kWhile loops that stores the loop's trip count, if it is known. - + This is useful for backends that can implement a `for i in 0..N` loop more efficiently than a `while` loop. For example, on GPUs, we can implement a `for i in 0..N` loop by enqueueing the kernels for the loop body N times, @@ -1671,9 +1734,9 @@ class WhileLoopBackendConfig(google.protobuf.message.Message): global___WhileLoopBackendConfig = WhileLoopBackendConfig @typing_extensions.final -class CustomCallOutputOperandAliasing(google.protobuf.message.Message): - """Specifies a pair of output/operand buffers for kCustomCall that alias each - other. +class OutputOperandAliasing(google.protobuf.message.Message): + """Specifies a pair of output/operand buffers that alias each other for + kCustomCall and kFusion """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -1695,4 +1758,4 @@ class CustomCallOutputOperandAliasing(google.protobuf.message.Message): ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["operand_index", b"operand_index", "operand_shape_index", b"operand_shape_index", "output_shape_index", b"output_shape_index"]) -> None: ... -global___CustomCallOutputOperandAliasing = CustomCallOutputOperandAliasing +global___OutputOperandAliasing = OutputOperandAliasing diff --git a/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi b/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi index 4fb6132fdf08..d726ac63a599 100644 --- a/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi +++ b/stubs/tensorflow/tensorflow/core/example/example_pb2.pyi @@ -30,7 +30,7 @@ class Example(google.protobuf.message.Message): format, so any configuration that describes data with rank-2 or above should keep this in mind. If you flatten a matrix into a FloatList it should be stored as [ row 0 ... row 1 ... row M-1 ] - + An Example for a movie recommendation application: features { feature { @@ -81,7 +81,7 @@ class Example(google.protobuf.message.Message): }} } } - + A conformant Example data set obeys the following conventions: - If a Feature K exists in one example with data type T, it must be of type T in all other examples when present. It may be omitted. @@ -116,7 +116,7 @@ class SequenceExample(google.protobuf.message.Message): associated with a repeated set of Features (a FeatureList). A FeatureList thus represents the values of a feature identified by its key over time / frames. - + Below is a SequenceExample for a movie recommendation application recording a sequence of ratings by a user. The time-independent features ("locale", "age", "favorites") describing the user are part of the context. The sequence @@ -127,7 +127,7 @@ class SequenceExample(google.protobuf.message.Message): namely "movie_ratings", "movie_names", and "actors" have a feature value for both movies. Note, that "actors" is itself a bytes_list with multiple strings per movie. - + context: { feature: { key : "locale" @@ -201,9 +201,9 @@ class SequenceExample(google.protobuf.message.Message): } } } - + A conformant SequenceExample data set obeys the following conventions: - + Context: - All conformant context features K must obey the same conventions as a conformant Example's features (see above). @@ -228,21 +228,21 @@ class SequenceExample(google.protobuf.message.Message): number of Feature messages, so that the ith element in each FeatureList is part of the ith frame (or time step). Examples of conformant and non-conformant examples' FeatureLists: - + Conformant FeatureLists: feature_lists: { feature_list: { key: "movie_ratings" value: { feature: { float_list: { value: [ 4.5 ] } } feature: { float_list: { value: [ 5.0 ] } } } } } - + Non-conformant FeatureLists (mismatched types): feature_lists: { feature_list: { key: "movie_ratings" value: { feature: { float_list: { value: [ 4.5 ] } } feature: { int64_list: { value: [ 5 ] } } } } } - + Conditionally conformant FeatureLists, the parser configuration determines if the feature sizes must match: feature_lists: { feature_list: { @@ -250,7 +250,7 @@ class SequenceExample(google.protobuf.message.Message): value: { feature: { float_list: { value: [ 4.5 ] } } feature: { float_list: { value: [ 5.0, 6.0 ] } } } } } - + Conformant pair of SequenceExample feature_lists: { feature_list: { key: "movie_ratings" @@ -264,7 +264,7 @@ class SequenceExample(google.protobuf.message.Message): feature: { float_list: { value: [ 5.0 ] } } feature: { float_list: { value: [ 2.0 ] } } } } } - + Conformant pair of SequenceExample feature_lists: { feature_list: { key: "movie_ratings" @@ -276,7 +276,7 @@ class SequenceExample(google.protobuf.message.Message): key: "movie_ratings" value: { } } } - + Conditionally conformant pair of SequenceExample, the parser configuration determines if the second feature_lists is consistent (zero-length) or invalid (missing "movie_ratings"): @@ -287,7 +287,7 @@ class SequenceExample(google.protobuf.message.Message): } } and: feature_lists: { } - + Non-conformant pair of SequenceExample (mismatched types) feature_lists: { feature_list: { key: "movie_ratings" @@ -301,7 +301,7 @@ class SequenceExample(google.protobuf.message.Message): feature: { int64_list: { value: [ 5 ] } } feature: { int64_list: { value: [ 2 ] } } } } } - + Conditionally conformant pair of SequenceExample; the parser configuration determines if the feature sizes must match: feature_lists: { feature_list: { diff --git a/stubs/tensorflow/tensorflow/core/example/feature_pb2.pyi b/stubs/tensorflow/tensorflow/core/example/feature_pb2.pyi index 1739ce2aef40..9db61139e6fd 100644 --- a/stubs/tensorflow/tensorflow/core/example/feature_pb2.pyi +++ b/stubs/tensorflow/tensorflow/core/example/feature_pb2.pyi @@ -187,10 +187,10 @@ global___Features = Features @typing_extensions.final class FeatureList(google.protobuf.message.Message): """Containers for sequential data. - + A FeatureList contains lists of Features. These may hold zero or more Feature values. - + FeatureLists are organized into categories by name. The FeatureLists message contains the mapping from name to FeatureList. """ diff --git a/stubs/tensorflow/tensorflow/core/framework/api_def_pb2.pyi b/stubs/tensorflow/tensorflow/core/framework/api_def_pb2.pyi index 09fdb8369aec..00f258ae3824 100644 --- a/stubs/tensorflow/tensorflow/core/framework/api_def_pb2.pyi +++ b/stubs/tensorflow/tensorflow/core/framework/api_def_pb2.pyi @@ -29,12 +29,12 @@ class ApiDef(google.protobuf.message.Message): to all client languages, and another set per client language. The per-client-language ApiDefs will inherit values from the common ApiDefs which it can either replace or modify. - + We separate the API definition from the OpDef so we can evolve the API while remaining backwards compatible when interpreting old graphs. Overrides go in an "api_def.pbtxt" file with a text-format ApiDefs message. - + WARNING: Be *very* careful changing the API for any existing op -- you can change the semantics of existing code. These changes may need to wait until a major release of TensorFlow to avoid breaking @@ -47,7 +47,7 @@ class ApiDef(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _VisibilityEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ApiDef._Visibility.ValueType], builtins.type): # noqa: F821 + class _VisibilityEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ApiDef._Visibility.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor DEFAULT_VISIBILITY: ApiDef._Visibility.ValueType # 0 """Normally this is "VISIBLE" unless you are inheriting a diff --git a/stubs/tensorflow/tensorflow/core/framework/dataset_options_pb2.pyi b/stubs/tensorflow/tensorflow/core/framework/dataset_options_pb2.pyi index 509ca03684d0..e9165c3d60c8 100644 --- a/stubs/tensorflow/tensorflow/core/framework/dataset_options_pb2.pyi +++ b/stubs/tensorflow/tensorflow/core/framework/dataset_options_pb2.pyi @@ -132,7 +132,7 @@ class CardinalityOptions(google.protobuf.message.Message): ValueType = typing.NewType("ValueType", builtins.int) V: typing_extensions.TypeAlias = ValueType - class _ComputeLevelEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[CardinalityOptions._ComputeLevel.ValueType], builtins.type): # noqa: F821 + class _ComputeLevelEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[CardinalityOptions._ComputeLevel.ValueType], builtins.type): DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor CARDINALITY_COMPUTE_UNSPECIFIED: CardinalityOptions._ComputeLevel.ValueType # 0 CARDINALITY_COMPUTE_LOW: CardinalityOptions._ComputeLevel.ValueType # 1 @@ -293,8 +293,8 @@ global___ThreadingOptions = ThreadingOptions class Options(google.protobuf.message.Message): """Message stored with Dataset objects to control how datasets are processed and optimized. - - next: 8 + + next: 9 """ DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -306,6 +306,7 @@ class Options(google.protobuf.message.Message): SLACK_FIELD_NUMBER: builtins.int THREADING_OPTIONS_FIELD_NUMBER: builtins.int EXTERNAL_STATE_POLICY_FIELD_NUMBER: builtins.int + SYMBOLIC_CHECKPOINT_FIELD_NUMBER: builtins.int deterministic: builtins.bool @property def autotune_options(self) -> global___AutotuneOptions: @@ -321,6 +322,7 @@ class Options(google.protobuf.message.Message): def threading_options(self) -> global___ThreadingOptions: """The threading options associated with the dataset.""" external_state_policy: global___ExternalStatePolicy.ValueType + symbolic_checkpoint: builtins.bool def __init__( self, *, @@ -331,14 +333,17 @@ class Options(google.protobuf.message.Message): slack: builtins.bool | None = ..., threading_options: global___ThreadingOptions | None = ..., external_state_policy: global___ExternalStatePolicy.ValueType | None = ..., + symbolic_checkpoint: builtins.bool | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["autotune_options", b"autotune_options", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "slack", b"slack", "threading_options", b"threading_options"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["autotune_options", b"autotune_options", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "slack", b"slack", "threading_options", b"threading_options"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["autotune_options", b"autotune_options", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["autotune_options", b"autotune_options", "deterministic", b"deterministic", "distribute_options", b"distribute_options", "external_state_policy", b"external_state_policy", "optimization_options", b"optimization_options", "optional_deterministic", b"optional_deterministic", "optional_external_state_policy", b"optional_external_state_policy", "optional_slack", b"optional_slack", "optional_symbolic_checkpoint", b"optional_symbolic_checkpoint", "slack", b"slack", "symbolic_checkpoint", b"symbolic_checkpoint", "threading_options", b"threading_options"]) -> None: ... @typing.overload def WhichOneof(self, oneof_group: typing_extensions.Literal["optional_deterministic", b"optional_deterministic"]) -> typing_extensions.Literal["deterministic"] | None: ... @typing.overload def WhichOneof(self, oneof_group: typing_extensions.Literal["optional_external_state_policy", b"optional_external_state_policy"]) -> typing_extensions.Literal["external_state_policy"] | None: ... @typing.overload def WhichOneof(self, oneof_group: typing_extensions.Literal["optional_slack", b"optional_slack"]) -> typing_extensions.Literal["slack"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["optional_symbolic_checkpoint", b"optional_symbolic_checkpoint"]) -> typing_extensions.Literal["symbolic_checkpoint"] | None: ... global___Options = Options diff --git a/stubs/tensorflow/tensorflow/core/framework/dataset_pb2.pyi b/stubs/tensorflow/tensorflow/core/framework/dataset_pb2.pyi new file mode 100644 index 000000000000..432d6f4f57a5 --- /dev/null +++ b/stubs/tensorflow/tensorflow/core/framework/dataset_pb2.pyi @@ -0,0 +1,105 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.message +import sys +import tensorflow.core.framework.tensor_pb2 +import tensorflow.core.framework.tensor_shape_pb2 +import tensorflow.core.framework.types_pb2 + +if sys.version_info >= (3, 8): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +@typing_extensions.final +class CompressedComponentMetadata(google.protobuf.message.Message): + """This file contains protocol buffers for working with tf.data Datasets. + + Metadata describing a compressed component of a dataset element. + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DTYPE_FIELD_NUMBER: builtins.int + TENSOR_SHAPE_FIELD_NUMBER: builtins.int + UNCOMPRESSED_BYTES_FIELD_NUMBER: builtins.int + dtype: tensorflow.core.framework.types_pb2.DataType.ValueType + """The dtype of the component tensor.""" + @property + def tensor_shape(self) -> tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto: + """The shape of the component tensor.""" + @property + def uncompressed_bytes(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: + """The amount of uncompressed tensor data. + - For string tensors, there is an element for each string indicating the + size of the string. + - For all other tensors, there is a single element indicating the size of + the tensor. + """ + def __init__( + self, + *, + dtype: tensorflow.core.framework.types_pb2.DataType.ValueType | None = ..., + tensor_shape: tensorflow.core.framework.tensor_shape_pb2.TensorShapeProto | None = ..., + uncompressed_bytes: collections.abc.Iterable[builtins.int] | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["tensor_shape", b"tensor_shape"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["dtype", b"dtype", "tensor_shape", b"tensor_shape", "uncompressed_bytes", b"uncompressed_bytes"]) -> None: ... + +global___CompressedComponentMetadata = CompressedComponentMetadata + +@typing_extensions.final +class CompressedElement(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATA_FIELD_NUMBER: builtins.int + COMPONENT_METADATA_FIELD_NUMBER: builtins.int + VERSION_FIELD_NUMBER: builtins.int + data: builtins.bytes + """Compressed tensor bytes for all components of the element.""" + @property + def component_metadata(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___CompressedComponentMetadata]: + """Metadata for the components of the element.""" + version: builtins.int + """Version of the CompressedElement. CompressedElements may be stored on disk + and read back by later versions of code, so we store a version number to + help readers understand which version they are reading. When you add a new + field to this proto, you need to increment kCompressedElementVersion in + tensorflow/core/data/compression_utils.cc. + """ + def __init__( + self, + *, + data: builtins.bytes | None = ..., + component_metadata: collections.abc.Iterable[global___CompressedComponentMetadata] | None = ..., + version: builtins.int | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["component_metadata", b"component_metadata", "data", b"data", "version", b"version"]) -> None: ... + +global___CompressedElement = CompressedElement + +@typing_extensions.final +class UncompressedElement(google.protobuf.message.Message): + """An uncompressed dataset element.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + COMPONENTS_FIELD_NUMBER: builtins.int + @property + def components(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.core.framework.tensor_pb2.TensorProto]: ... + def __init__( + self, + *, + components: collections.abc.Iterable[tensorflow.core.framework.tensor_pb2.TensorProto] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["components", b"components"]) -> None: ... + +global___UncompressedElement = UncompressedElement diff --git a/stubs/tensorflow/tensorflow/core/framework/full_type_pb2.pyi b/stubs/tensorflow/tensorflow/core/framework/full_type_pb2.pyi index 6e2670d8582d..5f02828424b5 100644 --- a/stubs/tensorflow/tensorflow/core/framework/full_type_pb2.pyi +++ b/stubs/tensorflow/tensorflow/core/framework/full_type_pb2.pyi @@ -32,7 +32,7 @@ class _FullTypeIdEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._Enu Type variables may serve as placeholder for any other type ID in type templates. - + Examples: TFT_DATASET[TFT_VAR["T"]] is a Dataset returning a type indicated by "T". TFT_TENSOR[TFT_VAR["T"]] is a Tensor of n element type indicated by "T". @@ -51,20 +51,20 @@ class _FullTypeIdEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._Enu """The algebraic product type. This is an algebraic type that may be used just for logical grouping. Not to confused with TFT_TUPLE which describes a concrete object of several elements. - + Example: TFT_DATASET[TFT_PRODUCT[TFT_TENSOR[TFT_INT32], TFT_TENSOR[TFT_FLOAT64]]] is a Dataset producing two tensors, an integer one and a float one. """ TFT_NAMED: _FullTypeId.ValueType # 4 """Represents a named field, with the name stored in the attribute. - + Parametrization: TFT_NAMED[]{} * is the type of the field * is the field name, as string (thpugh can theoretically be an int as well) - + Example: TFT_RECORD[ TFT_NAMED[TFT_TENSOR[TFT_INT32]]{'foo'}, @@ -76,7 +76,7 @@ class _FullTypeIdEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._Enu TFT_FOR_EACH: _FullTypeId.ValueType # 20 """Template definition. Expands the variables by repeating a template as arguments of container. - + Parametrization: TFT_FOR_EACH[,