From 1fb4ba41d5ccb9780d3442e7736f6c680c9b96a0 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 23 Feb 2023 00:33:06 +0100 Subject: [PATCH 1/4] fix update_mask in patch variable route --- .../endpoints/variable_endpoint.py | 31 +++++++++++++------ .../endpoints/test_variable_endpoint.py | 16 +++++++--- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index 3111ff18d424d..e1be34b5f438d 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -88,13 +88,19 @@ def get_variables( @security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE)]) +@provide_session @action_logging( event=action_event_from_permission( prefix=RESOURCE_EVENT_PREFIX, permission=permissions.ACTION_CAN_EDIT, ), ) -def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Response: +def patch_variable( + *, + variable_key: str, + update_mask: UpdateMask = None, + session: Session = NEW_SESSION, +) -> Response: """Update a variable by key.""" try: data = variable_schema.load(get_json_request_dict()) @@ -103,15 +109,22 @@ def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Resp if data["key"] != variable_key: raise BadRequest("Invalid post body", detail="key from request body doesn't match uri parameter") - + non_update_fields = ["key"] + variable = session.query(Variable).filter_by(key=variable_key).first() if update_mask: - if "key" in update_mask: - raise BadRequest("key is a ready only field") - if "value" not in update_mask: - raise BadRequest("No field to update") - - Variable.set(data["key"], data["val"]) - return variable_schema.dump(data) + update_mask = [i.strip() for i in update_mask] + data_ = {} + for field in update_mask: + if field in data and field not in non_update_fields: + data_[field] = data[field] + else: + raise BadRequest(detail=f"'{field}' is unknown or cannot be updated.") + data = data_ + for key in data: + setattr(variable, key, data[key]) + session.add(variable) + session.commit() + return variable_schema.dump(variable) @security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE)]) diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py b/tests/api_connexion/endpoints/test_variable_endpoint.py index 35f28c43b3b17..7c6c55d783978 100644 --- a/tests/api_connexion/endpoints/test_variable_endpoint.py +++ b/tests/api_connexion/endpoints/test_variable_endpoint.py @@ -229,10 +229,18 @@ def test_should_update_variable(self, session): environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 200 - assert response.json == { - "key": "var1", - "value": "updated", - } + assert response.json == {"key": "var1", "value": "updated", "description": None} + _check_last_log(session, dag_id=None, event="variable.edit", execution_date=None) + + def test_should_update_variable_with_mask(self, session): + Variable.set("var1", "foo", description="before update") + response = self.client.patch( + "/api/v1/variables/var1?update_mask=description", + json={"key": "var1", "value": "updated", "description": "after_update"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 200 + assert response.json == {"key": "var1", "value": "foo", "description": "after_update"} _check_last_log(session, dag_id=None, event="variable.edit", execution_date=None) def test_should_reject_invalid_update(self): From 3f6fa98ee7a32d9a0bcae5490b31005eef1b01c4 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 23 Feb 2023 08:23:41 +0100 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Tzu-ping Chung --- airflow/api_connexion/endpoints/variable_endpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index e1be34b5f438d..b7198c68683f1 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -112,16 +112,16 @@ def patch_variable( non_update_fields = ["key"] variable = session.query(Variable).filter_by(key=variable_key).first() if update_mask: - update_mask = [i.strip() for i in update_mask] data_ = {} for field in update_mask: + field = field.strip() if field in data and field not in non_update_fields: data_[field] = data[field] else: raise BadRequest(detail=f"'{field}' is unknown or cannot be updated.") data = data_ - for key in data: - setattr(variable, key, data[key]) + for key, val in data.items(): + setattr(variable, key, val) session.add(variable) session.commit() return variable_schema.dump(variable) From b27103af2a38f03d3d920d5e39a3fc9a46d39987 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 24 Feb 2023 23:34:28 +0100 Subject: [PATCH 3/4] create a new method to extract update mask data --- .../endpoints/connection_endpoint.py | 10 +--- .../api_connexion/endpoints/update_mask.py | 34 +++++++++++ .../endpoints/variable_endpoint.py | 10 +--- .../endpoints/test_update_mask.py | 56 +++++++++++++++++++ 4 files changed, 94 insertions(+), 16 deletions(-) create mode 100644 airflow/api_connexion/endpoints/update_mask.py create mode 100644 tests/api_connexion/endpoints/test_update_mask.py diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index 03edf24bd887c..7e2a7dd8da42d 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -26,6 +26,7 @@ from sqlalchemy.orm import Session from airflow.api_connexion import security +from airflow.api_connexion.endpoints.update_mask import extract_update_mask_data from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.connection_schema import ( @@ -132,14 +133,7 @@ def patch_connection( if data.get("conn_id") and connection.conn_id != data["conn_id"]: raise BadRequest(detail="The connection_id cannot be updated.") if update_mask: - update_mask = [i.strip() for i in update_mask] - data_ = {} - for field in update_mask: - if field in data and field not in non_update_fields: - data_[field] = data[field] - else: - raise BadRequest(detail=f"'{field}' is unknown or cannot be updated.") - data = data_ + data = extract_update_mask_data(update_mask, non_update_fields, data) for key in data: setattr(connection, key, data[key]) session.add(connection) diff --git a/airflow/api_connexion/endpoints/update_mask.py b/airflow/api_connexion/endpoints/update_mask.py new file mode 100644 index 0000000000000..38fd255f51b3a --- /dev/null +++ b/airflow/api_connexion/endpoints/update_mask.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, Mapping, Sequence + +from airflow.api_connexion.exceptions import BadRequest + + +def extract_update_mask_data( + update_mask: Sequence[str], non_update_fields: list[str], data: Mapping[str, Any] +) -> Mapping[str, Any]: + extracted_data = {} + for field in update_mask: + field = field.strip() + if field in data and field not in non_update_fields: + extracted_data[field] = data[field] + else: + raise BadRequest(detail=f"'{field}' is unknown or cannot be updated.") + return extracted_data diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index b7198c68683f1..e0f4e5d3bac04 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -25,6 +25,7 @@ from airflow.api_connexion import security from airflow.api_connexion.endpoints.request_dict import get_json_request_dict +from airflow.api_connexion.endpoints.update_mask import extract_update_mask_data from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.variable_schema import variable_collection_schema, variable_schema @@ -112,14 +113,7 @@ def patch_variable( non_update_fields = ["key"] variable = session.query(Variable).filter_by(key=variable_key).first() if update_mask: - data_ = {} - for field in update_mask: - field = field.strip() - if field in data and field not in non_update_fields: - data_[field] = data[field] - else: - raise BadRequest(detail=f"'{field}' is unknown or cannot be updated.") - data = data_ + data = extract_update_mask_data(update_mask, non_update_fields, data) for key, val in data.items(): setattr(variable, key, val) session.add(variable) diff --git a/tests/api_connexion/endpoints/test_update_mask.py b/tests/api_connexion/endpoints/test_update_mask.py new file mode 100644 index 0000000000000..4221f11a1319a --- /dev/null +++ b/tests/api_connexion/endpoints/test_update_mask.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.api_connexion.endpoints.update_mask import extract_update_mask_data +from airflow.api_connexion.exceptions import BadRequest + + +class TestUpdateMask: + def test_should_extract_data(self): + non_update_fields = ["field_1"] + update_mask = ["field_2"] + data = { + "field_1": "value_1", + "field_2": "value_2", + "field_3": "value_3", + } + data = extract_update_mask_data(update_mask, non_update_fields, data) + assert data == {"field_2": "value_2"} + + def test_update_forbid_field_should_raise_exception(self): + non_update_fields = ["field_1"] + update_mask = ["field_1", "field_2"] + data = { + "field_1": "value_1", + "field_2": "value_2", + "field_3": "value_3", + } + with pytest.raises(BadRequest): + extract_update_mask_data(update_mask, non_update_fields, data) + + def test_update_unknown_field_should_raise_exception(self): + non_update_fields = ["field_1"] + update_mask = ["field_2", "field_3"] + data = { + "field_1": "value_1", + "field_2": "value_2", + } + with pytest.raises(BadRequest): + extract_update_mask_data(update_mask, non_update_fields, data) From ae8ef83466354d909e2f7e69e22ae8d7b7fa21e8 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 11 Mar 2023 02:36:29 +0100 Subject: [PATCH 4/4] remove useless commit --- airflow/api_connexion/endpoints/variable_endpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index e0f4e5d3bac04..da8f35fcb8075 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -117,7 +117,6 @@ def patch_variable( for key, val in data.items(): setattr(variable, key, val) session.add(variable) - session.commit() return variable_schema.dump(variable)