diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 259efb6cd4940..e6e27e65b3f34 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -6336,6 +6336,8 @@ paths: summary: Delete Variable description: Delete a variable entry. operationId: delete_variable + security: + - OAuth2PasswordBearer: [] parameters: - name: variable_key in: path @@ -6376,6 +6378,8 @@ paths: summary: Get Variable description: Get a variable entry. operationId: get_variable + security: + - OAuth2PasswordBearer: [] parameters: - name: variable_key in: path @@ -6420,6 +6424,8 @@ paths: summary: Patch Variable description: Update a variable by key. operationId: patch_variable + security: + - OAuth2PasswordBearer: [] parameters: - name: variable_key in: path @@ -6487,6 +6493,8 @@ paths: summary: Get Variables description: Get all Variables entries. operationId: get_variables + security: + - OAuth2PasswordBearer: [] parameters: - name: limit in: query @@ -6550,6 +6558,8 @@ paths: summary: Post Variable description: Create a variable. operationId: post_variable + security: + - OAuth2PasswordBearer: [] requestBody: required: true content: @@ -6593,6 +6603,8 @@ paths: summary: Bulk Variables description: Bulk create, update, and delete variables. operationId: bulk_variables + security: + - OAuth2PasswordBearer: [] requestBody: required: true content: diff --git a/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow/api_fastapi/core_api/routes/public/variables.py index e8898ce69570a..b1b9d232b85d3 100644 --- a/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow/api_fastapi/core_api/routes/public/variables.py @@ -38,6 +38,7 @@ VariableResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import requires_access_variable from airflow.api_fastapi.core_api.services.public.variables import BulkVariableService from airflow.api_fastapi.logging.decorators import action_logging from airflow.models.variable import Variable @@ -49,6 +50,7 @@ "/{variable_key}", status_code=status.HTTP_204_NO_CONTENT, responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_variable("DELETE"))], ) def delete_variable( variable_key: str, @@ -64,6 +66,7 @@ def delete_variable( @variables_router.get( "/{variable_key}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_variable("GET"))], ) def get_variable( variable_key: str, @@ -82,6 +85,7 @@ def get_variable( @variables_router.get( "", + dependencies=[Depends(requires_access_variable("GET"))], ) def get_variables( limit: QueryLimit, @@ -124,6 +128,7 @@ def get_variables( status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_variable("PUT"))], ) def patch_variable( variable_key: str, @@ -164,7 +169,7 @@ def patch_variable( "", status_code=status.HTTP_201_CREATED, responses=create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]), - dependencies=[Depends(action_logging())], + dependencies=[Depends(action_logging()), Depends(requires_access_variable("POST"))], ) def post_variable( post_body: VariableBody, @@ -186,7 +191,7 @@ def post_variable( return variable -@variables_router.patch("") +@variables_router.patch("", dependencies=[Depends(requires_access_variable("PUT"))]) def bulk_variables( request: BulkBody[VariableBody], session: SessionDep, diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 2b267beaa8e0b..b1da04d6791b7 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -30,6 +30,7 @@ DagAccessEntity, DagDetails, PoolDetails, + VariableDetails, ) from airflow.configuration import conf from airflow.utils.jwt_signer import JWTSigner, get_signing_key @@ -130,6 +131,22 @@ def callback(): return inner +def requires_access_variable(method: ResourceMethod) -> Callable[[Request, BaseUser | None], None]: + def inner( + request: Request, + user: Annotated[BaseUser | None, Depends(get_user)] = None, + ) -> None: + variable_key: str | None = request.path_params.get("variable_key") + + _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_variable( + method=method, details=VariableDetails(key=variable_key), user=user + ), + ) + + return inner + + def _requires_access( *, is_authorized_callback: Callable[[], bool], diff --git a/tests/api_fastapi/core_api/routes/public/test_variables.py b/tests/api_fastapi/core_api/routes/public/test_variables.py index c5af61fc035fa..b9837f4d25312 100644 --- a/tests/api_fastapi/core_api/routes/public/test_variables.py +++ b/tests/api_fastapi/core_api/routes/public/test_variables.py @@ -107,6 +107,14 @@ def test_delete_should_respond_204(self, test_client, session): variables = session.query(Variable).all() assert len(variables) == 3 + def test_delete_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.delete(f"/public/variables/{TEST_VARIABLE_KEY}") + assert response.status_code == 401 + + def test_delete_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.delete(f"/public/variables/{TEST_VARIABLE_KEY}") + assert response.status_code == 403 + def test_delete_should_respond_404(self, test_client): response = test_client.delete(f"/public/variables/{TEST_VARIABLE_KEY}") assert response.status_code == 404 @@ -163,6 +171,14 @@ def test_get_should_respond_200(self, test_client, session, key, expected_respon assert response.status_code == 200 assert response.json() == expected_response + def test_get_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"/public/variables/{TEST_VARIABLE_KEY}") + assert response.status_code == 401 + + def test_get_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"/public/variables/{TEST_VARIABLE_KEY}") + assert response.status_code == 403 + def test_get_should_respond_404(self, test_client): response = test_client.get(f"/public/variables/{TEST_VARIABLE_KEY}") assert response.status_code == 404 @@ -220,6 +236,14 @@ def test_should_respond_200( assert body["total_entries"] == expected_total_entries assert [variable["key"] for variable in body["variables"]] == expected_keys + def test_get_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get("/public/variables") + assert response.status_code == 401 + + def test_get_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get("/public/variables") + assert response.status_code == 403 + class TestPatchVariable(TestVariableEndpoint): @pytest.mark.enable_redact @@ -303,6 +327,20 @@ def test_patch_should_respond_400(self, test_client): body = response.json() assert body["detail"] == "Invalid body, key from request body doesn't match uri parameter" + def test_patch_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch( + f"/public/variables/{TEST_VARIABLE_KEY}", + json={"key": TEST_VARIABLE_KEY, "value": "some_value", "description": None}, + ) + assert response.status_code == 401 + + def test_patch_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch( + f"/public/variables/{TEST_VARIABLE_KEY}", + json={"key": TEST_VARIABLE_KEY, "value": "some_value", "description": None}, + ) + assert response.status_code == 403 + def test_patch_should_respond_404(self, test_client): response = test_client.patch( f"/public/variables/{TEST_VARIABLE_KEY}", @@ -378,6 +416,28 @@ def test_post_should_respond_201(self, test_client, session, body, expected_resp assert response.status_code == 201 assert response.json() == expected_response + def test_post_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post( + "/public/variables", + json={ + "key": "new variable key", + "value": "new variable value", + "description": "new variable description", + }, + ) + assert response.status_code == 401 + + def test_post_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post( + "/public/variables", + json={ + "key": "new variable key", + "value": "new variable value", + "description": "new variable description", + }, + ) + assert response.status_code == 403 + def test_post_should_respond_409_when_key_exists(self, test_client, session): self.create_variables() # Attempting to post a variable with an existing key @@ -927,3 +987,11 @@ def test_bulk_variables(self, test_client, actions, expected_results): response_data = response.json() for key, value in expected_results.items(): assert response_data[key] == value + + def test_bulk_variables_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch("/public/variables", json={}) + assert response.status_code == 401 + + def test_bulk_variables_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch("/public/variables", json={}) + assert response.status_code == 403