diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py index 5e205b16d5e9e..f09c8bb7002f8 100644 --- a/airflow/hooks/mysql_hook.py +++ b/airflow/hooks/mysql_hook.py @@ -56,11 +56,14 @@ def run(self, sql): cur.close() conn.close() - def insert_rows(self, table, rows): + def insert_rows(self, table, rows, target_fields=None): """ A generic way to insert a set of tuples into a table, the whole set of inserts is treated as one transaction """ + if target_fields: + target_fields = ", ".join(target_fields) + target_fields = "({})".format(target_fields) conn = self.get_conn() cur = conn.cursor() for row in rows: @@ -73,8 +76,10 @@ def insert_rows(self, table, rows): else: l.append(str(cell)) values = tuple(l) - sql = "INSERT INTO {0} VALUES ({1});".format( - table, ",".join(values)) + sql = "INSERT INTO {0} {1} VALUES ({2});".format( + table, + target_fields, + ",".join(values)) cur.execute(sql) conn.commit() cur.close() diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py index 0caca64dca1b9..2b5fe43cad40f 100644 --- a/airflow/operators/hive_stats_operator.py +++ b/airflow/operators/hive_stats_operator.py @@ -7,7 +7,6 @@ from airflow.utils import apply_defaults - class HiveStatsCollectionOperator(BaseOperator): """ Gathers partition statistics using a dynmically generated Presto @@ -35,7 +34,9 @@ class HiveStatsCollectionOperator(BaseOperator): :type col_blacklist: list :param assignment_func: a function that receives a column name and a type, and returns a dict of metric names and an Presto expressions. - If None is returned, the global defaults are applied. + If None is returned, the global defaults are applied. If an + empty dictionary is returned, no stats are computed for that + column. :type assignment_func: function """ @@ -68,6 +69,7 @@ def __init__( self.mysql_conn_id = mysql_conn_id self.assignment_func = assignment_func self.ds = '{{ ds }}' + self.dttm = '{{ execution_date.isoformat() }}' def get_default_exprs(self, col, col_type): if col in self.col_blacklist: @@ -100,7 +102,9 @@ def execute(self, context=None): d = {} if self.assignment_func: d = self.assignment_func(col, col_type) - if not d: + if d is None: + d = self.get_default_exprs(col, col_type) + else: d = self.get_default_exprs(col, col_type) exprs.update(d) exprs.update(self.extra_exprs) @@ -142,6 +146,19 @@ def execute(self, context=None): logging.info("Pivoting and loading cells into the Airflow db") rows = [ - (self.ds, self.table, part_json) + (r[0][0], r[0][1], r[1]) + (self.ds, self.dttm, self.table, part_json) + + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)] - mysql.insert_rows(table='hive_stats', rows=rows) + mysql.insert_rows( + table='hive_stats', + rows=rows, + target_fields=[ + 'ds', + 'dttm', + 'table_name', + 'partition_repr', + 'col', + 'metric', + 'value', + ] + )