Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions stubs/tensorflow/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ tensorflow.Graph.__getattr__
tensorflow.Operation.__getattr__
tensorflow.Variable.__getattr__
tensorflow.keras.layers.Layer.__getattr__
tensorflow.python.feature_column.feature_column_v2.SharedEmbeddingColumnCreator.__getattr__
tensorflow.GradientTape.__getattr__

# Internal undocumented API
tensorflow.RaggedTensor.__init__
# Has an undocumented extra argument that tf.Variable which acts like subclass
Expand Down Expand Up @@ -69,3 +71,13 @@ tensorflow.keras.layers.*.compute_output_shape
# pb2.pyi generated by mypy-protobuf diverge with runtime in many ways. These stubs
# are mainly tested in mypy-protobuf.
.*_pb2.*

# Uses namedtuple at runtime, but NamedTuple in stubs and the two disagree about the name of
# __new__ first argument (cls vs cls_).
tensorflow.io.RaggedFeature.__new__
tensorflow.io.FixedLenSequenceFeature.__new__
tensorflow.io.FixedLenFeature.__new__
tensorflow.io.SparseFeature.__new__

# Metaclass inconsistency. The runtime metaclass is defined from c++ extension and is undocumented.
tensorflow.io.TFRecordWriter
2 changes: 1 addition & 1 deletion stubs/tensorflow/tensorflow/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from typing import Any, NoReturn, TypeVar, overload
from typing_extensions import ParamSpec, Self, TypeAlias

import numpy
from tensorflow import initializers as initializers, keras as keras, math as math
from tensorflow import feature_column as feature_column, initializers as initializers, io as io, keras as keras, math as math
from tensorflow._aliases import ContainerGradients, ContainerTensors, ContainerTensorsLike, Gradients, TensorLike

# Explicit import of DType is covered by the wildcard, but
Expand Down
95 changes: 95 additions & 0 deletions stubs/tensorflow/tensorflow/feature_column/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections.abc import Callable, Iterable, Sequence

import tensorflow as tf
from tensorflow import _ShapeLike
from tensorflow.python.feature_column import feature_column_v2 as fc, sequence_feature_column as seq_fc

def numeric_column(
key: str,
shape: _ShapeLike = (1,),
default_value: float | None = None,
dtype: tf.DType = ...,
normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None = None,
) -> fc.NumericColumn: ...
def bucketized_column(source_column: fc.NumericColumn, boundaries: list[float] | tuple[float, ...]) -> fc.BucketizedColumn: ...
def embedding_column(
categorical_column: fc.CategoricalColumn,
dimension: int,
combiner: fc._Combiners = "mean",
initializer: Callable[[_ShapeLike], tf.Tensor] | None = None,
ckpt_to_load_from: str | None = None,
tensor_name_in_ckpt: str | None = None,
max_norm: float | None = None,
trainable: bool = True,
use_safe_embedding_lookup: bool = True,
) -> fc.EmbeddingColumn: ...
def shared_embeddings(
categorical_columns: Iterable[fc.CategoricalColumn],
dimension: int,
combiner: fc._Combiners = "mean",
initializer: Callable[[_ShapeLike], tf.Tensor] | None = None,
shared_embedding_collection_name: str | None = None,
ckpt_to_load_from: str | None = None,
tensor_name_in_ckpt: str | None = None,
max_norm: float | None = None,
trainable: bool = True,
use_safe_embedding_lookup: bool = True,
) -> list[fc.SharedEmbeddingColumn]: ...
def categorical_column_with_identity(
key: str, num_buckets: int, default_value: int | None = None
) -> fc.IdentityCategoricalColumn: ...
def categorical_column_with_hash_bucket(key: str, hash_bucket_size: int, dtype: tf.DType = ...) -> fc.HashedCategoricalColumn: ...
def categorical_column_with_vocabulary_file(
key: str,
vocabulary_file: str,
vocabulary_size: int | None = None,
dtype: tf.DType = ...,
default_value: str | int | None = None,
num_oov_buckets: int = 0,
file_format: str | None = None,
) -> fc.VocabularyFileCategoricalColumn: ...
def categorical_column_with_vocabulary_list(
key: str,
vocabulary_list: Sequence[str] | Sequence[int],
dtype: tf.DType | None = None,
default_value: str | int | None = -1,
num_oov_buckets: int = 0,
) -> fc.VocabularyListCategoricalColumn: ...
def indicator_column(categorical_column: fc.CategoricalColumn) -> fc.IndicatorColumn: ...
def weighted_categorical_column(
categorical_column: fc.CategoricalColumn, weight_feature_key: str, dtype: tf.DType = ...
) -> fc.WeightedCategoricalColumn: ...
def crossed_column(
keys: Iterable[str | fc.CategoricalColumn], hash_bucket_size: int, hash_key: int | None = None
) -> fc.CrossedColumn: ...
def sequence_numeric_column(
key: str,
shape: _ShapeLike = (1,),
default_value: float = 0.0,
dtype: tf.DType = ...,
normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None = None,
) -> seq_fc.SequenceNumericColumn: ...
def sequence_categorical_column_with_identity(
key: str, num_buckets: int, default_value: int | None = None
) -> fc.SequenceCategoricalColumn: ...
def sequence_categorical_column_with_hash_bucket(
key: str, hash_bucket_size: int, dtype: tf.DType = ...
) -> fc.SequenceCategoricalColumn: ...
def sequence_categorical_column_with_vocabulary_file(
key: str,
vocabulary_file: str,
vocabulary_size: int | None = None,
num_oov_buckets: int = 0,
default_value: str | int | None = None,
dtype: tf.DType = ...,
) -> fc.SequenceCategoricalColumn: ...
def sequence_categorical_column_with_vocabulary_list(
key: str,
vocabulary_list: Sequence[str] | Sequence[int],
dtype: tf.DType | None = None,
default_value: str | int | None = -1,
num_oov_buckets: int = 0,
) -> fc.SequenceCategoricalColumn: ...
def make_parse_example_spec(
feature_columns: Iterable[fc.FeatureColumn],
) -> dict[str, tf.io.FixedLenFeature | tf.io.VarLenFeature]: ...
106 changes: 106 additions & 0 deletions stubs/tensorflow/tensorflow/io/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from _typeshed import Incomplete
from collections.abc import Iterable, Mapping
from types import TracebackType
from typing import NamedTuple
from typing_extensions import Literal, Self, TypeAlias

from tensorflow import _DTypeLike, _ShapeLike, _TensorCompatible
from tensorflow._aliases import TensorLike
from tensorflow.io import gfile as gfile

_FeatureSpecs: TypeAlias = Mapping[str, FixedLenFeature | FixedLenSequenceFeature | VarLenFeature | RaggedFeature | SparseFeature]

_CompressionTypes: TypeAlias = Literal["ZLIB", "GZIP", "", 0, 1, 2] | None
_CompressionLevels: TypeAlias = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | None
_MemoryLevels: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9] | None

class TFRecordOptions:
compression_type: _CompressionTypes | TFRecordOptions
flush_mode: int | None # The exact values allowed comes from zlib
input_buffer_size: int | None
output_buffer_size: int | None
window_bits: int | None
compression_level: _CompressionLevels
compression_method: str | None
mem_level: _MemoryLevels
compression_strategy: int | None # The exact values allowed comes from zlib

def __init__(
self,
compression_type: _CompressionTypes | TFRecordOptions = None,
flush_mode: int | None = None,
input_buffer_size: int | None = None,
output_buffer_size: int | None = None,
window_bits: int | None = None,
compression_level: _CompressionLevels = None,
compression_method: str | None = None,
mem_level: _MemoryLevels = None,
compression_strategy: int | None = None,
) -> None: ...
@classmethod
def get_compression_type_string(cls, options: _CompressionTypes | TFRecordOptions) -> str: ...

class TFRecordWriter:
def __init__(self, path: str, options: _CompressionTypes | TFRecordOptions | None = None) -> None: ...
def write(self, record: bytes) -> None: ...
def flush(self) -> None: ...
def close(self) -> None: ...
def __enter__(self) -> Self: ...
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None: ...

# Also defaults are missing here because pytype crashes when a default is present reported
# in this [issue](https://github.com/google/pytype/issues/1410#issue-1669793588). After
# next release the defaults can be added back.
class FixedLenFeature(NamedTuple):
shape: _ShapeLike
dtype: _DTypeLike
default_value: _TensorCompatible | None = ...

class FixedLenSequenceFeature(NamedTuple):
shape: _ShapeLike
dtype: _DTypeLike
allow_missing: bool = ...
default_value: _TensorCompatible | None = ...

class VarLenFeature(NamedTuple):
dtype: _DTypeLike

class SparseFeature(NamedTuple):
index_key: str | list[str]
value_key: str
dtype: _DTypeLike
size: int | list[int]
already_sorted: bool = ...

class RaggedFeature(NamedTuple):
# Mypy doesn't support nested NamedTuples, but at runtime they actually do use
# nested collections.namedtuple.
class RowSplits(NamedTuple): # type: ignore[misc]
key: str

class RowLengths(NamedTuple): # type: ignore[misc]
key: str

class RowStarts(NamedTuple): # type: ignore[misc]
key: str

class RowLimits(NamedTuple): # type: ignore[misc]
key: str

class ValueRowIds(NamedTuple): # type: ignore[misc]
key: str

class UniformRowLength(NamedTuple): # type: ignore[misc]
length: int
dtype: _DTypeLike
value_key: str | None = ...
partitions: tuple[RowSplits | RowLengths | RowStarts | RowLimits | ValueRowIds | UniformRowLength, ...] = ... # type: ignore[name-defined]
row_splits_dtype: _DTypeLike = ...
validate: bool = ...

def parse_example(
serialized: _TensorCompatible, features: _FeatureSpecs, example_names: Iterable[str] | None = None, name: str | None = None
) -> dict[str, TensorLike]: ...
def __getattr__(name: str) -> Incomplete: ...
11 changes: 11 additions & 0 deletions stubs/tensorflow/tensorflow/io/gfile.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from _typeshed import Incomplete, StrOrBytesPath
from collections.abc import Iterable

def rmtree(path: StrOrBytesPath) -> None: ...
def isdir(path: StrOrBytesPath) -> bool: ...
def listdir(path: StrOrBytesPath) -> list[str]: ...
def exists(path: StrOrBytesPath) -> bool: ...
def copy(src: StrOrBytesPath, dst: StrOrBytesPath, overwrite: bool = False) -> None: ...
def makedirs(path: StrOrBytesPath) -> None: ...
def glob(pattern: str | bytes | Iterable[str | bytes]) -> list[str]: ...
def __getattr__(name: str) -> Incomplete: ...
Empty file.
Loading