diff --git a/airflow/contrib/operators/r_operator.py b/airflow/contrib/operators/r_operator.py new file mode 100644 index 0000000000000..2db3a6087274d --- /dev/null +++ b/airflow/contrib/operators/r_operator.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from builtins import bytes +import os +from tempfile import NamedTemporaryFile + +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.utils.file import TemporaryDirectory +from airflow.utils.operator_helpers import context_to_airflow_vars + +import rpy2.robjects as robjects +from rpy2.rinterface import RRuntimeError + + +class ROperator(BaseOperator): + """ + Execute an R script or command + + If BaseOperator.do_xcom_push is True, the last line written to stdout + will also be pushed to an XCom when the R command completes + + :param r_command: The command or a reference to an R script (must have + '.r' extension) to be executed (templated) + :type r_command: string + :param env: Optional list of environment variables and their (string) + values to set (templated). Unlike `BashOperator`, this does not + replace the current environment, although it can be used to override + existing values. Values can be read in R with `Sys.getenv()`. + :type env: dict + :param output_encoding: encoding output from R (default: 'utf-8') + :type output_encoding: string + + """ + + template_fields = ('r_command', 'env',) + template_ext = ('.r', '.R') + ui_color = '#C8D5E6' + + @apply_defaults + def __init__( + self, + r_command, + env={}, + output_encoding='utf-8', + *args, **kwargs): + + super(ROperator, self).__init__(*args, **kwargs) + self.r_command = r_command + self.env = env + self.output_encoding = output_encoding + + def execute(self, context): + """ + Execute the R command or script in a temporary directory + """ + + # Export additional environment variables + os.environ.update(self.env) + + # Export context as environment variables + airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True) + self.log.info('Exporting the following env vars:\n%s', + '\n'.join(["{}={}".format(k, v) + for k, v in + airflow_context_vars.items()])) + os.environ.update(airflow_context_vars) + + with TemporaryDirectory(prefix='airflowtmp') as tmp_dir: + with NamedTemporaryFile(dir=tmp_dir, prefix=self.task_id) as f: + + f.write(bytes(self.r_command, 'utf_8')) + f.flush() + fname = f.name + script_location = os.path.abspath(fname) + + self.log.info("Temporary script location: %s", script_location) + self.log.info("Running command(s):\n%s", self.r_command) + + try: + res = robjects.r.source(fname, echo=False) + except RRuntimeError as e: + self.log.error("Received R error: %s", e) + res = None + + # This will be a pickled rpy2.robjects.vectors.ListVector + return res diff --git a/setup.py b/setup.py index 557c7b5652f2b..d3219ce07b852 100644 --- a/setup.py +++ b/setup.py @@ -210,6 +210,7 @@ def write_version(filename=os.path.join(*['airflow', pinot = ['pinotdb==0.1.1'] postgres = ['psycopg2>=2.7.4'] qds = ['qds-sdk>=1.10.4'] +r = ['rpy2>=2.9.5'] rabbitmq = ['librabbitmq>=1.6.1'] redis = ['redis>=2.10.5,<3.0.0'] salesforce = ['simple-salesforce>=0.72'] @@ -261,7 +262,7 @@ def write_version(filename=os.path.join(*['airflow', docker + ssh + kubernetes + celery + redis + gcp_api + datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + druid + pinot + segment + snowflake + elasticsearch + - atlas + azure + aws) + atlas + azure + aws + r) # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'( if PY3: @@ -369,6 +370,7 @@ def do_setup(): 'pinot': pinot, 'postgres': postgres, 'qds': qds, + 'r': r, 'rabbitmq': rabbitmq, 'redis': redis, 'salesforce': salesforce, diff --git a/tests/contrib/operators/test_r_operator.py b/tests/contrib/operators/test_r_operator.py new file mode 100644 index 0000000000000..135d1ed764f1b --- /dev/null +++ b/tests/contrib/operators/test_r_operator.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, unicode_literals + +import unittest + +from airflow import configuration, DAG +from airflow.contrib.operators.r_operator import ROperator +from airflow.models import TaskInstance +from airflow.utils import timezone + + +DEFAULT_DATE = timezone.datetime(2016, 1, 1) + + +class ROperatorTest(unittest.TestCase): + """Test the ROperator""" + + def setUp(self): + super(ROperatorTest, self).setUp() + configuration.load_test_config() + self.dag = DAG( + 'test_roperator_dag', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + }, + schedule_interval='@once' + ) + + self.xcom_test_str = 'Hello Airflow' + self.task_xcom = ROperator( + task_id='test_r_xcom', + r_command='cat("Ignored Line\n{}")'.format(self.xcom_test_str), + xcom_push=True, + dag=self.dag + ) + + def test_xcom_output(self): + """Test whether Xcom output is produced using last line""" + + self.task_xcom.do_xcom_push = True + + ti = TaskInstance( + task=self.task_xcom, + execution_date=timezone.utcnow() + ) + + ti.run() + self.assertIsNotNone(ti.duration) + + self.assertEqual( + ti.xcom_pull(task_ids=self.task_xcom.task_id, key='return_value'), + self.xcom_test_str + ) + + def test_xcom_none(self): + """Test whether no Xcom output is produced when push=False""" + + self.task_xcom.do_xcom_push = False + + ti = TaskInstance( + task=self.task_xcom, + execution_date=timezone.utcnow(), + ) + + ti.run() + self.assertIsNotNone(ti.duration) + self.assertIsNone(ti.xcom_pull(task_ids=self.task_xcom.task_id)) + + def test_command_template(self): + """Test whether templating works properly with r_command""" + + task = ROperator( + task_id='test_cmd_template', + r_command='cat("{{ ds }}")', + dag=self.dag + ) + + ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) + ti.render_templates() + + self.assertEqual( + ti.task.r_command, + 'cat("{}")'.format(DEFAULT_DATE.date().isoformat()) + ) + + +if __name__ == '__main__': + unittest.main()