diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index ab1bc433d94a4..582e4abdb9e12 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from collections.abc import Iterable, Mapping from functools import cached_property from typing import TYPE_CHECKING, Any from urllib import parse @@ -42,6 +43,73 @@ def connect( return ESConnection(host, port, user, password, scheme, **kwargs) +class ElasticsearchSQLCursor: + """A PEP 249-like Cursor class for Elasticsearch SQL API""" + + def __init__(self, es: Elasticsearch, **kwargs): + self.es = es + self.body = { + "fetch_size": kwargs.get("fetch_size", 1000), + "field_multi_value_leniency": kwargs.get("field_multi_value_leniency", False), + } + self._response: ObjectApiResponse | None = None + + @property + def response(self) -> ObjectApiResponse: + return self._response or {} # type: ignore + + @response.setter + def response(self, value): + self._response = value + + @property + def cursor(self): + return self.response.get("cursor") + + @property + def rows(self): + return self.response.get("rows", []) + + @property + def rowcount(self) -> int: + return len(self.rows) + + @property + def description(self) -> list[tuple]: + return [(column["name"], column["type"]) for column in self.response.get("columns", [])] + + def execute( + self, statement: str, params: Iterable | Mapping[str, Any] | None = None + ) -> ObjectApiResponse: + self.body["query"] = statement + if params: + self.body["params"] = params + self.response = self.es.sql.query(body=self.body) + if self.cursor: + self.body["cursor"] = self.cursor + else: + self.body.pop("cursor", None) + return self.response + + def fetchone(self): + if self.rows: + return self.rows[0] + return None + + def fetchmany(self, size: int | None = None): + raise NotImplementedError() + + def fetchall(self): + results = self.rows + while self.cursor: + self.execute(statement=self.body["query"]) + results.extend(self.rows) + return results + + def close(self): + self._response = None + + class ESConnection: """wrapper class for elasticsearch.Elasticsearch.""" @@ -67,9 +135,19 @@ def __init__( else: self.es = Elasticsearch(self.url, **self.kwargs) - def execute_sql(self, query: str) -> ObjectApiResponse: - sql_query = {"query": query} - return self.es.sql.query(body=sql_query) + def cursor(self) -> ElasticsearchSQLCursor: + return ElasticsearchSQLCursor(self.es, **self.kwargs) + + def close(self): + self.es.close() + + def commit(self): + pass + + def execute_sql( + self, query: str, params: Iterable | Mapping[str, Any] | None = None + ) -> ObjectApiResponse: + return self.cursor().execute(query, params) class ElasticsearchSQLHook(DbApiHook): @@ -84,13 +162,13 @@ class ElasticsearchSQLHook(DbApiHook): conn_name_attr = "elasticsearch_conn_id" default_conn_name = "elasticsearch_default" + connector = ESConnection conn_type = "elasticsearch" hook_name = "Elasticsearch" def __init__(self, schema: str = "http", connection: AirflowConnection | None = None, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = schema - self._connection = connection def get_conn(self) -> ESConnection: """Return an elasticsearch connection object.""" @@ -104,11 +182,10 @@ def get_conn(self) -> ESConnection: "scheme": conn.schema or "http", } - if conn.extra_dejson.get("http_compress", False): - conn_args["http_compress"] = bool(["http_compress"]) + conn_args.update(conn.extra_dejson) - if conn.extra_dejson.get("timeout", False): - conn_args["timeout"] = conn.extra_dejson["timeout"] + if conn_args.get("http_compress", False): + conn_args["http_compress"] = bool(conn_args["http_compress"]) return connect(**conn_args) diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index ea34f2532de41..953e7dd50ef72 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -20,15 +20,42 @@ from unittest import mock from unittest.mock import MagicMock +import pytest from elasticsearch import Elasticsearch +from elasticsearch._sync.client import SqlClient +from kgb import SpyAgency from airflow.models import Connection +from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.elasticsearch.hooks.elasticsearch import ( ElasticsearchPythonHook, + ElasticsearchSQLCursor, ElasticsearchSQLHook, ESConnection, ) +ROWS = [ + [1, "Stallone", "Sylvester", "78"], + [2, "Statham", "Jason", "57"], + [3, "Li", "Jet", "61"], + [4, "Lundgren", "Dolph", "66"], + [5, "Norris", "Chuck", "84"], +] +RESPONSE_WITHOUT_CURSOR = { + "columns": [ + {"name": "index", "type": "long"}, + {"name": "name", "type": "text"}, + {"name": "firstname", "type": "text"}, + {"name": "age", "type": "long"}, + ], + "rows": ROWS, +} +RESPONSE = {**RESPONSE_WITHOUT_CURSOR, **{"cursor": "e7f8QwXUruW2mIebzudH4BwAA//8DAA=="}} +RESPONSES = [ + RESPONSE, + RESPONSE_WITHOUT_CURSOR, +] + class TestElasticsearchSQLHookConn: def setup_method(self): @@ -48,10 +75,68 @@ def test_get_conn(self, mock_connect): mock_connect.assert_called_with(host="localhost", port=9200, scheme="http", user=None, password=None) +class TestElasticsearchSQLCursor: + def setup_method(self): + sql = MagicMock(spec=SqlClient) + sql.query.side_effect = RESPONSES + self.es = MagicMock(sql=sql, spec=Elasticsearch) + + def test_execute(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + + assert cursor.execute("SELECT * FROM hollywood.actors") == RESPONSE + + def test_rowcount(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.rowcount == len(ROWS) + + def test_description(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.description == [ + ("index", "long"), + ("name", "text"), + ("firstname", "text"), + ("age", "long"), + ] + + def test_fetchone(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.fetchone() == ROWS[0] + + def test_fetchmany(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + with pytest.raises(NotImplementedError): + cursor.fetchmany() + + def test_fetchall(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + records = cursor.fetchall() + + assert len(records) == 10 + assert records == ROWS + + class TestElasticsearchSQLHook: def setup_method(self): - self.cur = mock.MagicMock(rowcount=0) - self.conn = mock.MagicMock() + sql = MagicMock(spec=SqlClient) + sql.query.side_effect = RESPONSES + es = MagicMock(sql=sql, spec=Elasticsearch) + self.cur = ElasticsearchSQLCursor(es=es, options={}) + self.spy_agency = SpyAgency() + self.spy_agency.spy_on(self.cur.close, call_original=True) + self.spy_agency.spy_on(self.cur.execute, call_original=True) + self.spy_agency.spy_on(self.cur.fetchall, call_original=True) + self.conn = MagicMock(spec=ESConnection) self.conn.cursor.return_value = self.cur conn = self.conn @@ -64,55 +149,60 @@ def get_conn(self): self.db_hook = UnitTestElasticsearchSQLHook() def test_get_first_record(self): - statement = "SQL" - result_sets = [("row1",), ("row2",)] - self.cur.fetchone.return_value = result_sets[0] + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.get_first(statement) == ROWS[0] - assert result_sets[0] == self.db_hook.get_first(statement) self.conn.close.assert_called_once_with() - self.cur.close.assert_called_once_with() - self.cur.execute.assert_called_once_with(statement) + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) def test_get_records(self): - statement = "SQL" - result_sets = [("row1",), ("row2",)] - self.cur.fetchall.return_value = result_sets + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.get_records(statement) == ROWS - assert result_sets == self.db_hook.get_records(statement) self.conn.close.assert_called_once_with() - self.cur.close.assert_called_once_with() - self.cur.execute.assert_called_once_with(statement) + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) def test_get_pandas_df(self): - statement = "SQL" - column = "col" - result_sets = [("row1",), ("row2",)] - self.cur.description = [(column,)] - self.cur.fetchall.return_value = result_sets + statement = "SELECT * FROM hollywood.actors" df = self.db_hook.get_pandas_df(statement) - assert column == df.columns[0] + assert list(df.columns) == ["index", "name", "firstname", "age"] + assert df.values.tolist() == ROWS + + self.conn.close.assert_called_once_with() + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) + + def test_run(self): + statement = "SELECT * FROM hollywood.actors" - assert result_sets[0][0] == df.values.tolist()[0][0] - assert result_sets[1][0] == df.values.tolist()[1][0] + assert self.db_hook.run(statement, handler=fetch_all_handler) == ROWS - self.cur.execute.assert_called_once_with(statement) + self.conn.close.assert_called_once_with() + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) @mock.patch("airflow.providers.elasticsearch.hooks.elasticsearch.Elasticsearch") def test_execute_sql_query(self, mock_es): mock_es_sql_client = MagicMock() - mock_es_sql_client.query.return_value = { - "columns": [{"name": "id"}, {"name": "first_name"}], - "rows": [[1, "John"], [2, "Jane"]], - } + mock_es_sql_client.query.return_value = RESPONSE_WITHOUT_CURSOR mock_es.return_value.sql = mock_es_sql_client es_connection = ESConnection(host="localhost", port=9200) - response = es_connection.execute_sql("SELECT * FROM index1") - mock_es_sql_client.query.assert_called_once_with(body={"query": "SELECT * FROM index1"}) - - assert response["rows"] == [[1, "John"], [2, "Jane"]] - assert response["columns"] == [{"name": "id"}, {"name": "first_name"}] + response = es_connection.execute_sql("SELECT * FROM hollywood.actors") + mock_es_sql_client.query.assert_called_once_with( + body={ + "fetch_size": 1000, + "field_multi_value_leniency": False, + "query": "SELECT * FROM hollywood.actors", + } + ) + + assert response == RESPONSE_WITHOUT_CURSOR class MockElasticsearch: