diff --git a/sdks/python/apache_beam/runners/runner.py b/sdks/python/apache_beam/runners/runner.py index 1d0f700c66c3..f08ab93b501b 100644 --- a/sdks/python/apache_beam/runners/runner.py +++ b/sdks/python/apache_beam/runners/runner.py @@ -117,13 +117,18 @@ class PipelineRunner(object): """ def run(self, transform, options=None): - """Run the given transform with this runner. + """Run the given transform or callable with this runner. """ # Imported here to avoid circular dependencies. # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam import PTransform + from apache_beam.pvalue import PBegin from apache_beam.pipeline import Pipeline p = Pipeline(runner=self, options=options) - p | transform + if isinstance(transform, PTransform): + p | transform + else: + transform(PBegin(p)) return p.run() def run_pipeline(self, pipeline): diff --git a/sdks/python/apache_beam/runners/runner_test.py b/sdks/python/apache_beam/runners/runner_test.py index aa615cfdd196..9ab226faf828 100644 --- a/sdks/python/apache_beam/runners/runner_test.py +++ b/sdks/python/apache_beam/runners/runner_test.py @@ -128,6 +128,20 @@ def test_run_api(self): my_metric_value = result.metrics().query()['counters'][0].committed self.assertEqual(my_metric_value, 111) + def test_run_api_with_callable(self): + my_metric = Metrics.counter('namespace', 'my_metric') + + def fn(start): + return (start + | beam.Create([1, 10, 100]) + | beam.Map(lambda x: my_metric.inc(x))) + runner = DirectRunner() + result = runner.run(fn) + result.wait_until_finish() + # Use counters to assert the pipeline actually ran. + my_metric_value = result.metrics().query()['counters'][0].committed + self.assertEqual(my_metric_value, 111) + if __name__ == '__main__': unittest.main()