Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
* Added tfx.v1 Public APIs, please refer to
[API doc](https://www.tensorflow.org/tfx/api_docs/python/tfx/v1) for details.
* Transform component now computes pre-transform and post-transform statistics
by default. This can be disabled by setting `disable_statistics=True` in the
Transform component.
and stores them in new, indvidual outputs ('pre_transform_schema',
'pre_transform_stats', 'post_transform_schema', 'post_transform_stats',
'post_transform_anomalies'). This can be disabled by setting
`disable_statistics=True` in the Transform component.
* BERT cola and mrpc examples now demonstrate how to calculate statistics for
NLP features.
* TFX CLI now supports
[Vertex Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines/introduction).
use it with `--engine=vertex` flag.

## Breaking Changes

* Removed unneccessary default values for required component input Channels.

### For Pipeline Authors

* N/A
Expand All @@ -29,9 +36,6 @@

## Bug Fixes and Other Changes

* TFX CLI now supports
[Vertex Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines/introduction).
use it with `--engine=vertex` flag.
* Forces keyword arguments for AirflowComponent to make it compatible with
Apache Airflow 2.1.0 and later.
* Removed `six` dependency.
Expand All @@ -40,6 +44,7 @@
* Depends on `struct2tensor>=0.31.0,<0.32.0`.
* Depends on `tensorflow>=1.15.2,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3`.
* Depends on `tensorflow-data-validation>=1.0.0,<1.1.0`.
* Depends on `tensorflow-hub>=0.9.0,<0.13`.
* Depends on `tensorflowjs>=3.6.0,<4`.
* Depends on `tensorflow-model-analysis>=0.31.0,<0.32.0`.
* Depends on `tensorflow-serving-api>=1.15,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3`.
Expand Down
2 changes: 1 addition & 1 deletion tfx/components/bulk_inferrer/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class BulkInferrer(base_beam_component.BaseBeamComponent):

def __init__(
self,
examples: types.Channel = None,
examples: types.Channel,
model: Optional[types.Channel] = None,
model_blessing: Optional[types.Channel] = None,
data_spec: Optional[Union[bulk_inferrer_pb2.DataSpec, Dict[Text,
Expand Down
4 changes: 2 additions & 2 deletions tfx/components/evaluator/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class Evaluator(base_beam_component.BaseBeamComponent):

def __init__(
self,
examples: types.Channel = None,
model: types.Channel = None,
examples: types.Channel,
model: Optional[types.Channel] = None,
baseline_model: Optional[types.Channel] = None,
# TODO(b/148618405): deprecate feature_slicing_spec.
feature_slicing_spec: Optional[Union[evaluator_pb2.FeatureSlicingSpec,
Expand Down
4 changes: 2 additions & 2 deletions tfx/components/example_validator/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ class ExampleValidator(base_component.BaseComponent):
EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

def __init__(self,
statistics: types.Channel = None,
schema: types.Channel = None,
statistics: types.Channel,
schema: types.Channel,
exclude_splits: Optional[List[Text]] = None):
"""Construct an ExampleValidator component.

Expand Down
8 changes: 5 additions & 3 deletions tfx/components/pusher/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def __init__(
the same field names as PushDestination proto message.
custom_config: A dict which contains the deployment job parameters to be
passed to Cloud platforms.
custom_executor_spec: Optional custom executor spec. This is experimental
and is subject to change in the future.
custom_executor_spec: Optional custom executor spec. Deprecated (no
compatibility guarantee), please customize component directly.
"""
pushed_model = types.Channel(type=standard_artifacts.PushedModel)
if (push_destination is None and not custom_executor_spec and
Expand All @@ -109,7 +109,9 @@ def __init__(
'custom_executor_spec is supplied that does not require '
'it.')
if custom_executor_spec:
logging.warning('`custom_executor_spec` is going to be deprecated.')
logging.warning(
'`custom_executor_spec` is deprecated. Please customize component directly.'
)
if model is None and infra_blessing is None:
raise ValueError(
'Either one of model or infra_blessing channel should be given. '
Expand Down
2 changes: 1 addition & 1 deletion tfx/components/schema_gen/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SchemaGen(base_component.BaseComponent):

def __init__(
self,
statistics: Optional[types.Channel] = None,
statistics: types.Channel,
infer_feature_shape: Optional[Union[bool,
data_types.RuntimeParameter]] = True,
exclude_splits: Optional[List[Text]] = None):
Expand Down
2 changes: 1 addition & 1 deletion tfx/components/statistics_gen/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class StatisticsGen(base_beam_component.BaseBeamComponent):
EXECUTOR_SPEC = executor_spec.BeamExecutorSpec(executor.Executor)

def __init__(self,
examples: types.Channel = None,
examples: types.Channel,
schema: Optional[types.Channel] = None,
stats_options: Optional[tfdv.StatsOptions] = None,
exclude_splits: Optional[List[Text]] = None):
Expand Down
55 changes: 22 additions & 33 deletions tfx/components/trainer/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from tfx.utils import json_utils


# TODO(b/147702778): update when switch generic executor as default.
class Trainer(base_component.BaseComponent):
"""A TFX component to train a TensorFlow model.

Expand All @@ -43,34 +42,16 @@ class Trainer(base_component.BaseComponent):
code](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/penguin_utils_keras.py)
of the TFX penguin pipeline example.

*Note:* The default executor for this component trains locally. This can be
overriden to enable the model to be trained on other platforms. The [Cloud AI
Platform custom
executor](https://github.com/tensorflow/tfx/tree/master/tfx/extensions/google_cloud_ai_platform/trainer)
provides an example how to implement this.
*Note:* This component trains locally. For cloud distributed training, please
refer to [Cloud AI Platform
Trainer](https://github.com/tensorflow/tfx/tree/master/tfx/extensions/google_cloud_ai_platform/trainer).

## Example 1: Training locally
## Example
```
# Uses user-provided Python function that trains a model using TF.
trainer = Trainer(
module_file=module_file,
transformed_examples=transform.outputs['transformed_examples'],
schema=infer_schema.outputs['schema'],
transform_graph=transform.outputs['transform_graph'],
train_args=proto.TrainArgs(splits=['train'], num_steps=10000),
eval_args=proto.EvalArgs(splits=['eval'], num_steps=5000))
```

## Example 2: Training through a cloud provider
```
from tfx.extensions.google_cloud_ai_platform.trainer import executor as
ai_platform_trainer_executor
# Train using Google Cloud AI Platform.
trainer = Trainer(
custom_executor_spec=executor_spec.ExecutorClassSpec(
ai_platform_trainer_executor.GenericExecutor),
module_file=module_file,
transformed_examples=transform.outputs['transformed_examples'],
examples=transform.outputs['transformed_examples'],
schema=infer_schema.outputs['schema'],
transform_graph=transform.outputs['transform_graph'],
train_args=proto.TrainArgs(splits=['train'], num_steps=10000),
Expand All @@ -92,7 +73,7 @@ class Trainer(base_component.BaseComponent):

def __init__(
self,
examples: types.Channel = None,
examples: Optional[types.Channel] = None,
transformed_examples: Optional[types.Channel] = None,
transform_graph: Optional[types.Channel] = None,
schema: Optional[types.Channel] = None,
Expand All @@ -102,8 +83,9 @@ def __init__(
run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
# TODO(b/147702778): deprecate trainer_fn.
trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None,
train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None,
eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None,
train_args: Optional[Union[trainer_pb2.TrainArgs, Dict[Text,
Any]]] = None,
eval_args: Optional[Union[trainer_pb2.EvalArgs, Dict[Text, Any]]] = None,
custom_config: Optional[Dict[Text, Any]] = None,
custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
"""Construct a Trainer component.
Expand All @@ -112,7 +94,8 @@ def __init__(
examples: A Channel of type `standard_artifacts.Examples`, serving as
the source of examples used in training (required). May be raw or
transformed.
transformed_examples: Deprecated field. Please set 'examples' instead.
transformed_examples: Deprecated (no compatibility guarantee). Please set
'examples' instead.
transform_graph: An optional Channel of type
`standard_artifacts.TransformGraph`, serving as the input transform
graph if present.
Expand Down Expand Up @@ -168,8 +151,8 @@ def trainer_fn(trainer.fn_args_utils.FnArgs,
behavior (when splits is empty) is evaluate on `eval` split.
custom_config: A dict which contains addtional training job parameters
that will be passed into user module.
custom_executor_spec: Optional custom executor spec. This is experimental
and is subject to change in the future.
custom_executor_spec: Optional custom executor spec. Deprecated (no
compatibility guarantee), please customize component directly.

Raises:
ValueError:
Expand All @@ -194,7 +177,13 @@ def trainer_fn(trainer.fn_args_utils.FnArgs,
"'transform_graph' must be supplied too.")

if custom_executor_spec:
logging.warning("`custom_executor_spec` is going to be deprecated.")
logging.warning(
"`custom_executor_spec` is deprecated. Please customize component directly."
)
if transformed_examples:
logging.warning(
"`transformed_examples` is deprecated. Please use `examples` instead."
)
examples = examples or transformed_examples
model = types.Channel(type=standard_artifacts.Model)
model_run = types.Channel(type=standard_artifacts.ModelRun)
Expand All @@ -204,8 +193,8 @@ def trainer_fn(trainer.fn_args_utils.FnArgs,
schema=schema,
base_model=base_model,
hyperparameters=hyperparameters,
train_args=train_args,
eval_args=eval_args,
train_args=train_args or trainer_pb2.TrainArgs(),
eval_args=eval_args or trainer_pb2.EvalArgs(),
module_file=module_file,
run_fn=run_fn,
trainer_fn=trainer_fn,
Expand Down
14 changes: 6 additions & 8 deletions tfx/components/trainer/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,9 @@ def testConstructFromModuleFile(self):
module_file = '/path/to/module/file'
trainer = component.Trainer(
module_file=module_file,
transformed_examples=self.examples,
examples=self.examples,
transform_graph=self.transform_graph,
schema=self.schema,
train_args=self.train_args,
eval_args=self.eval_args)
schema=self.schema)
self._verify_outputs(trainer)
self.assertEqual(
module_file,
Expand All @@ -72,7 +70,7 @@ def testConstructWithParameter(self):
n_steps = data_types.RuntimeParameter(name='n-steps', ptype=int)
trainer = component.Trainer(
module_file=module_file,
transformed_examples=self.examples,
examples=self.examples,
transform_graph=self.transform_graph,
schema=self.schema,
train_args=dict(splits=['train'], num_steps=n_steps),
Expand All @@ -87,7 +85,7 @@ def testConstructFromTrainerFn(self):
trainer_fn = 'path.to.my_trainer_fn'
trainer = component.Trainer(
trainer_fn=trainer_fn,
transformed_examples=self.examples,
examples=self.examples,
transform_graph=self.transform_graph,
train_args=self.train_args,
eval_args=self.eval_args)
Expand All @@ -102,7 +100,7 @@ def testConstructFromRunFn(self):
run_fn=run_fn,
custom_executor_spec=executor_spec.ExecutorClassSpec(
executor.GenericExecutor),
transformed_examples=self.examples,
examples=self.examples,
transform_graph=self.transform_graph,
train_args=self.train_args,
eval_args=self.eval_args)
Expand Down Expand Up @@ -175,7 +173,7 @@ def testConstructDuplicateUserModule(self):
def testConstructWithHParams(self):
trainer = component.Trainer(
trainer_fn='path.to.my_trainer_fn',
transformed_examples=self.examples,
examples=self.examples,
transform_graph=self.transform_graph,
schema=self.schema,
hyperparameters=self.hyperparameters,
Expand Down
25 changes: 21 additions & 4 deletions tfx/components/transform/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ class Transform(base_beam_component.BaseBeamComponent):

def __init__(
self,
examples: types.Channel = None,
schema: types.Channel = None,
examples: types.Channel,
schema: types.Channel,
module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None,
preprocessing_fn: Optional[Union[Text,
data_types.RuntimeParameter]] = None,
splits_config: transform_pb2.SplitsConfig = None,
splits_config: Optional[transform_pb2.SplitsConfig] = None,
analyzer_cache: Optional[types.Channel] = None,
materialize: bool = True,
disable_analyzer_cache: bool = False,
Expand Down Expand Up @@ -175,6 +175,18 @@ def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
transformed_examples = types.Channel(type=standard_artifacts.Examples)
transformed_examples.matching_channel_name = 'examples'

(pre_transform_schema, pre_transform_stats, post_transform_schema,
post_transform_stats, post_transform_anomalies) = (None,) * 5
if not disable_statistics:
pre_transform_schema = types.Channel(type=standard_artifacts.Schema)
post_transform_schema = types.Channel(type=standard_artifacts.Schema)
pre_transform_stats = types.Channel(
type=standard_artifacts.ExampleStatistics)
post_transform_stats = types.Channel(
type=standard_artifacts.ExampleStatistics)
post_transform_anomalies = types.Channel(
type=standard_artifacts.ExampleAnomalies)

if disable_analyzer_cache:
updated_analyzer_cache = None
if analyzer_cache:
Expand All @@ -196,7 +208,12 @@ def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
analyzer_cache=analyzer_cache,
updated_analyzer_cache=updated_analyzer_cache,
custom_config=json_utils.dumps(custom_config),
disable_statistics=int(disable_statistics))
disable_statistics=int(disable_statistics),
pre_transform_schema=pre_transform_schema,
pre_transform_stats=pre_transform_stats,
post_transform_schema=post_transform_schema,
post_transform_stats=post_transform_stats,
post_transform_anomalies=post_transform_anomalies)
super(Transform, self).__init__(spec=spec)

if udf_utils.should_package_user_modules():
Expand Down
27 changes: 25 additions & 2 deletions tfx/components/transform/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def setUp(self):
def _verify_outputs(self,
transform,
materialize=True,
disable_analyzer_cache=False):
disable_analyzer_cache=False,
disable_statistics=False):
self.assertEqual(
standard_artifacts.TransformGraph.TYPE_NAME, transform.outputs[
standard_component_specs.TRANSFORM_GRAPH_KEY].type_name)
Expand All @@ -65,6 +66,28 @@ def _verify_outputs(self,
standard_artifacts.TransformCache.TYPE_NAME, transform.outputs[
standard_component_specs.UPDATED_ANALYZER_CACHE_KEY].type_name)

if disable_statistics:
for key in [
standard_component_specs.PRE_TRANSFORM_STATS_KEY,
standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY,
standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY,
standard_component_specs.POST_TRANSFORM_STATS_KEY,
standard_component_specs.POST_TRANSFORM_SCHEMA_KEY
]:
self.assertNotIn(key, transform.outputs.keys())
else:
for key, name in [(standard_component_specs.PRE_TRANSFORM_STATS_KEY,
standard_artifacts.ExampleStatistics.TYPE_NAME),
(standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY,
standard_artifacts.Schema.TYPE_NAME),
(standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY,
standard_artifacts.ExampleAnomalies.TYPE_NAME),
(standard_component_specs.POST_TRANSFORM_STATS_KEY,
standard_artifacts.ExampleStatistics.TYPE_NAME),
(standard_component_specs.POST_TRANSFORM_SCHEMA_KEY,
standard_artifacts.Schema.TYPE_NAME)]:
self.assertEqual(name, transform.outputs[key].type_name)

def test_construct_from_module_file(self):
module_file = '/path/to/preprocessing.py'
transform = component.Transform(
Expand Down Expand Up @@ -196,7 +219,7 @@ def test_construct_with_stats_disabled(self):
preprocessing_fn='my_preprocessing_fn',
disable_statistics=True,
)
self._verify_outputs(transform)
self._verify_outputs(transform, disable_statistics=True)
self.assertEqual(
True,
bool(transform.spec.exec_properties[
Expand Down
Loading