From 1d2a6be9b2bbd915367747a0eb7c936353a8f473 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 4 Feb 2025 20:48:33 +0100 Subject: [PATCH 01/12] refactor: Implemented ElasticsearchSQLCursor --- .../elasticsearch/hooks/elasticsearch.py | 85 +++++++++++++++++-- .../elasticsearch/hooks/test_elasticsearch.py | 68 ++++++++++++++- 2 files changed, 143 insertions(+), 10 deletions(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index ab1bc433d94a4..b69fb461a679d 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -18,7 +18,7 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Iterable, Mapping from urllib import parse from airflow.hooks.base import BaseHook @@ -42,6 +42,71 @@ def connect( return ESConnection(host, port, user, password, scheme, **kwargs) +class ElasticsearchSQLCursor: + """ A PEP 249-like Cursor class for Elasticsearch SQL API """ + + def __init__(self, es: Elasticsearch, options: dict[str, Any]): + self.es = es + self.body = { + "fetch_size": options.get("fetch_size", 100), + "field_multi_value_leniency": options.get("field_multi_value_leniency", False), + } + self._response: ObjectApiResponse | None = None + + @property + def response(self) -> ObjectApiResponse: + return self._response or {} + + @response.setter + def response(self, value): + self._response = value + + @property + def cursor(self): + return self.response.get("cursor") + + @property + def rows(self): + return self.response.get("rows", []) + + @property + def rowcount(self) -> int: + return len(self.rows) + + @property + def description(self) -> list[tuple]: + return self.response.get("columns", []) + + def execute(self, statement: str, params: Iterable | Mapping[str, Any] | None = None) -> ObjectApiResponse: + self.body["query"] = statement + if params: + self.body["params"] = params + self.response = self.es.sql.query(body=self.body) + if self.cursor: + self.body["cursor"] = self.cursor + else: + self.body.pop("cursor", None) + return self.response + + def fetchone(self): + if self.rows: + return self.rows[0] + return None + + def fetchmany(self, size: int | None = None): + raise NotImplementedError() + + def fetchall(self): + results = self.rows + while self.cursor: + self.execute(query=self.body["query"]) + results.extend(self.rows) + return results + + def close(self): + pass + + class ESConnection: """wrapper class for elasticsearch.Elasticsearch.""" @@ -67,9 +132,14 @@ def __init__( else: self.es = Elasticsearch(self.url, **self.kwargs) - def execute_sql(self, query: str) -> ObjectApiResponse: - sql_query = {"query": query} - return self.es.sql.query(body=sql_query) + def cursor(self) -> ElasticsearchSQLCursor: + return ElasticsearchSQLCursor(self.es, self.kwargs) + + def close(self): + self.es.close() + + def execute_sql(self, query: str, params: Iterable | Mapping[str, Any] | None = None) -> ObjectApiResponse: + return self.cursor().execute(query, params) class ElasticsearchSQLHook(DbApiHook): @@ -104,11 +174,10 @@ def get_conn(self) -> ESConnection: "scheme": conn.schema or "http", } - if conn.extra_dejson.get("http_compress", False): - conn_args["http_compress"] = bool(["http_compress"]) + conn_args.update(conn.extra_dejson) - if conn.extra_dejson.get("timeout", False): - conn_args["timeout"] = conn.extra_dejson["timeout"] + if conn_args.get("http_compress", False): + conn_args["http_compress"] = bool(conn_args["http_compress"]) return connect(**conn_args) diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index ea34f2532de41..ce46460bae60b 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -20,14 +20,17 @@ from unittest import mock from unittest.mock import MagicMock +import pytest from elasticsearch import Elasticsearch from airflow.models import Connection from airflow.providers.elasticsearch.hooks.elasticsearch import ( ElasticsearchPythonHook, + ElasticsearchSQLCursor, ElasticsearchSQLHook, ESConnection, ) +from elasticsearch._sync.client import SqlClient class TestElasticsearchSQLHookConn: @@ -48,10 +51,71 @@ def test_get_conn(self, mock_connect): mock_connect.assert_called_with(host="localhost", port=9200, scheme="http", user=None, password=None) +class TestElasticsearchSQLCursor: + rows = [ + [1, "Stallone", "Sylvester", "78"], + [2, "Statham", "Jason", "57"], + [3, "Li", "Jet", "61"], + [4, "Lundgren", "Dolph", "66"], + [5, "Norris", "Chuck", "84"], + ] + columns = [ + {'name': 'index', 'type': 'long'}, + {'name': 'name', 'type': 'text'}, + {'name': 'firstname', 'type': 'text'}, + {'name': 'age', 'type': 'long'}, + ] + response = { + "columns": columns, + "rows": rows + } + + def setup_method(self): + sql = MagicMock(spec=SqlClient) + sql.query.side_effect = lambda body: self.response + self.es = MagicMock(sql=sql, spec=Elasticsearch) + + def test_execute(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + + assert cursor.execute("SELECT * FROM hollywood.actors") == self.response + + def test_rowcount(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.rowcount == len(self.rows) + + def test_description(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.description == self.columns + + def test_fetchone(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.fetchone() == self.rows[0] + + def test_fetchmany(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + with pytest.raises(NotImplementedError): + cursor.fetchmany() + + def test_fetchall(self): + cursor = ElasticsearchSQLCursor(es=self.es, options={}) + cursor.execute("SELECT * FROM hollywood.actors") + + assert cursor.fetchall() == self.rows + + class TestElasticsearchSQLHook: def setup_method(self): - self.cur = mock.MagicMock(rowcount=0) - self.conn = mock.MagicMock() + self.cur = mock.MagicMock(rowcount=0, spec=ElasticsearchSQLCursor) + self.conn = mock.MagicMock(spec=ESConnection) self.conn.cursor.return_value = self.cur conn = self.conn From 3ba5254018bb8907fb8b966dcfd0517862a1d32d Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 09:56:34 +0100 Subject: [PATCH 02/12] refactor: Fixed some issue with ElasticsearchSQLCursor when used through the hook --- .../elasticsearch/hooks/elasticsearch.py | 7 +- .../elasticsearch/hooks/test_elasticsearch.py | 126 ++++++++++-------- 2 files changed, 74 insertions(+), 59 deletions(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index b69fb461a679d..fc04600af2df2 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -75,7 +75,7 @@ def rowcount(self) -> int: @property def description(self) -> list[tuple]: - return self.response.get("columns", []) + return [(column["name"], column["type"]) for column in self.response.get("columns", [])] def execute(self, statement: str, params: Iterable | Mapping[str, Any] | None = None) -> ObjectApiResponse: self.body["query"] = statement @@ -138,6 +138,9 @@ def cursor(self) -> ElasticsearchSQLCursor: def close(self): self.es.close() + def commit(self): + pass + def execute_sql(self, query: str, params: Iterable | Mapping[str, Any] | None = None) -> ObjectApiResponse: return self.cursor().execute(query, params) @@ -154,13 +157,13 @@ class ElasticsearchSQLHook(DbApiHook): conn_name_attr = "elasticsearch_conn_id" default_conn_name = "elasticsearch_default" + connector = ESConnection conn_type = "elasticsearch" hook_name = "Elasticsearch" def __init__(self, schema: str = "http", connection: AirflowConnection | None = None, *args, **kwargs): super().__init__(*args, **kwargs) self.schema = schema - self._connection = connection def get_conn(self) -> ESConnection: """Return an elasticsearch connection object.""" diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index ce46460bae60b..55299671e2cd9 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -18,10 +18,12 @@ from __future__ import annotations from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch, create_autospec import pytest +from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from elasticsearch import Elasticsearch +from kgb import spy_on, SpyAgency from airflow.models import Connection from airflow.providers.elasticsearch.hooks.elasticsearch import ( @@ -33,6 +35,24 @@ from elasticsearch._sync.client import SqlClient +ROWS = [ + [1, "Stallone", "Sylvester", "78"], + [2, "Statham", "Jason", "57"], + [3, "Li", "Jet", "61"], + [4, "Lundgren", "Dolph", "66"], + [5, "Norris", "Chuck", "84"], +] +RESPONSE = { + "columns": [ + {'name': 'index', 'type': 'long'}, + {'name': 'name', 'type': 'text'}, + {'name': 'firstname', 'type': 'text'}, + {'name': 'age', 'type': 'long'}, + ], + "rows": ROWS +} + + class TestElasticsearchSQLHookConn: def setup_method(self): self.connection = Connection(host="localhost", port=9200, schema="http") @@ -52,51 +72,33 @@ def test_get_conn(self, mock_connect): class TestElasticsearchSQLCursor: - rows = [ - [1, "Stallone", "Sylvester", "78"], - [2, "Statham", "Jason", "57"], - [3, "Li", "Jet", "61"], - [4, "Lundgren", "Dolph", "66"], - [5, "Norris", "Chuck", "84"], - ] - columns = [ - {'name': 'index', 'type': 'long'}, - {'name': 'name', 'type': 'text'}, - {'name': 'firstname', 'type': 'text'}, - {'name': 'age', 'type': 'long'}, - ] - response = { - "columns": columns, - "rows": rows - } - def setup_method(self): sql = MagicMock(spec=SqlClient) - sql.query.side_effect = lambda body: self.response + sql.query.side_effect = lambda body: RESPONSE self.es = MagicMock(sql=sql, spec=Elasticsearch) def test_execute(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) - assert cursor.execute("SELECT * FROM hollywood.actors") == self.response + assert cursor.execute("SELECT * FROM hollywood.actors") == RESPONSE def test_rowcount(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) cursor.execute("SELECT * FROM hollywood.actors") - assert cursor.rowcount == len(self.rows) + assert cursor.rowcount == len(ROWS) def test_description(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) cursor.execute("SELECT * FROM hollywood.actors") - assert cursor.description == self.columns + assert cursor.description == [("index", "long"), ("name", "text"), ("firstname", "text"), ("age", "long")] def test_fetchone(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) cursor.execute("SELECT * FROM hollywood.actors") - assert cursor.fetchone() == self.rows[0] + assert cursor.fetchone() == ROWS[0] def test_fetchmany(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) @@ -109,13 +111,20 @@ def test_fetchall(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) cursor.execute("SELECT * FROM hollywood.actors") - assert cursor.fetchall() == self.rows + assert cursor.fetchall() == ROWS class TestElasticsearchSQLHook: def setup_method(self): - self.cur = mock.MagicMock(rowcount=0, spec=ElasticsearchSQLCursor) - self.conn = mock.MagicMock(spec=ESConnection) + sql = MagicMock(spec=SqlClient) + sql.query.side_effect = lambda body: RESPONSE + es = MagicMock(sql=sql, spec=Elasticsearch) + self.cur = ElasticsearchSQLCursor(es=es, options={}) + self.spy_agency = SpyAgency() + self.spy_agency.spy_on(self.cur.close, call_original=True) + self.spy_agency.spy_on(self.cur.execute, call_original=True) + self.spy_agency.spy_on(self.cur.fetchall, call_original=True) + self.conn = MagicMock(spec=ESConnection) self.conn.cursor.return_value = self.cur conn = self.conn @@ -128,55 +137,58 @@ def get_conn(self): self.db_hook = UnitTestElasticsearchSQLHook() def test_get_first_record(self): - statement = "SQL" - result_sets = [("row1",), ("row2",)] - self.cur.fetchone.return_value = result_sets[0] + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.get_first(statement) == ROWS[0] - assert result_sets[0] == self.db_hook.get_first(statement) self.conn.close.assert_called_once_with() - self.cur.close.assert_called_once_with() - self.cur.execute.assert_called_once_with(statement) + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) def test_get_records(self): - statement = "SQL" - result_sets = [("row1",), ("row2",)] - self.cur.fetchall.return_value = result_sets + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.get_records(statement) == ROWS - assert result_sets == self.db_hook.get_records(statement) self.conn.close.assert_called_once_with() - self.cur.close.assert_called_once_with() - self.cur.execute.assert_called_once_with(statement) + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) def test_get_pandas_df(self): - statement = "SQL" - column = "col" - result_sets = [("row1",), ("row2",)] - self.cur.description = [(column,)] - self.cur.fetchall.return_value = result_sets + statement = "SELECT * FROM hollywood.actors" df = self.db_hook.get_pandas_df(statement) - assert column == df.columns[0] + assert list(df.columns) == ["index", "name", "firstname", "age"] + assert df.values.tolist() == ROWS - assert result_sets[0][0] == df.values.tolist()[0][0] - assert result_sets[1][0] == df.values.tolist()[1][0] + self.conn.close.assert_called_once_with() + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) - self.cur.execute.assert_called_once_with(statement) + def test_run(self): + statement = "SELECT * FROM hollywood.actors" + + assert self.db_hook.run(statement, handler=fetch_all_handler) == ROWS + + self.conn.close.assert_called_once_with() + self.spy_agency.assert_spy_called(self.cur.close) + self.spy_agency.assert_spy_called(self.cur.execute) @mock.patch("airflow.providers.elasticsearch.hooks.elasticsearch.Elasticsearch") def test_execute_sql_query(self, mock_es): mock_es_sql_client = MagicMock() - mock_es_sql_client.query.return_value = { - "columns": [{"name": "id"}, {"name": "first_name"}], - "rows": [[1, "John"], [2, "Jane"]], - } + mock_es_sql_client.query.return_value = RESPONSE mock_es.return_value.sql = mock_es_sql_client es_connection = ESConnection(host="localhost", port=9200) - response = es_connection.execute_sql("SELECT * FROM index1") - mock_es_sql_client.query.assert_called_once_with(body={"query": "SELECT * FROM index1"}) - - assert response["rows"] == [[1, "John"], [2, "Jane"]] - assert response["columns"] == [{"name": "id"}, {"name": "first_name"}] + response = es_connection.execute_sql("SELECT * FROM hollywood.actors") + mock_es_sql_client.query.assert_called_once_with(body={ + "fetch_size": 100, + "field_multi_value_leniency": False, + "query": "SELECT * FROM hollywood.actors", + }) + + assert response == RESPONSE class MockElasticsearch: From 02028c623a82fe52e1a434b51ff750e78343e16b Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 10:33:43 +0100 Subject: [PATCH 03/12] refactor: Set response to None when cursor is closed --- .../src/airflow/providers/elasticsearch/hooks/elasticsearch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index fc04600af2df2..a124180fde921 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -104,7 +104,7 @@ def fetchall(self): return results def close(self): - pass + self._response = None class ESConnection: From f276abdf638f815635911cc855adf98f2f524df4 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 11:12:32 +0100 Subject: [PATCH 04/12] refactor: Reformatted test --- .../elasticsearch/hooks/test_elasticsearch.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index 55299671e2cd9..c620168db9799 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -18,22 +18,21 @@ from __future__ import annotations from unittest import mock -from unittest.mock import MagicMock, patch, create_autospec +from unittest.mock import MagicMock import pytest -from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from elasticsearch import Elasticsearch -from kgb import spy_on, SpyAgency +from elasticsearch._sync.client import SqlClient +from kgb import SpyAgency from airflow.models import Connection +from airflow.providers.common.sql.hooks.handlers import fetch_all_handler from airflow.providers.elasticsearch.hooks.elasticsearch import ( ElasticsearchPythonHook, ElasticsearchSQLCursor, ElasticsearchSQLHook, ESConnection, ) -from elasticsearch._sync.client import SqlClient - ROWS = [ [1, "Stallone", "Sylvester", "78"], @@ -44,10 +43,10 @@ ] RESPONSE = { "columns": [ - {'name': 'index', 'type': 'long'}, - {'name': 'name', 'type': 'text'}, - {'name': 'firstname', 'type': 'text'}, - {'name': 'age', 'type': 'long'}, + {"name": "index", "type": "long"}, + {"name": "name", "type": "text"}, + {"name": "firstname", "type": "text"}, + {"name": "age", "type": "long"}, ], "rows": ROWS } @@ -92,7 +91,12 @@ def test_description(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) cursor.execute("SELECT * FROM hollywood.actors") - assert cursor.description == [("index", "long"), ("name", "text"), ("firstname", "text"), ("age", "long")] + assert cursor.description == [ + ("index", "long"), + ("name", "text"), + ("firstname", "text"), + ("age", "long"), + ] def test_fetchone(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) @@ -182,11 +186,13 @@ def test_execute_sql_query(self, mock_es): es_connection = ESConnection(host="localhost", port=9200) response = es_connection.execute_sql("SELECT * FROM hollywood.actors") - mock_es_sql_client.query.assert_called_once_with(body={ - "fetch_size": 100, - "field_multi_value_leniency": False, - "query": "SELECT * FROM hollywood.actors", - }) + mock_es_sql_client.query.assert_called_once_with( + body={ + "fetch_size": 100, + "field_multi_value_leniency": False, + "query": "SELECT * FROM hollywood.actors", + } + ) assert response == RESPONSE From 0be1b26afba6a167b176a45a4f50c176195c792c Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 11:17:50 +0100 Subject: [PATCH 05/12] refactor: Reformatted ElasticsearchSQLCursor --- .../providers/elasticsearch/hooks/elasticsearch.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index a124180fde921..022c78ae99130 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -17,8 +17,9 @@ # under the License. from __future__ import annotations +from collections.abc import Iterable, Mapping from functools import cached_property -from typing import TYPE_CHECKING, Any, Iterable, Mapping +from typing import TYPE_CHECKING, Any from urllib import parse from airflow.hooks.base import BaseHook @@ -43,7 +44,7 @@ def connect( class ElasticsearchSQLCursor: - """ A PEP 249-like Cursor class for Elasticsearch SQL API """ + """A PEP 249-like Cursor class for Elasticsearch SQL API""" def __init__(self, es: Elasticsearch, options: dict[str, Any]): self.es = es @@ -77,7 +78,9 @@ def rowcount(self) -> int: def description(self) -> list[tuple]: return [(column["name"], column["type"]) for column in self.response.get("columns", [])] - def execute(self, statement: str, params: Iterable | Mapping[str, Any] | None = None) -> ObjectApiResponse: + def execute( + self, statement: str, params: Iterable | Mapping[str, Any] | None = None + ) -> ObjectApiResponse: self.body["query"] = statement if params: self.body["params"] = params @@ -141,7 +144,9 @@ def close(self): def commit(self): pass - def execute_sql(self, query: str, params: Iterable | Mapping[str, Any] | None = None) -> ObjectApiResponse: + def execute_sql( + self, query: str, params: Iterable | Mapping[str, Any] | None = None + ) -> ObjectApiResponse: return self.cursor().execute(query, params) From 3f5fbd3ffe51b87e0f6d3155137a70bd9bd37cbd Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 12:07:48 +0100 Subject: [PATCH 06/12] refactor: options parameter as dict is unsafe so changed it to kwargs in ElasticsearchSQLCursor --- .../airflow/providers/elasticsearch/hooks/elasticsearch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index 022c78ae99130..91755dfb30cad 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -46,10 +46,10 @@ def connect( class ElasticsearchSQLCursor: """A PEP 249-like Cursor class for Elasticsearch SQL API""" - def __init__(self, es: Elasticsearch, options: dict[str, Any]): + def __init__(self, es: Elasticsearch, **options): self.es = es self.body = { - "fetch_size": options.get("fetch_size", 100), + "fetch_size": options.get("fetch_size", 10000), "field_multi_value_leniency": options.get("field_multi_value_leniency", False), } self._response: ObjectApiResponse | None = None @@ -136,7 +136,7 @@ def __init__( self.es = Elasticsearch(self.url, **self.kwargs) def cursor(self) -> ElasticsearchSQLCursor: - return ElasticsearchSQLCursor(self.es, self.kwargs) + return ElasticsearchSQLCursor(self.es, **self.kwargs) def close(self): self.es.close() From d453bdf455aa7ba9d002889ef0c8a0d620a729fa Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 12:08:33 +0100 Subject: [PATCH 07/12] refactor: fixed assertion on fetch_size --- .../provider_tests/elasticsearch/hooks/test_elasticsearch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index c620168db9799..3cb94a15a7641 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -188,7 +188,7 @@ def test_execute_sql_query(self, mock_es): response = es_connection.execute_sql("SELECT * FROM hollywood.actors") mock_es_sql_client.query.assert_called_once_with( body={ - "fetch_size": 100, + "fetch_size": 10000, "field_multi_value_leniency": False, "query": "SELECT * FROM hollywood.actors", } From 8fb794395217a7288d126248379bb1a862fbb94c Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 12:09:37 +0100 Subject: [PATCH 08/12] refactor: renamed options parameter to kwargs --- .../airflow/providers/elasticsearch/hooks/elasticsearch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index 91755dfb30cad..6cdca1644db6b 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -46,11 +46,11 @@ def connect( class ElasticsearchSQLCursor: """A PEP 249-like Cursor class for Elasticsearch SQL API""" - def __init__(self, es: Elasticsearch, **options): + def __init__(self, es: Elasticsearch, **kwargs): self.es = es self.body = { - "fetch_size": options.get("fetch_size", 10000), - "field_multi_value_leniency": options.get("field_multi_value_leniency", False), + "fetch_size": kwargs.get("fetch_size", 10000), + "field_multi_value_leniency": kwargs.get("field_multi_value_leniency", False), } self._response: ObjectApiResponse | None = None From 945ad7a16fd9fe969eae26f2ff137ce684d898bf Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 12:52:10 +0100 Subject: [PATCH 09/12] refactor: Reformatted TestElasticsearchSQLHookConn --- .../provider_tests/elasticsearch/hooks/test_elasticsearch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index 3cb94a15a7641..9996a85addb66 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -48,7 +48,7 @@ {"name": "firstname", "type": "text"}, {"name": "age", "type": "long"}, ], - "rows": ROWS + "rows": ROWS, } From 78cdf14e3b280fb48770d76eb347040938e61eb0 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 12:53:52 +0100 Subject: [PATCH 10/12] refactor: Ignore type for response property --- .../src/airflow/providers/elasticsearch/hooks/elasticsearch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index 6cdca1644db6b..81d5a6d1e78be 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -56,7 +56,7 @@ def __init__(self, es: Elasticsearch, **kwargs): @property def response(self) -> ObjectApiResponse: - return self._response or {} + return self._response or {} # type: ignore @response.setter def response(self, value): From 844d9f2865cad73b39e4c79e49d5127a44af362d Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 14:29:07 +0100 Subject: [PATCH 11/12] refactor: Fixed pagination and make sure this feature is unit tested --- .../elasticsearch/hooks/elasticsearch.py | 2 +- .../elasticsearch/hooks/test_elasticsearch.py | 20 +++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index 81d5a6d1e78be..e86323bf16a6b 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -102,7 +102,7 @@ def fetchmany(self, size: int | None = None): def fetchall(self): results = self.rows while self.cursor: - self.execute(query=self.body["query"]) + self.execute(statement=self.body["query"]) results.extend(self.rows) return results diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index 9996a85addb66..eff679be10728 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -41,7 +41,7 @@ [4, "Lundgren", "Dolph", "66"], [5, "Norris", "Chuck", "84"], ] -RESPONSE = { +RESPONSE_WITHOUT_CURSOR = { "columns": [ {"name": "index", "type": "long"}, {"name": "name", "type": "text"}, @@ -50,6 +50,11 @@ ], "rows": ROWS, } +RESPONSE = {**RESPONSE_WITHOUT_CURSOR, **{"cursor": "e7f8QwXUruW2mIebzudH4BwAA//8DAA=="}} +RESPONSES = [ + RESPONSE, + RESPONSE_WITHOUT_CURSOR, +] class TestElasticsearchSQLHookConn: @@ -73,7 +78,7 @@ def test_get_conn(self, mock_connect): class TestElasticsearchSQLCursor: def setup_method(self): sql = MagicMock(spec=SqlClient) - sql.query.side_effect = lambda body: RESPONSE + sql.query.side_effect = RESPONSES self.es = MagicMock(sql=sql, spec=Elasticsearch) def test_execute(self): @@ -115,13 +120,16 @@ def test_fetchall(self): cursor = ElasticsearchSQLCursor(es=self.es, options={}) cursor.execute("SELECT * FROM hollywood.actors") - assert cursor.fetchall() == ROWS + records = cursor.fetchall() + + assert len(records) == 10 + assert records == ROWS class TestElasticsearchSQLHook: def setup_method(self): sql = MagicMock(spec=SqlClient) - sql.query.side_effect = lambda body: RESPONSE + sql.query.side_effect = RESPONSES es = MagicMock(sql=sql, spec=Elasticsearch) self.cur = ElasticsearchSQLCursor(es=es, options={}) self.spy_agency = SpyAgency() @@ -181,7 +189,7 @@ def test_run(self): @mock.patch("airflow.providers.elasticsearch.hooks.elasticsearch.Elasticsearch") def test_execute_sql_query(self, mock_es): mock_es_sql_client = MagicMock() - mock_es_sql_client.query.return_value = RESPONSE + mock_es_sql_client.query.return_value = RESPONSE_WITHOUT_CURSOR mock_es.return_value.sql = mock_es_sql_client es_connection = ESConnection(host="localhost", port=9200) @@ -194,7 +202,7 @@ def test_execute_sql_query(self, mock_es): } ) - assert response == RESPONSE + assert response == RESPONSE_WITHOUT_CURSOR class MockElasticsearch: From 8b12bd44e82dc2e5006d1866d74caeb56c499080 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 5 Feb 2025 14:58:37 +0100 Subject: [PATCH 12/12] refactor: Decreased fetch_size to 1000 as 10000 could lead to timeouts --- .../src/airflow/providers/elasticsearch/hooks/elasticsearch.py | 2 +- .../provider_tests/elasticsearch/hooks/test_elasticsearch.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py index e86323bf16a6b..582e4abdb9e12 100644 --- a/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py +++ b/providers/elasticsearch/src/airflow/providers/elasticsearch/hooks/elasticsearch.py @@ -49,7 +49,7 @@ class ElasticsearchSQLCursor: def __init__(self, es: Elasticsearch, **kwargs): self.es = es self.body = { - "fetch_size": kwargs.get("fetch_size", 10000), + "fetch_size": kwargs.get("fetch_size", 1000), "field_multi_value_leniency": kwargs.get("field_multi_value_leniency", False), } self._response: ObjectApiResponse | None = None diff --git a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py index eff679be10728..953e7dd50ef72 100644 --- a/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py +++ b/providers/elasticsearch/tests/provider_tests/elasticsearch/hooks/test_elasticsearch.py @@ -196,7 +196,7 @@ def test_execute_sql_query(self, mock_es): response = es_connection.execute_sql("SELECT * FROM hollywood.actors") mock_es_sql_client.query.assert_called_once_with( body={ - "fetch_size": 10000, + "fetch_size": 1000, "field_multi_value_leniency": False, "query": "SELECT * FROM hollywood.actors", }