From 5dd33613a5ff29a7a3a1bd7b7c329f24329c752b Mon Sep 17 00:00:00 2001 From: tfx-team Date: Tue, 25 May 2021 19:13:33 -0700 Subject: [PATCH 1/6] Create individual artifacts for statistics generation outputs for Transform component. PiperOrigin-RevId: 375846159 --- RELEASE.md | 6 +- tfx/components/transform/component.py | 19 ++- tfx/components/transform/component_test.py | 27 +++- tfx/components/transform/executor.py | 152 +++++++++++++++--- tfx/components/transform/executor_test.py | 87 ++++++++-- tfx/components/transform/labels.py | 6 + .../taxi_pipeline_regression_e2e_test.py | 8 +- .../expected_full_taxi_pipeline_job.json | 25 +++ tfx/types/standard_component_specs.py | 18 +++ 9 files changed, 309 insertions(+), 39 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 1d77640faa..d8a619d83c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -5,8 +5,10 @@ * Added tfx.v1 Public APIs, please refer to [API doc](https://www.tensorflow.org/tfx/api_docs/python/tfx/v1) for details. * Transform component now computes pre-transform and post-transform statistics - by default. This can be disabled by setting `disable_statistics=True` in the - Transform component. + and stores them in new, indvidual outputs ('pre_transform_schema', + 'pre_transform_stats', 'post_transform_schema', 'post_transform_stats', + 'post_transform_anomalies'). This can be disabled by setting + `disable_statistics=True` in the Transform component. * BERT cola and mrpc examples now demonstrate how to calculate statistics for NLP features. diff --git a/tfx/components/transform/component.py b/tfx/components/transform/component.py index fa588eb932..562f80dead 100644 --- a/tfx/components/transform/component.py +++ b/tfx/components/transform/component.py @@ -175,6 +175,18 @@ def preprocessing_fn(inputs: Dict[Text, Any], custom_config: transformed_examples = types.Channel(type=standard_artifacts.Examples) transformed_examples.matching_channel_name = 'examples' + (pre_transform_schema, pre_transform_stats, post_transform_schema, + post_transform_stats, post_transform_anomalies) = (None,) * 5 + if not disable_statistics: + pre_transform_schema = types.Channel(type=standard_artifacts.Schema) + post_transform_schema = types.Channel(type=standard_artifacts.Schema) + pre_transform_stats = types.Channel( + type=standard_artifacts.ExampleStatistics) + post_transform_stats = types.Channel( + type=standard_artifacts.ExampleStatistics) + post_transform_anomalies = types.Channel( + type=standard_artifacts.ExampleAnomalies) + if disable_analyzer_cache: updated_analyzer_cache = None if analyzer_cache: @@ -196,7 +208,12 @@ def preprocessing_fn(inputs: Dict[Text, Any], custom_config: analyzer_cache=analyzer_cache, updated_analyzer_cache=updated_analyzer_cache, custom_config=json_utils.dumps(custom_config), - disable_statistics=int(disable_statistics)) + disable_statistics=int(disable_statistics), + pre_transform_schema=pre_transform_schema, + pre_transform_stats=pre_transform_stats, + post_transform_schema=post_transform_schema, + post_transform_stats=post_transform_stats, + post_transform_anomalies=post_transform_anomalies) super(Transform, self).__init__(spec=spec) if udf_utils.should_package_user_modules(): diff --git a/tfx/components/transform/component_test.py b/tfx/components/transform/component_test.py index d5326139c6..5b85be4e12 100644 --- a/tfx/components/transform/component_test.py +++ b/tfx/components/transform/component_test.py @@ -46,7 +46,8 @@ def setUp(self): def _verify_outputs(self, transform, materialize=True, - disable_analyzer_cache=False): + disable_analyzer_cache=False, + disable_statistics=False): self.assertEqual( standard_artifacts.TransformGraph.TYPE_NAME, transform.outputs[ standard_component_specs.TRANSFORM_GRAPH_KEY].type_name) @@ -65,6 +66,28 @@ def _verify_outputs(self, standard_artifacts.TransformCache.TYPE_NAME, transform.outputs[ standard_component_specs.UPDATED_ANALYZER_CACHE_KEY].type_name) + if disable_statistics: + for key in [ + standard_component_specs.PRE_TRANSFORM_STATS_KEY, + standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY, + standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY, + standard_component_specs.POST_TRANSFORM_STATS_KEY, + standard_component_specs.POST_TRANSFORM_SCHEMA_KEY + ]: + self.assertNotIn(key, transform.outputs.keys()) + else: + for key, name in [(standard_component_specs.PRE_TRANSFORM_STATS_KEY, + standard_artifacts.ExampleStatistics.TYPE_NAME), + (standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY, + standard_artifacts.Schema.TYPE_NAME), + (standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY, + standard_artifacts.ExampleAnomalies.TYPE_NAME), + (standard_component_specs.POST_TRANSFORM_STATS_KEY, + standard_artifacts.ExampleStatistics.TYPE_NAME), + (standard_component_specs.POST_TRANSFORM_SCHEMA_KEY, + standard_artifacts.Schema.TYPE_NAME)]: + self.assertEqual(name, transform.outputs[key].type_name) + def test_construct_from_module_file(self): module_file = '/path/to/preprocessing.py' transform = component.Transform( @@ -196,7 +219,7 @@ def test_construct_with_stats_disabled(self): preprocessing_fn='my_preprocessing_fn', disable_statistics=True, ) - self._verify_outputs(transform) + self._verify_outputs(transform, disable_statistics=True) self.assertEqual( True, bool(transform.spec.exec_properties[ diff --git a/tfx/components/transform/executor.py b/tfx/components/transform/executor.py index f7a451f07f..c2b358b0c3 100644 --- a/tfx/components/transform/executor.py +++ b/tfx/components/transform/executor.py @@ -89,6 +89,15 @@ # Beam extra pip package prefix. _BEAM_EXTRA_PACKAGE_PREFIX = '--extra_package=' +# Stats output filename keys. +_ANOMALIES_FILE = 'SchemaDiff.pb' +_STATS_FILE = 'FeatureStats.pb' +_SCHEMA_FILE = 'Schema.pb' + +_ANOMALIES_KEY = 'anomalies' +_SCHEMA_KEY = 'schema' +_STATS_KEY = 'stats' + # TODO(b/122478841): Move it to a common place that is shared across components. class _Status(object): @@ -384,6 +393,32 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]], transform_output = artifact_utils.get_single_uri( output_dict[standard_component_specs.TRANSFORM_GRAPH_KEY]) + disable_statistics = bool( + exec_properties.get(standard_component_specs.DISABLE_STATISTICS_KEY, 0)) + + label_component_key_list = [ + (labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL, + standard_component_specs.PRE_TRANSFORM_STATS_KEY), + (labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL, + standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY), + (labels.POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL, + standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY), + (labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL, + standard_component_specs.POST_TRANSFORM_STATS_KEY), + (labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL, + standard_component_specs.POST_TRANSFORM_SCHEMA_KEY) + ] + stats_output_paths = {} + if not disable_statistics: + for label, component_key in label_component_key_list: + if component_key in output_dict: + stats_output_paths[label] = artifact_utils.get_single_uri( + output_dict[component_key]) + if stats_output_paths and len(stats_output_paths) != len( + label_component_key_list): + raise ValueError( + 'Either all stats_output_paths should be specified or none.') + temp_path = os.path.join(transform_output, _TEMP_DIR_IN_TRANSFORM_OUTPUT) absl.logging.debug('Using temp path %s for tft.beam', temp_path) @@ -427,9 +462,6 @@ def _GetCachePath(label, params_dict): force_tf_compat_v1 = bool( exec_properties.get(standard_component_specs.FORCE_TF_COMPAT_V1_KEY, 0)) - disable_statistics = bool( - exec_properties.get(standard_component_specs.DISABLE_STATISTICS_KEY, 0)) - # Make sure user packages get propagated to the remote Beam worker. user_module_key = exec_properties.get( standard_component_specs.MODULE_PATH_KEY, None) @@ -481,6 +513,8 @@ def _GetCachePath(label, params_dict): materialize_output_paths, labels.TEMP_OUTPUT_LABEL: str(temp_path), } + + label_outputs.update(stats_output_paths) cache_output = _GetCachePath( standard_component_specs.UPDATED_ANALYZER_CACHE_KEY, output_dict) if cache_output is not None: @@ -597,7 +631,8 @@ def _ReadMetadata(self, data_format: int, Optional[beam.pvalue.PDone], Optional[beam.pvalue.PDone]]) def _GenerateAndMaybeValidateStats( - pcoll: beam.pvalue.PCollection, stats_output_path: Text, + pcoll: beam.pvalue.PCollection, stats_output_loc: Union[Text, Dict[Text, + Text]], stats_options: tfdv.StatsOptions, enable_validation: bool ) -> Tuple[beam.pvalue.PDone, Optional[beam.pvalue.PDone], Optional[beam.pvalue.PDone]]: @@ -605,7 +640,9 @@ def _GenerateAndMaybeValidateStats( Args: pcoll: PCollection of examples. - stats_output_path: path where statistics is written to. + stats_output_loc: path to the statistics folder to write all results to + or a dictionary keyed with individual paths for 'schema', 'stats', and + 'anomalies'. stats_options: An instance of `tfdv.StatsOptions()` used when computing statistics. enable_validation: Whether to enable stats validation. @@ -616,6 +653,16 @@ def _GenerateAndMaybeValidateStats( schema is not present or validation is not enabled, the corresponding values are replaced with Nones. """ + if isinstance(stats_output_loc, dict): + stats_output_path = stats_output_loc[_STATS_KEY] + schema_output_path = stats_output_loc[_SCHEMA_KEY] + anomalies_output_path = stats_output_loc.get(_ANOMALIES_KEY) + else: + stats_output_path = stats_output_loc + stats_output_dir = os.path.dirname(stats_output_loc) + schema_output_path = os.path.join(stats_output_dir, _SCHEMA_FILE) + anomalies_output_path = os.path.join(stats_output_dir, _ANOMALIES_FILE) + generated_stats = ( pcoll | 'FilterInternalColumn' >> beam.Map(Executor._FilterInternalColumn) @@ -629,8 +676,6 @@ def _GenerateAndMaybeValidateStats( if stats_options.schema is None: return (stats_result, None, None) - stats_dir = os.path.dirname(stats_output_path) - schema_output_path = os.path.join(stats_dir, 'Schema.pb') # TODO(b/186867968): See if we should switch to common libraries. schema_result = ( pcoll.pipeline @@ -644,14 +689,13 @@ def _GenerateAndMaybeValidateStats( if not enable_validation: return (stats_result, schema_result, None) - validation_output_path = os.path.join(stats_dir, 'SchemaDiff.pb') # TODO(b/186867968): See if we should switch to common libraries. validation_result = ( generated_stats | 'ValidateStatistics' >> beam.Map( lambda stats: tfdv.validate_statistics(stats, stats_options.schema)) | 'WriteValidation' >> beam.io.WriteToText( - validation_output_path, + anomalies_output_path, append_trailing_newlines=False, shard_name_template='', # To force unsharded output. coder=beam.coders.ProtoCoder(anomalies_pb2.Anomalies))) @@ -917,6 +961,16 @@ def Transform(self, inputs: Mapping[Text, Any], outputs: Mapping[Text, Any], - labels.TRANSFORM_MATERIALIZE_OUTPUT_PATHS_LABEL: Paths to transform materialization. - labels.TEMP_OUTPUT_LABEL: A path to temporary directory. + - labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL: A path to the output + pre-transform schema, optional. + - labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL: A path to the output + pre-transform statistics, optional. + - labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL: A path to the output + post-transform schema, optional. + - labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL: A path to the output + post-transform statistics, optional. + - labels.POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL: A path to the + output post-transform anomalies, optional. status_file: Where the status should be written (not yet implemented) """ del status_file # unused @@ -961,6 +1015,22 @@ def Transform(self, inputs: Mapping[Text, Any], outputs: Mapping[Text, Any], force_tf_compat_v1 = value_utils.GetSoleValue( inputs, labels.FORCE_TF_COMPAT_V1_LABEL) + stats_labels_list = [ + labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL, + labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL, + labels.POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL, + labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL, + labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL + ] + stats_output_paths = {} + for label in stats_labels_list: + value = value_utils.GetSoleValue(outputs, label, strict=False) + if value: + stats_output_paths[label] = value + if stats_output_paths and len(stats_output_paths) != len(stats_labels_list): + raise ValueError('Either all stats_output_paths should be' + ' specified or none.') + absl.logging.debug('Force tf.compat.v1: %s', force_tf_compat_v1) absl.logging.debug('Analyze data patterns: %s', list(enumerate(analyze_data_paths))) @@ -1040,7 +1110,7 @@ def Transform(self, inputs: Mapping[Text, Any], outputs: Mapping[Text, Any], raw_examples_data_format, temp_path, input_cache_dir, output_cache_dir, disable_statistics, per_set_stats_output_paths, materialization_format, - len(analyze_data_paths)) + len(analyze_data_paths), stats_output_paths) # TODO(b/122478841): Writes status to status file. # pylint: disable=expression-not-assigned, no-value-for-parameter @@ -1055,7 +1125,8 @@ def _RunBeamImpl(self, analyze_data_list: List[_Dataset], output_cache_dir: Optional[Text], disable_statistics: bool, per_set_stats_output_paths: Sequence[Text], materialization_format: Optional[Text], - analyze_paths_count: int) -> _Status: + analyze_paths_count: int, + stats_output_paths: Dict[Text, Text]) -> _Status: """Perform data preprocessing with TFT. Args: @@ -1080,6 +1151,10 @@ def _RunBeamImpl(self, analyze_data_list: List[_Dataset], data or None if materialization is not enabled. analyze_paths_count: An integer, the number of paths that should be used for analysis. + stats_output_paths: A dictionary specifying the output paths to use when + computing statistics. If the dictionary is empty, the stats will be + placed within the transform_output_path to preserve backwards + compatibility. Returns: Status of the execution. @@ -1233,10 +1308,6 @@ def _RunBeamImpl(self, analyze_data_list: List[_Dataset], if (not disable_statistics and not self._IsDataFormatProto(raw_examples_data_format)): # Aggregated feature stats before transformation. - pre_transform_feature_stats_path = os.path.join( - transform_output_path, - tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH) - if self._IsDataFormatSequenceExample(raw_examples_data_format): schema_proto = None else: @@ -1272,11 +1343,29 @@ def _ExtractRawExampleBatches(record_batch): pre_transform_stats_options.experimental_use_sketch_based_topk_uniques = ( self._TfdvUseSketchBasedTopKUniques()) + if stats_output_paths: + pre_transform_feature_stats_loc = { + _STATS_KEY: + os.path.join( + stats_output_paths[ + labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL], + _STATS_FILE), + _SCHEMA_KEY: + os.path.join( + stats_output_paths[ + labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL], + _SCHEMA_FILE) + } + else: + pre_transform_feature_stats_loc = os.path.join( + transform_output_path, + tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH) + (stats_input | 'FlattenAnalysisDatasets' >> beam.Flatten(pipeline=pipeline) | 'GenerateStats[FlattenedAnalysisDataset]' >> self._GenerateAndMaybeValidateStats( - pre_transform_feature_stats_path, + pre_transform_feature_stats_loc, stats_options=pre_transform_stats_options, enable_validation=False)) @@ -1309,10 +1398,6 @@ def _ExtractRawExampleBatches(record_batch): dataset.transformed | 'ExtractRecordBatches[{}]'.format(infix) >> beam.Keys()) - post_transform_feature_stats_path = os.path.join( - transform_output_path, - tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH) - post_transform_stats_options = _InvokeStatsOptionsUpdaterFn( stats_options_updater_fn, stats_options_util.StatsType.POST_TRANSFORM, @@ -1321,6 +1406,31 @@ def _ExtractRawExampleBatches(record_batch): post_transform_stats_options.experimental_use_sketch_based_topk_uniques = ( self._TfdvUseSketchBasedTopKUniques()) + + if stats_output_paths: + post_transform_feature_stats_loc = { + _STATS_KEY: + os.path.join( + stats_output_paths[ + labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL], + _STATS_FILE), + _SCHEMA_KEY: + os.path.join( + stats_output_paths[ + labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL], + _SCHEMA_FILE), + _ANOMALIES_KEY: + os.path.join( + stats_output_paths[ + labels + .POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL], + _ANOMALIES_FILE) + } + else: + post_transform_feature_stats_loc = os.path.join( + transform_output_path, + tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH) + ([ dataset.transformed_and_standardized for dataset in transform_data_list @@ -1331,7 +1441,7 @@ def _ExtractRawExampleBatches(record_batch): completion=beam.pvalue.AsSingleton(completed_transform)) | 'GenerateAndValidateStats[FlattenedTransformedDatasets]' >> self._GenerateAndMaybeValidateStats( - post_transform_feature_stats_path, + post_transform_feature_stats_loc, stats_options=post_transform_stats_options, enable_validation=True)) diff --git a/tfx/components/transform/executor_test.py b/tfx/components/transform/executor_test.py index 1c2535df1a..df7314ebcf 100644 --- a/tfx/components/transform/executor_test.py +++ b/tfx/components/transform/executor_test.py @@ -128,6 +128,22 @@ def _make_base_do_params(self, source_data_dir, output_data_dir): self._updated_analyzer_cache_artifact.uri = os.path.join( self._output_data_dir, 'CACHE') + self._pre_transform_schema = standard_artifacts.Schema() + self._pre_transform_schema.uri = os.path.join(output_data_dir, + 'pre_transform_schema', '0') + self._pre_transform_stats = standard_artifacts.ExampleStatistics() + self._pre_transform_stats.uri = os.path.join(output_data_dir, + 'pre_transform_stats', '0') + self._post_transform_schema = standard_artifacts.Schema() + self._post_transform_schema.uri = os.path.join(output_data_dir, + 'post_transform_schema', '0') + self._post_transform_stats = standard_artifacts.ExampleStatistics() + self._post_transform_stats.uri = os.path.join(output_data_dir, + 'post_transform_stats', '0') + self._post_transform_anomalies = standard_artifacts.ExampleAnomalies() + self._post_transform_anomalies.uri = os.path.join( + output_data_dir, 'post_transform_anomalies', '0') + self._output_dict = { standard_component_specs.TRANSFORM_GRAPH_KEY: [ self._transformed_output @@ -138,6 +154,21 @@ def _make_base_do_params(self, source_data_dir, output_data_dir): standard_component_specs.UPDATED_ANALYZER_CACHE_KEY: [ self._updated_analyzer_cache_artifact ], + standard_component_specs.PRE_TRANSFORM_STATS_KEY: [ + self._pre_transform_stats + ], + standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY: [ + self._pre_transform_schema + ], + standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY: [ + self._post_transform_anomalies + ], + standard_component_specs.POST_TRANSFORM_STATS_KEY: [ + self._post_transform_stats + ], + standard_component_specs.POST_TRANSFORM_SCHEMA_KEY: [ + self._post_transform_schema + ], } # Create exec properties skeleton. @@ -212,18 +243,46 @@ def _verify_transform_outputs(self, self.assertEqual(examples_eval_count, transformed_eval_count) self.assertGreater(transformed_train_count, transformed_eval_count) - path_to_pre_transform_statistics = os.path.join( - self._transformed_output.uri, - tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH) - path_to_post_transform_statistics = os.path.join( - self._transformed_output.uri, - tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH) if disable_statistics: - self.assertFalse(fileio.exists(path_to_pre_transform_statistics)) - self.assertFalse(fileio.exists(path_to_post_transform_statistics)) + self.assertFalse( + fileio.exists( + os.path.join(self._pre_transform_schema.uri, 'Schema.pb'))) + self.assertFalse( + fileio.exists( + os.path.join(self._post_transform_schema.uri, 'Schema.pb'))) + self.assertFalse( + fileio.exists( + os.path.join(self._pre_transform_stats.uri, 'FeatureStats.pb'))) + self.assertFalse( + fileio.exists( + os.path.join(self._post_transform_stats.uri, 'FeatureStats.pb'))) + self.assertFalse( + fileio.exists( + os.path.join(self._post_transform_anomalies.uri, + 'SchemaDiff.pb'))) + else: - self.assertTrue(fileio.exists(path_to_pre_transform_statistics)) - self.assertTrue(fileio.exists(path_to_post_transform_statistics)) + expected_outputs.extend([ + 'pre_transform_schema', 'pre_transform_stats', + 'post_transform_schema', 'post_transform_stats', + 'post_transform_anomalies' + ]) + self.assertTrue( + fileio.exists( + os.path.join(self._pre_transform_schema.uri, 'Schema.pb'))) + self.assertTrue( + fileio.exists( + os.path.join(self._post_transform_schema.uri, 'Schema.pb'))) + self.assertTrue( + fileio.exists( + os.path.join(self._pre_transform_stats.uri, 'FeatureStats.pb'))) + self.assertTrue( + fileio.exists( + os.path.join(self._post_transform_stats.uri, 'FeatureStats.pb'))) + self.assertTrue( + fileio.exists( + os.path.join(self._post_transform_anomalies.uri, + 'SchemaDiff.pb'))) # Depending on `materialize` and `store_cache`, check that # expected outputs are exactly correct. If either flag is False, its @@ -296,6 +355,14 @@ def test_do_with_statistics_disabled(self): standard_component_specs.MODULE_FILE_KEY] = self._module_file self._exec_properties[ standard_component_specs.DISABLE_STATISTICS_KEY] = True + for key in [ + standard_component_specs.PRE_TRANSFORM_STATS_KEY, + standard_component_specs.PRE_TRANSFORM_SCHEMA_KEY, + standard_component_specs.POST_TRANSFORM_ANOMALIES_KEY, + standard_component_specs.POST_TRANSFORM_STATS_KEY, + standard_component_specs.POST_TRANSFORM_SCHEMA_KEY + ]: + self._output_dict.pop(key) self._transform_executor.Do(self._input_dict, self._output_dict, self._exec_properties) self._verify_transform_outputs(disable_statistics=True) diff --git a/tfx/components/transform/labels.py b/tfx/components/transform/labels.py index 5798e98ff6..87b7e52052 100644 --- a/tfx/components/transform/labels.py +++ b/tfx/components/transform/labels.py @@ -45,6 +45,12 @@ TRANSFORM_METADATA_OUTPUT_PATH_LABEL = 'transform_output_path' CACHE_OUTPUT_PATH_LABEL = 'cache_output_path' TEMP_OUTPUT_LABEL = 'temp_path' +PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL = 'pre_transform_output_schema_path' +PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL = 'pre_transform_output_stats_path' +POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL = 'post_transform_output_schema_path' +POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL = 'post_transform_output_stats_path' +POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL = ( + 'post_transform_output_anomalies_path') # Examples File Format FORMAT_TFRECORD = 'FORMAT_TFRECORD' diff --git a/tfx/experimental/pipeline_testing/examples/chicago_taxi_pipeline/taxi_pipeline_regression_e2e_test.py b/tfx/experimental/pipeline_testing/examples/chicago_taxi_pipeline/taxi_pipeline_regression_e2e_test.py index b51bdf9464..d80faf058f 100644 --- a/tfx/experimental/pipeline_testing/examples/chicago_taxi_pipeline/taxi_pipeline_regression_e2e_test.py +++ b/tfx/experimental/pipeline_testing/examples/chicago_taxi_pipeline/taxi_pipeline_regression_e2e_test.py @@ -136,11 +136,13 @@ def testStubbedTaxiPipelineBeam(self): artifact_count = len(artifacts) executions = m.store.get_executions() execution_count = len(executions) - # Artifact count is greater by 3 due to extra artifacts produced by + # Artifact count is greater by 7 due to extra artifacts produced by # Evaluator(blessing and evaluation), Trainer(model and model_run) and - # Transform(example, graph, cache) minus Resolver which doesn't generate + # Transform(example, graph, cache, pre_transform_statistics, + # pre_transform_schema, post_transform_statistics, post_transform_schema, + # post_transform_anomalies) minus Resolver which doesn't generate # new artifact. - self.assertEqual(artifact_count, execution_count + 3) + self.assertEqual(artifact_count, execution_count + 7) self.assertLen(self.taxi_pipeline.components, execution_count) for execution in executions: diff --git a/tfx/orchestration/kubeflow/v2/testdata/expected_full_taxi_pipeline_job.json b/tfx/orchestration/kubeflow/v2/testdata/expected_full_taxi_pipeline_job.json index d66cae66ab..d0a41bb3fa 100644 --- a/tfx/orchestration/kubeflow/v2/testdata/expected_full_taxi_pipeline_job.json +++ b/tfx/orchestration/kubeflow/v2/testdata/expected_full_taxi_pipeline_job.json @@ -369,6 +369,31 @@ "Transform": { "outputDefinitions": { "artifacts": { + "pre_transform_schema": { + "artifactType": { + "instanceSchema": "title: tfx.Schema\ntype: object\n" + } + }, + "pre_transform_stats": { + "artifactType": { + "instanceSchema": "title: tfx.ExampleStatistics\ntype: object\nproperties:\n span:\n type: integer\n description: Span for an artifact.\n split_names:\n type: string\n description: JSON-encoded list of splits for an artifact. Empty string means artifact has no split.\n" + } + }, + "post_transform_stats": { + "artifactType": { + "instanceSchema": "title: tfx.ExampleStatistics\ntype: object\nproperties:\n span:\n type: integer\n description: Span for an artifact.\n split_names:\n type: string\n description: JSON-encoded list of splits for an artifact. Empty string means artifact has no split.\n" + } + }, + "post_transform_schema": { + "artifactType": { + "instanceSchema": "title: tfx.Schema\ntype: object\n" + } + }, + "post_transform_anomalies": { + "artifactType": { + "instanceSchema": "title: tfx.ExampleAnomalies\ntype: object\nproperties:\n span:\n type: integer\n description: Span for an artifact.\n split_names:\n type: string\n description: JSON-encoded list of splits for an artifact. Empty string means artifact has no split.\n" + } + }, "updated_analyzer_cache": { "artifactType": { "instanceSchema": "title: tfx.TransformCache\ntype: object\n" diff --git a/tfx/types/standard_component_specs.py b/tfx/types/standard_component_specs.py index bc917a2bdc..b0f0ee5296 100644 --- a/tfx/types/standard_component_specs.py +++ b/tfx/types/standard_component_specs.py @@ -98,6 +98,11 @@ TRANSFORMED_EXAMPLES_KEY = 'transformed_examples' UPDATED_ANALYZER_CACHE_KEY = 'updated_analyzer_cache' DISABLE_STATISTICS_KEY = 'disable_statistics' +PRE_TRANSFORM_SCHEMA_KEY = 'pre_transform_schema' +PRE_TRANSFORM_STATS_KEY = 'pre_transform_stats' +POST_TRANSFORM_SCHEMA_KEY = 'post_transform_schema' +POST_TRANSFORM_STATS_KEY = 'post_transform_stats' +POST_TRANSFORM_ANOMALIES_KEY = 'post_transform_anomalies' class BulkInferrerSpec(ComponentSpec): @@ -419,4 +424,17 @@ class TransformSpec(ComponentSpec): UPDATED_ANALYZER_CACHE_KEY: ChannelParameter( type=standard_artifacts.TransformCache, optional=True), + PRE_TRANSFORM_SCHEMA_KEY: + ChannelParameter(type=standard_artifacts.Schema, optional=True), + PRE_TRANSFORM_STATS_KEY: + ChannelParameter( + type=standard_artifacts.ExampleStatistics, optional=True), + POST_TRANSFORM_SCHEMA_KEY: + ChannelParameter(type=standard_artifacts.Schema, optional=True), + POST_TRANSFORM_STATS_KEY: + ChannelParameter( + type=standard_artifacts.ExampleStatistics, optional=True), + POST_TRANSFORM_ANOMALIES_KEY: + ChannelParameter( + type=standard_artifacts.ExampleAnomalies, optional=True) } From dd9fecac4302255e61525e25d1bf81b46fe3c46a Mon Sep 17 00:00:00 2001 From: jiyongjung Date: Wed, 26 May 2021 00:19:30 -0700 Subject: [PATCH 2/6] Fixes broken vertex e2e test for the template. PiperOrigin-RevId: 375881341 --- .../taxi/e2e_tests/vertex_e2e_test.py | 47 ++++++++++--------- .../templates/taxi/kubeflow_v2_runner.py | 2 +- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/tfx/experimental/templates/taxi/e2e_tests/vertex_e2e_test.py b/tfx/experimental/templates/taxi/e2e_tests/vertex_e2e_test.py index e48db33b01..2fa2ee5d03 100644 --- a/tfx/experimental/templates/taxi/e2e_tests/vertex_e2e_test.py +++ b/tfx/experimental/templates/taxi/e2e_tests/vertex_e2e_test.py @@ -20,13 +20,11 @@ from absl import logging import docker from google.cloud import storage -from kfp.pipeline_spec import pipeline_spec_pb2 from kfp.v2.google import client import tensorflow as tf -from tfx.experimental.templates.taxi.e2e_tests import test_utils +from tfx.experimental.templates import test_utils from tfx.orchestration import test_utils as orchestration_test_utils from tfx.orchestration.kubeflow.v2 import test_utils as kubeflow_v2_test_utils -from tfx.utils import io_utils from tfx.utils import retry @@ -47,17 +45,22 @@ class TaxiTemplateKubeflowV2E2ETest(test_utils.BaseEndToEndTest): _REPO_BASE = os.environ['KFP_E2E_SRC'] _GCP_PROJECT_ID = os.environ['KFP_E2E_GCP_PROJECT_ID'] - _GCP_REGION = os.environ.get['KFP_E2E_GCP_REGION'] - _API_KEY = os.environ.get['CAIP_E2E_API_KEY'] + _GCP_REGION = os.environ['KFP_E2E_GCP_REGION'] _BUCKET_NAME = os.environ['KFP_E2E_BUCKET_NAME'] def setUp(self): super().setUp() random_id = orchestration_test_utils.random_id() + if ':' not in self._BASE_CONTAINER_IMAGE: + self._base_container_image = '{}:{}'.format(self._BASE_CONTAINER_IMAGE, + random_id) + self._prepare_base_container_image() + else: + self._base_container_image = self._BASE_CONTAINER_IMAGE self._target_container_image = 'gcr.io/{}/{}:{}'.format( self._GCP_PROJECT_ID, 'taxi-template-vertex-e2e-test', random_id) # Overriding the pipeline name to - self._pipeline_name = 'taxi_template_vertex_e2e_test_{}'.format( + self._pipeline_name = 'taxi-template-vertex-e2e-{}'.format( random_id) def tearDown(self): @@ -89,6 +92,10 @@ def _delete_base_container_image(self): def _delete_target_container_image(self): self._delete_docker_image(self._target_container_image) + def _prepare_base_container_image(self): + orchestration_test_utils.build_docker_image(self._base_container_image, + self._REPO_BASE) + def _create_pipeline(self): self._runCli([ 'pipeline', @@ -99,7 +106,18 @@ def _create_pipeline(self): 'kubeflow_v2_runner.py', '--build-image', '--build-base-image', - self._CONTAINER_IMAGE, + self._base_container_image, + ]) + + def _update_pipeline(self): + self._runCli([ + 'pipeline', + 'update', + '--engine', + 'vertex', + '--pipeline_path', + 'kubeflow_v2_runner.py', + '--build-image', ]) def _run_pipeline(self): @@ -114,8 +132,6 @@ def _run_pipeline(self): self._GCP_PROJECT_ID, '--region', self._GCP_REGION, - '--api_key', - self._GCP_API_KEY, ]) run_id = self._parse_run_id(result) self._wait_until_completed(run_id) @@ -131,8 +147,7 @@ def _parse_run_id(self, output: str): def _wait_until_completed(self, run_id: str): vertex_client = client.AIPlatformClient( project_id=self._GCP_PROJECT_ID, - region=self._GCP_REGION, - api_key=self._GCP_API_KEY) + region=self._GCP_REGION) kubeflow_v2_test_utils.poll_job_status(vertex_client, run_id, self._TIME_OUT, self._POLLING_INTERVAL_IN_SECONDS) @@ -174,16 +189,6 @@ def testPipeline(self): # Create a pipeline with only one component. self._create_pipeline() - # Extract the compiled pipeline spec. - kubeflow_v2_pb = pipeline_spec_pb2.PipelineJob() - io_utils.parse_json_file( - file_name=os.path.join(os.getcwd(), 'pipeline.json'), - message=kubeflow_v2_pb) - # There should be one step in the compiled pipeline. - self.assertLen(kubeflow_v2_pb.pipeline_spec['tasks'], 1) - - self._run_pipeline() - # Update the pipeline to include all components. updated_pipeline_file = self._addAllComponents() logging.info('Updated %s to add all components to the pipeline.', diff --git a/tfx/experimental/templates/taxi/kubeflow_v2_runner.py b/tfx/experimental/templates/taxi/kubeflow_v2_runner.py index 433307fe53..dc5edeacf3 100644 --- a/tfx/experimental/templates/taxi/kubeflow_v2_runner.py +++ b/tfx/experimental/templates/taxi/kubeflow_v2_runner.py @@ -48,7 +48,7 @@ # _DATA_PATH = 'data'. Note that Dataflow does not support use container as a # dependency currently, so this means CsvExampleGen cannot be used with Dataflow # (step 8 in the template notebook). -_DATA_PATH = 'gs://{}/tfx-template/data/'.format(configs.GCS_BUCKET_NAME) +_DATA_PATH = 'gs://{}/tfx-template/data/taxi/'.format(configs.GCS_BUCKET_NAME) def run(): From 5e769c3f9a5c775fe9c5d1fb97727b0b2042746c Mon Sep 17 00:00:00 2001 From: tfx-team Date: Mon, 24 May 2021 15:13:17 -0700 Subject: [PATCH 3/6] Update tuner in penguin kfp PiperOrigin-RevId: 375568040 --- tfx/examples/penguin/penguin_pipeline_kubeflow_gcp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfx/examples/penguin/penguin_pipeline_kubeflow_gcp.py b/tfx/examples/penguin/penguin_pipeline_kubeflow_gcp.py index e7274496cc..cfc13ea7bc 100644 --- a/tfx/examples/penguin/penguin_pipeline_kubeflow_gcp.py +++ b/tfx/examples/penguin/penguin_pipeline_kubeflow_gcp.py @@ -193,7 +193,7 @@ def create_pipeline( # vs distributed training per trial # ... -> DistributingCloudTunerA -> CAIP job Y -> master,worker1,2,3 # -> DistributingCloudTunerB -> CAIP job Z -> master,worker1,2,3 - tuner = tfx.components.Tuner( + tuner = tfx.extensions.google_cloud_ai_platform.Tuner( module_file=module_file, examples=transform.outputs['transformed_examples'], transform_graph=transform.outputs['transform_graph'], From 6b2c325b7e2e281df15470ee4e0e50841564cd53 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Tue, 1 Jun 2021 10:48:32 -0700 Subject: [PATCH 4/6] Clean up the default value for required field PiperOrigin-RevId: 376869705 --- RELEASE.md | 2 + tfx/components/bulk_inferrer/component.py | 2 +- tfx/components/evaluator/component.py | 4 +- tfx/components/example_validator/component.py | 4 +- tfx/components/pusher/component.py | 8 ++- tfx/components/schema_gen/component.py | 2 +- tfx/components/statistics_gen/component.py | 2 +- tfx/components/trainer/component.py | 55 ++++++++----------- tfx/components/trainer/component_test.py | 14 ++--- tfx/components/transform/component.py | 6 +- tfx/components/tuner/component.py | 10 ++-- .../bulk_inferrer/component.py | 4 +- 12 files changed, 52 insertions(+), 61 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index d8a619d83c..ca8e410e37 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -14,6 +14,8 @@ ## Breaking Changes +* Removed unneccessary default values for required component input Channels. + ### For Pipeline Authors * N/A diff --git a/tfx/components/bulk_inferrer/component.py b/tfx/components/bulk_inferrer/component.py index 4bd0a97959..ace3634b86 100644 --- a/tfx/components/bulk_inferrer/component.py +++ b/tfx/components/bulk_inferrer/component.py @@ -56,7 +56,7 @@ class BulkInferrer(base_beam_component.BaseBeamComponent): def __init__( self, - examples: types.Channel = None, + examples: types.Channel, model: Optional[types.Channel] = None, model_blessing: Optional[types.Channel] = None, data_spec: Optional[Union[bulk_inferrer_pb2.DataSpec, Dict[Text, diff --git a/tfx/components/evaluator/component.py b/tfx/components/evaluator/component.py index 0cf3f081df..1bdb90cc44 100644 --- a/tfx/components/evaluator/component.py +++ b/tfx/components/evaluator/component.py @@ -52,8 +52,8 @@ class Evaluator(base_beam_component.BaseBeamComponent): def __init__( self, - examples: types.Channel = None, - model: types.Channel = None, + examples: types.Channel, + model: Optional[types.Channel] = None, baseline_model: Optional[types.Channel] = None, # TODO(b/148618405): deprecate feature_slicing_spec. feature_slicing_spec: Optional[Union[evaluator_pb2.FeatureSlicingSpec, diff --git a/tfx/components/example_validator/component.py b/tfx/components/example_validator/component.py index 12103ca73e..b6244404be 100644 --- a/tfx/components/example_validator/component.py +++ b/tfx/components/example_validator/component.py @@ -67,8 +67,8 @@ class ExampleValidator(base_component.BaseComponent): EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor) def __init__(self, - statistics: types.Channel = None, - schema: types.Channel = None, + statistics: types.Channel, + schema: types.Channel, exclude_splits: Optional[List[Text]] = None): """Construct an ExampleValidator component. diff --git a/tfx/components/pusher/component.py b/tfx/components/pusher/component.py index a0986be11d..e167987174 100644 --- a/tfx/components/pusher/component.py +++ b/tfx/components/pusher/component.py @@ -99,8 +99,8 @@ def __init__( the same field names as PushDestination proto message. custom_config: A dict which contains the deployment job parameters to be passed to Cloud platforms. - custom_executor_spec: Optional custom executor spec. This is experimental - and is subject to change in the future. + custom_executor_spec: Optional custom executor spec. Deprecated (no + compatibility guarantee), please customize component directly. """ pushed_model = types.Channel(type=standard_artifacts.PushedModel) if (push_destination is None and not custom_executor_spec and @@ -109,7 +109,9 @@ def __init__( 'custom_executor_spec is supplied that does not require ' 'it.') if custom_executor_spec: - logging.warning('`custom_executor_spec` is going to be deprecated.') + logging.warning( + '`custom_executor_spec` is deprecated. Please customize component directly.' + ) if model is None and infra_blessing is None: raise ValueError( 'Either one of model or infra_blessing channel should be given. ' diff --git a/tfx/components/schema_gen/component.py b/tfx/components/schema_gen/component.py index cf1f317adc..e154b20e48 100644 --- a/tfx/components/schema_gen/component.py +++ b/tfx/components/schema_gen/component.py @@ -63,7 +63,7 @@ class SchemaGen(base_component.BaseComponent): def __init__( self, - statistics: Optional[types.Channel] = None, + statistics: types.Channel, infer_feature_shape: Optional[Union[bool, data_types.RuntimeParameter]] = True, exclude_splits: Optional[List[Text]] = None): diff --git a/tfx/components/statistics_gen/component.py b/tfx/components/statistics_gen/component.py index deaff65def..019f1129df 100644 --- a/tfx/components/statistics_gen/component.py +++ b/tfx/components/statistics_gen/component.py @@ -51,7 +51,7 @@ class StatisticsGen(base_beam_component.BaseBeamComponent): EXECUTOR_SPEC = executor_spec.BeamExecutorSpec(executor.Executor) def __init__(self, - examples: types.Channel = None, + examples: types.Channel, schema: Optional[types.Channel] = None, stats_options: Optional[tfdv.StatsOptions] = None, exclude_splits: Optional[List[Text]] = None): diff --git a/tfx/components/trainer/component.py b/tfx/components/trainer/component.py index 5d6859021a..e13654cd16 100644 --- a/tfx/components/trainer/component.py +++ b/tfx/components/trainer/component.py @@ -32,7 +32,6 @@ from tfx.utils import json_utils -# TODO(b/147702778): update when switch generic executor as default. class Trainer(base_component.BaseComponent): """A TFX component to train a TensorFlow model. @@ -43,34 +42,16 @@ class Trainer(base_component.BaseComponent): code](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/penguin_utils_keras.py) of the TFX penguin pipeline example. - *Note:* The default executor for this component trains locally. This can be - overriden to enable the model to be trained on other platforms. The [Cloud AI - Platform custom - executor](https://github.com/tensorflow/tfx/tree/master/tfx/extensions/google_cloud_ai_platform/trainer) - provides an example how to implement this. + *Note:* This component trains locally. For cloud distributed training, please + refer to [Cloud AI Platform + Trainer](https://github.com/tensorflow/tfx/tree/master/tfx/extensions/google_cloud_ai_platform/trainer). - ## Example 1: Training locally + ## Example ``` # Uses user-provided Python function that trains a model using TF. trainer = Trainer( module_file=module_file, - transformed_examples=transform.outputs['transformed_examples'], - schema=infer_schema.outputs['schema'], - transform_graph=transform.outputs['transform_graph'], - train_args=proto.TrainArgs(splits=['train'], num_steps=10000), - eval_args=proto.EvalArgs(splits=['eval'], num_steps=5000)) - ``` - - ## Example 2: Training through a cloud provider - ``` - from tfx.extensions.google_cloud_ai_platform.trainer import executor as - ai_platform_trainer_executor - # Train using Google Cloud AI Platform. - trainer = Trainer( - custom_executor_spec=executor_spec.ExecutorClassSpec( - ai_platform_trainer_executor.GenericExecutor), - module_file=module_file, - transformed_examples=transform.outputs['transformed_examples'], + examples=transform.outputs['transformed_examples'], schema=infer_schema.outputs['schema'], transform_graph=transform.outputs['transform_graph'], train_args=proto.TrainArgs(splits=['train'], num_steps=10000), @@ -92,7 +73,7 @@ class Trainer(base_component.BaseComponent): def __init__( self, - examples: types.Channel = None, + examples: Optional[types.Channel] = None, transformed_examples: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, schema: Optional[types.Channel] = None, @@ -102,8 +83,9 @@ def __init__( run_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, # TODO(b/147702778): deprecate trainer_fn. trainer_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, - train_args: Union[trainer_pb2.TrainArgs, Dict[Text, Any]] = None, - eval_args: Union[trainer_pb2.EvalArgs, Dict[Text, Any]] = None, + train_args: Optional[Union[trainer_pb2.TrainArgs, Dict[Text, + Any]]] = None, + eval_args: Optional[Union[trainer_pb2.EvalArgs, Dict[Text, Any]]] = None, custom_config: Optional[Dict[Text, Any]] = None, custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None): """Construct a Trainer component. @@ -112,7 +94,8 @@ def __init__( examples: A Channel of type `standard_artifacts.Examples`, serving as the source of examples used in training (required). May be raw or transformed. - transformed_examples: Deprecated field. Please set 'examples' instead. + transformed_examples: Deprecated (no compatibility guarantee). Please set + 'examples' instead. transform_graph: An optional Channel of type `standard_artifacts.TransformGraph`, serving as the input transform graph if present. @@ -168,8 +151,8 @@ def trainer_fn(trainer.fn_args_utils.FnArgs, behavior (when splits is empty) is evaluate on `eval` split. custom_config: A dict which contains addtional training job parameters that will be passed into user module. - custom_executor_spec: Optional custom executor spec. This is experimental - and is subject to change in the future. + custom_executor_spec: Optional custom executor spec. Deprecated (no + compatibility guarantee), please customize component directly. Raises: ValueError: @@ -194,7 +177,13 @@ def trainer_fn(trainer.fn_args_utils.FnArgs, "'transform_graph' must be supplied too.") if custom_executor_spec: - logging.warning("`custom_executor_spec` is going to be deprecated.") + logging.warning( + "`custom_executor_spec` is deprecated. Please customize component directly." + ) + if transformed_examples: + logging.warning( + "`transformed_examples` is deprecated. Please use `examples` instead." + ) examples = examples or transformed_examples model = types.Channel(type=standard_artifacts.Model) model_run = types.Channel(type=standard_artifacts.ModelRun) @@ -204,8 +193,8 @@ def trainer_fn(trainer.fn_args_utils.FnArgs, schema=schema, base_model=base_model, hyperparameters=hyperparameters, - train_args=train_args, - eval_args=eval_args, + train_args=train_args or trainer_pb2.TrainArgs(), + eval_args=eval_args or trainer_pb2.EvalArgs(), module_file=module_file, run_fn=run_fn, trainer_fn=trainer_fn, diff --git a/tfx/components/trainer/component_test.py b/tfx/components/trainer/component_test.py index 4ecce71da8..19bd7419a9 100644 --- a/tfx/components/trainer/component_test.py +++ b/tfx/components/trainer/component_test.py @@ -57,11 +57,9 @@ def testConstructFromModuleFile(self): module_file = '/path/to/module/file' trainer = component.Trainer( module_file=module_file, - transformed_examples=self.examples, + examples=self.examples, transform_graph=self.transform_graph, - schema=self.schema, - train_args=self.train_args, - eval_args=self.eval_args) + schema=self.schema) self._verify_outputs(trainer) self.assertEqual( module_file, @@ -72,7 +70,7 @@ def testConstructWithParameter(self): n_steps = data_types.RuntimeParameter(name='n-steps', ptype=int) trainer = component.Trainer( module_file=module_file, - transformed_examples=self.examples, + examples=self.examples, transform_graph=self.transform_graph, schema=self.schema, train_args=dict(splits=['train'], num_steps=n_steps), @@ -87,7 +85,7 @@ def testConstructFromTrainerFn(self): trainer_fn = 'path.to.my_trainer_fn' trainer = component.Trainer( trainer_fn=trainer_fn, - transformed_examples=self.examples, + examples=self.examples, transform_graph=self.transform_graph, train_args=self.train_args, eval_args=self.eval_args) @@ -102,7 +100,7 @@ def testConstructFromRunFn(self): run_fn=run_fn, custom_executor_spec=executor_spec.ExecutorClassSpec( executor.GenericExecutor), - transformed_examples=self.examples, + examples=self.examples, transform_graph=self.transform_graph, train_args=self.train_args, eval_args=self.eval_args) @@ -175,7 +173,7 @@ def testConstructDuplicateUserModule(self): def testConstructWithHParams(self): trainer = component.Trainer( trainer_fn='path.to.my_trainer_fn', - transformed_examples=self.examples, + examples=self.examples, transform_graph=self.transform_graph, schema=self.schema, hyperparameters=self.hyperparameters, diff --git a/tfx/components/transform/component.py b/tfx/components/transform/component.py index 562f80dead..e4c60bfb71 100644 --- a/tfx/components/transform/component.py +++ b/tfx/components/transform/component.py @@ -91,12 +91,12 @@ class Transform(base_beam_component.BaseBeamComponent): def __init__( self, - examples: types.Channel = None, - schema: types.Channel = None, + examples: types.Channel, + schema: types.Channel, module_file: Optional[Union[Text, data_types.RuntimeParameter]] = None, preprocessing_fn: Optional[Union[Text, data_types.RuntimeParameter]] = None, - splits_config: transform_pb2.SplitsConfig = None, + splits_config: Optional[transform_pb2.SplitsConfig] = None, analyzer_cache: Optional[types.Channel] = None, materialize: bool = True, disable_analyzer_cache: bool = False, diff --git a/tfx/components/tuner/component.py b/tfx/components/tuner/component.py index e140e0c19c..f7fb1f1415 100644 --- a/tfx/components/tuner/component.py +++ b/tfx/components/tuner/component.py @@ -65,13 +65,13 @@ class Tuner(base_component.BaseComponent): EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor) def __init__(self, - examples: types.Channel = None, + examples: types.Channel, schema: Optional[types.Channel] = None, transform_graph: Optional[types.Channel] = None, module_file: Optional[Text] = None, tuner_fn: Optional[Text] = None, - train_args: trainer_pb2.TrainArgs = None, - eval_args: trainer_pb2.EvalArgs = None, + train_args: Optional[trainer_pb2.TrainArgs] = None, + eval_args: Optional[trainer_pb2.EvalArgs] = None, tune_args: Optional[tuner_pb2.TuneArgs] = None, custom_config: Optional[Dict[Text, Any]] = None): """Construct a Tuner component. @@ -116,8 +116,8 @@ def tuner_fn(fn_args: FnArgs) -> TunerFnResult: transform_graph=transform_graph, module_file=module_file, tuner_fn=tuner_fn, - train_args=train_args, - eval_args=eval_args, + train_args=train_args or trainer_pb2.TrainArgs(), + eval_args=eval_args or trainer_pb2.EvalArgs(), tune_args=tune_args, best_hyperparameters=best_hyperparameters, custom_config=json_utils.dumps(custom_config), diff --git a/tfx/extensions/google_cloud_ai_platform/bulk_inferrer/component.py b/tfx/extensions/google_cloud_ai_platform/bulk_inferrer/component.py index 3b71581a98..d79d58f69a 100644 --- a/tfx/extensions/google_cloud_ai_platform/bulk_inferrer/component.py +++ b/tfx/extensions/google_cloud_ai_platform/bulk_inferrer/component.py @@ -84,14 +84,14 @@ class CloudAIBulkInferrerComponent(base_component.BaseComponent): def __init__( self, - examples: types.Channel = None, + examples: types.Channel, model: Optional[types.Channel] = None, model_blessing: Optional[types.Channel] = None, data_spec: Optional[Union[bulk_inferrer_pb2.DataSpec, Dict[Text, Any]]] = None, output_example_spec: Optional[Union[bulk_inferrer_pb2.OutputExampleSpec, Dict[Text, Any]]] = None, - custom_config: Dict[Text, Any] = None): + custom_config: Optional[Dict[Text, Any]] = None): """Construct an BulkInferrer component. Args: From 44a1262e57b506cabeb13f171fda4b0d51489c4b Mon Sep 17 00:00:00 2001 From: dhruvesht Date: Fri, 28 May 2021 09:19:44 -0700 Subject: [PATCH 5/6] Adjusting the release logs for the vertex pipelines PiperOrigin-RevId: 376369379 --- RELEASE.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index ca8e410e37..2c361c85b9 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,9 @@ `disable_statistics=True` in the Transform component. * BERT cola and mrpc examples now demonstrate how to calculate statistics for NLP features. +* TFX CLI now supports + [Vertex Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines/introduction). + use it with `--engine=vertex` flag. ## Breaking Changes @@ -33,9 +36,6 @@ ## Bug Fixes and Other Changes -* TFX CLI now supports - [Vertex Pipelines](https://cloud.google.com/vertex-ai/docs/pipelines/introduction). - use it with `--engine=vertex` flag. * Forces keyword arguments for AirflowComponent to make it compatible with Apache Airflow 2.1.0 and later. * Removed `six` dependency. From 19b0c8bd332272259b500063acfede3bedd684bb Mon Sep 17 00:00:00 2001 From: jjong Date: Wed, 2 Jun 2021 03:50:33 -0700 Subject: [PATCH 6/6] Updates tensorflow-hub version constraint release note. FYI: #3812 #3813 PiperOrigin-RevId: 377024694 --- RELEASE.md | 1 + tfx/dependencies.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/RELEASE.md b/RELEASE.md index 2c361c85b9..54d674e5a2 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -44,6 +44,7 @@ * Depends on `struct2tensor>=0.31.0,<0.32.0`. * Depends on `tensorflow>=1.15.2,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3`. * Depends on `tensorflow-data-validation>=1.0.0,<1.1.0`. +* Depends on `tensorflow-hub>=0.9.0,<0.13`. * Depends on `tensorflowjs>=3.6.0,<4`. * Depends on `tensorflow-model-analysis>=0.31.0,<0.32.0`. * Depends on `tensorflow-serving-api>=1.15,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3`. diff --git a/tfx/dependencies.py b/tfx/dependencies.py index 27b780bf5f..650b355918 100644 --- a/tfx/dependencies.py +++ b/tfx/dependencies.py @@ -88,7 +88,7 @@ def make_required_install_packages(): 'pyarrow>=1,<3', 'pyyaml>=3.12,<6', 'tensorflow>=1.15.2,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,<3', - 'tensorflow-hub>=0.9.0,<0.10', + 'tensorflow-hub>=0.9.0,<0.13', 'tensorflow-data-validation' + select_constraint( default='>=1.0.0,<1.1.0', nightly='>=1.1.0.dev',