-
-
Notifications
You must be signed in to change notification settings - Fork 2k
tensorflow feature columns #10052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
JelleZijlstra
merged 13 commits into
python:main
from
hmc-cs-mdrissi:md/tensorflow_feature_columns
Apr 27, 2023
Merged
tensorflow feature columns #10052
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
8d79557
Add feature column type stubs.
4a42a7b
tweak
2869dc4
mypy fixes
c373157
ignore a few mypy errors
cd82246
fix all stubtest errors
a4bdaa7
ignore one pyright error
4a43fa8
workaround type checker bugs
722ba53
fix few more defaults for pytype
fcfd666
Update stubs/tensorflow/tensorflow/io/__init__.pyi
JelleZijlstra a6e114a
small cleanups
8b44211
Merge branch 'master' into md/tensorflow_feature_columns
c1ffa95
fix 1 last mypy error
bd05a43
retrigger flaky ci check
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| from collections.abc import Callable, Iterable, Sequence | ||
|
|
||
| import tensorflow as tf | ||
| from tensorflow import _ShapeLike | ||
| from tensorflow.python.feature_column import feature_column_v2 as fc, sequence_feature_column as seq_fc | ||
|
|
||
| def numeric_column( | ||
| key: str, | ||
| shape: _ShapeLike = (1,), | ||
| default_value: float | None = None, | ||
| dtype: tf.DType = ..., | ||
| normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None = None, | ||
| ) -> fc.NumericColumn: ... | ||
| def bucketized_column(source_column: fc.NumericColumn, boundaries: list[float] | tuple[float, ...]) -> fc.BucketizedColumn: ... | ||
| def embedding_column( | ||
| categorical_column: fc.CategoricalColumn, | ||
| dimension: int, | ||
| combiner: fc._Combiners = "mean", | ||
| initializer: Callable[[_ShapeLike], tf.Tensor] | None = None, | ||
| ckpt_to_load_from: str | None = None, | ||
| tensor_name_in_ckpt: str | None = None, | ||
| max_norm: float | None = None, | ||
| trainable: bool = True, | ||
| use_safe_embedding_lookup: bool = True, | ||
| ) -> fc.EmbeddingColumn: ... | ||
| def shared_embeddings( | ||
| categorical_columns: Iterable[fc.CategoricalColumn], | ||
| dimension: int, | ||
| combiner: fc._Combiners = "mean", | ||
| initializer: Callable[[_ShapeLike], tf.Tensor] | None = None, | ||
| shared_embedding_collection_name: str | None = None, | ||
| ckpt_to_load_from: str | None = None, | ||
| tensor_name_in_ckpt: str | None = None, | ||
| max_norm: float | None = None, | ||
| trainable: bool = True, | ||
| use_safe_embedding_lookup: bool = True, | ||
| ) -> list[fc.SharedEmbeddingColumn]: ... | ||
| def categorical_column_with_identity( | ||
| key: str, num_buckets: int, default_value: int | None = None | ||
| ) -> fc.IdentityCategoricalColumn: ... | ||
| def categorical_column_with_hash_bucket(key: str, hash_bucket_size: int, dtype: tf.DType = ...) -> fc.HashedCategoricalColumn: ... | ||
| def categorical_column_with_vocabulary_file( | ||
| key: str, | ||
| vocabulary_file: str, | ||
| vocabulary_size: int | None = None, | ||
| dtype: tf.DType = ..., | ||
| default_value: str | int | None = None, | ||
| num_oov_buckets: int = 0, | ||
| file_format: str | None = None, | ||
| ) -> fc.VocabularyFileCategoricalColumn: ... | ||
| def categorical_column_with_vocabulary_list( | ||
| key: str, | ||
| vocabulary_list: Sequence[str] | Sequence[int], | ||
| dtype: tf.DType | None = None, | ||
| default_value: str | int | None = -1, | ||
| num_oov_buckets: int = 0, | ||
| ) -> fc.VocabularyListCategoricalColumn: ... | ||
| def indicator_column(categorical_column: fc.CategoricalColumn) -> fc.IndicatorColumn: ... | ||
| def weighted_categorical_column( | ||
| categorical_column: fc.CategoricalColumn, weight_feature_key: str, dtype: tf.DType = ... | ||
| ) -> fc.WeightedCategoricalColumn: ... | ||
| def crossed_column( | ||
| keys: Iterable[str | fc.CategoricalColumn], hash_bucket_size: int, hash_key: int | None = None | ||
| ) -> fc.CrossedColumn: ... | ||
| def sequence_numeric_column( | ||
| key: str, | ||
| shape: _ShapeLike = (1,), | ||
| default_value: float = 0.0, | ||
| dtype: tf.DType = ..., | ||
| normalizer_fn: Callable[[tf.Tensor], tf.Tensor] | None = None, | ||
| ) -> seq_fc.SequenceNumericColumn: ... | ||
| def sequence_categorical_column_with_identity( | ||
| key: str, num_buckets: int, default_value: int | None = None | ||
| ) -> fc.SequenceCategoricalColumn: ... | ||
| def sequence_categorical_column_with_hash_bucket( | ||
| key: str, hash_bucket_size: int, dtype: tf.DType = ... | ||
| ) -> fc.SequenceCategoricalColumn: ... | ||
| def sequence_categorical_column_with_vocabulary_file( | ||
| key: str, | ||
| vocabulary_file: str, | ||
| vocabulary_size: int | None = None, | ||
| num_oov_buckets: int = 0, | ||
| default_value: str | int | None = None, | ||
| dtype: tf.DType = ..., | ||
| ) -> fc.SequenceCategoricalColumn: ... | ||
| def sequence_categorical_column_with_vocabulary_list( | ||
| key: str, | ||
| vocabulary_list: Sequence[str] | Sequence[int], | ||
| dtype: tf.DType | None = None, | ||
| default_value: str | int | None = -1, | ||
| num_oov_buckets: int = 0, | ||
| ) -> fc.SequenceCategoricalColumn: ... | ||
| def make_parse_example_spec( | ||
| feature_columns: Iterable[fc.FeatureColumn], | ||
| ) -> dict[str, tf.io.FixedLenFeature | tf.io.VarLenFeature]: ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| from _typeshed import Incomplete | ||
| from collections.abc import Iterable, Mapping | ||
| from types import TracebackType | ||
| from typing import NamedTuple | ||
| from typing_extensions import Literal, Self, TypeAlias | ||
|
|
||
| from tensorflow import _DTypeLike, _ShapeLike, _TensorCompatible | ||
| from tensorflow._aliases import TensorLike | ||
| from tensorflow.io import gfile as gfile | ||
|
|
||
| _FeatureSpecs: TypeAlias = Mapping[str, FixedLenFeature | FixedLenSequenceFeature | VarLenFeature | RaggedFeature | SparseFeature] | ||
|
|
||
| _CompressionTypes: TypeAlias = Literal["ZLIB", "GZIP", "", 0, 1, 2] | None | ||
| _CompressionLevels: TypeAlias = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9] | None | ||
| _MemoryLevels: TypeAlias = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9] | None | ||
|
|
||
| class TFRecordOptions: | ||
| compression_type: _CompressionTypes | TFRecordOptions | ||
| flush_mode: int | None # The exact values allowed comes from zlib | ||
| input_buffer_size: int | None | ||
| output_buffer_size: int | None | ||
| window_bits: int | None | ||
| compression_level: _CompressionLevels | ||
| compression_method: str | None | ||
| mem_level: _MemoryLevels | ||
| compression_strategy: int | None # The exact values allowed comes from zlib | ||
|
|
||
| def __init__( | ||
| self, | ||
| compression_type: _CompressionTypes | TFRecordOptions = None, | ||
| flush_mode: int | None = None, | ||
| input_buffer_size: int | None = None, | ||
| output_buffer_size: int | None = None, | ||
| window_bits: int | None = None, | ||
| compression_level: _CompressionLevels = None, | ||
| compression_method: str | None = None, | ||
| mem_level: _MemoryLevels = None, | ||
| compression_strategy: int | None = None, | ||
| ) -> None: ... | ||
| @classmethod | ||
| def get_compression_type_string(cls, options: _CompressionTypes | TFRecordOptions) -> str: ... | ||
|
|
||
| class TFRecordWriter: | ||
| def __init__(self, path: str, options: _CompressionTypes | TFRecordOptions | None = None) -> None: ... | ||
| def write(self, record: bytes) -> None: ... | ||
| def flush(self) -> None: ... | ||
| def close(self) -> None: ... | ||
| def __enter__(self) -> Self: ... | ||
| def __exit__( | ||
| self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None | ||
| ) -> None: ... | ||
|
|
||
| # Also defaults are missing here because pytype crashes when a default is present reported | ||
| # in this [issue](https://github.com/google/pytype/issues/1410#issue-1669793588). After | ||
| # next release the defaults can be added back. | ||
| class FixedLenFeature(NamedTuple): | ||
| shape: _ShapeLike | ||
| dtype: _DTypeLike | ||
| default_value: _TensorCompatible | None = ... | ||
|
|
||
| class FixedLenSequenceFeature(NamedTuple): | ||
| shape: _ShapeLike | ||
| dtype: _DTypeLike | ||
| allow_missing: bool = ... | ||
| default_value: _TensorCompatible | None = ... | ||
|
|
||
| class VarLenFeature(NamedTuple): | ||
| dtype: _DTypeLike | ||
|
|
||
| class SparseFeature(NamedTuple): | ||
| index_key: str | list[str] | ||
| value_key: str | ||
| dtype: _DTypeLike | ||
| size: int | list[int] | ||
| already_sorted: bool = ... | ||
|
|
||
| class RaggedFeature(NamedTuple): | ||
| # Mypy doesn't support nested NamedTuples, but at runtime they actually do use | ||
| # nested collections.namedtuple. | ||
| class RowSplits(NamedTuple): # type: ignore[misc] | ||
| key: str | ||
|
|
||
| class RowLengths(NamedTuple): # type: ignore[misc] | ||
| key: str | ||
|
|
||
| class RowStarts(NamedTuple): # type: ignore[misc] | ||
| key: str | ||
|
|
||
| class RowLimits(NamedTuple): # type: ignore[misc] | ||
| key: str | ||
|
|
||
| class ValueRowIds(NamedTuple): # type: ignore[misc] | ||
| key: str | ||
|
|
||
| class UniformRowLength(NamedTuple): # type: ignore[misc] | ||
| length: int | ||
| dtype: _DTypeLike | ||
| value_key: str | None = ... | ||
| partitions: tuple[RowSplits | RowLengths | RowStarts | RowLimits | ValueRowIds | UniformRowLength, ...] = ... # type: ignore[name-defined] | ||
| row_splits_dtype: _DTypeLike = ... | ||
| validate: bool = ... | ||
|
|
||
| def parse_example( | ||
| serialized: _TensorCompatible, features: _FeatureSpecs, example_names: Iterable[str] | None = None, name: str | None = None | ||
| ) -> dict[str, TensorLike]: ... | ||
| def __getattr__(name: str) -> Incomplete: ... | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| from _typeshed import Incomplete, StrOrBytesPath | ||
| from collections.abc import Iterable | ||
|
|
||
| def rmtree(path: StrOrBytesPath) -> None: ... | ||
| def isdir(path: StrOrBytesPath) -> bool: ... | ||
| def listdir(path: StrOrBytesPath) -> list[str]: ... | ||
| def exists(path: StrOrBytesPath) -> bool: ... | ||
| def copy(src: StrOrBytesPath, dst: StrOrBytesPath, overwrite: bool = False) -> None: ... | ||
| def makedirs(path: StrOrBytesPath) -> None: ... | ||
| def glob(pattern: str | bytes | Iterable[str | bytes]) -> list[str]: ... | ||
| def __getattr__(name: str) -> Incomplete: ... |
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.