Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6336,6 +6336,8 @@ paths:
summary: Delete Variable
description: Delete a variable entry.
operationId: delete_variable
security:
- OAuth2PasswordBearer: []
parameters:
- name: variable_key
in: path
Expand Down Expand Up @@ -6376,6 +6378,8 @@ paths:
summary: Get Variable
description: Get a variable entry.
operationId: get_variable
security:
- OAuth2PasswordBearer: []
parameters:
- name: variable_key
in: path
Expand Down Expand Up @@ -6420,6 +6424,8 @@ paths:
summary: Patch Variable
description: Update a variable by key.
operationId: patch_variable
security:
- OAuth2PasswordBearer: []
parameters:
- name: variable_key
in: path
Expand Down Expand Up @@ -6487,6 +6493,8 @@ paths:
summary: Get Variables
description: Get all Variables entries.
operationId: get_variables
security:
- OAuth2PasswordBearer: []
parameters:
- name: limit
in: query
Expand Down Expand Up @@ -6550,6 +6558,8 @@ paths:
summary: Post Variable
description: Create a variable.
operationId: post_variable
security:
- OAuth2PasswordBearer: []
requestBody:
required: true
content:
Expand Down Expand Up @@ -6593,6 +6603,8 @@ paths:
summary: Bulk Variables
description: Bulk create, update, and delete variables.
operationId: bulk_variables
security:
- OAuth2PasswordBearer: []
requestBody:
required: true
content:
Expand Down
9 changes: 7 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
VariableResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_variable
from airflow.api_fastapi.core_api.services.public.variables import BulkVariableService
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.models.variable import Variable
Expand All @@ -49,6 +50,7 @@
"/{variable_key}",
status_code=status.HTTP_204_NO_CONTENT,
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_variable("DELETE"))],
)
def delete_variable(
variable_key: str,
Expand All @@ -64,6 +66,7 @@ def delete_variable(
@variables_router.get(
"/{variable_key}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_variable("GET"))],
)
def get_variable(
variable_key: str,
Expand All @@ -82,6 +85,7 @@ def get_variable(

@variables_router.get(
"",
dependencies=[Depends(requires_access_variable("GET"))],
)
def get_variables(
limit: QueryLimit,
Expand Down Expand Up @@ -124,6 +128,7 @@ def get_variables(
status.HTTP_404_NOT_FOUND,
]
),
dependencies=[Depends(requires_access_variable("PUT"))],
)
def patch_variable(
variable_key: str,
Expand Down Expand Up @@ -164,7 +169,7 @@ def patch_variable(
"",
status_code=status.HTTP_201_CREATED,
responses=create_openapi_http_exception_doc([status.HTTP_409_CONFLICT]),
dependencies=[Depends(action_logging())],
dependencies=[Depends(action_logging()), Depends(requires_access_variable("POST"))],
)
def post_variable(
post_body: VariableBody,
Expand All @@ -186,7 +191,7 @@ def post_variable(
return variable


@variables_router.patch("")
@variables_router.patch("", dependencies=[Depends(requires_access_variable("PUT"))])
def bulk_variables(
request: BulkBody[VariableBody],
session: SessionDep,
Expand Down
17 changes: 17 additions & 0 deletions airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
DagAccessEntity,
DagDetails,
PoolDetails,
VariableDetails,
)
from airflow.configuration import conf
from airflow.utils.jwt_signer import JWTSigner, get_signing_key
Expand Down Expand Up @@ -130,6 +131,22 @@ def callback():
return inner


def requires_access_variable(method: ResourceMethod) -> Callable[[Request, BaseUser | None], None]:
def inner(
request: Request,
user: Annotated[BaseUser | None, Depends(get_user)] = None,
) -> None:
variable_key: str | None = request.path_params.get("variable_key")

_requires_access(
is_authorized_callback=lambda: get_auth_manager().is_authorized_variable(
method=method, details=VariableDetails(key=variable_key), user=user
),
)

return inner


def _requires_access(
*,
is_authorized_callback: Callable[[], bool],
Expand Down
68 changes: 68 additions & 0 deletions tests/api_fastapi/core_api/routes/public/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ def test_delete_should_respond_204(self, test_client, session):
variables = session.query(Variable).all()
assert len(variables) == 3

def test_delete_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.delete(f"/public/variables/{TEST_VARIABLE_KEY}")
assert response.status_code == 401

def test_delete_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.delete(f"/public/variables/{TEST_VARIABLE_KEY}")
assert response.status_code == 403

def test_delete_should_respond_404(self, test_client):
response = test_client.delete(f"/public/variables/{TEST_VARIABLE_KEY}")
assert response.status_code == 404
Expand Down Expand Up @@ -163,6 +171,14 @@ def test_get_should_respond_200(self, test_client, session, key, expected_respon
assert response.status_code == 200
assert response.json() == expected_response

def test_get_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.get(f"/public/variables/{TEST_VARIABLE_KEY}")
assert response.status_code == 401

def test_get_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.get(f"/public/variables/{TEST_VARIABLE_KEY}")
assert response.status_code == 403

def test_get_should_respond_404(self, test_client):
response = test_client.get(f"/public/variables/{TEST_VARIABLE_KEY}")
assert response.status_code == 404
Expand Down Expand Up @@ -220,6 +236,14 @@ def test_should_respond_200(
assert body["total_entries"] == expected_total_entries
assert [variable["key"] for variable in body["variables"]] == expected_keys

def test_get_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.get("/public/variables")
assert response.status_code == 401

def test_get_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.get("/public/variables")
assert response.status_code == 403


class TestPatchVariable(TestVariableEndpoint):
@pytest.mark.enable_redact
Expand Down Expand Up @@ -303,6 +327,20 @@ def test_patch_should_respond_400(self, test_client):
body = response.json()
assert body["detail"] == "Invalid body, key from request body doesn't match uri parameter"

def test_patch_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.patch(
f"/public/variables/{TEST_VARIABLE_KEY}",
json={"key": TEST_VARIABLE_KEY, "value": "some_value", "description": None},
)
assert response.status_code == 401

def test_patch_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.patch(
f"/public/variables/{TEST_VARIABLE_KEY}",
json={"key": TEST_VARIABLE_KEY, "value": "some_value", "description": None},
)
assert response.status_code == 403

def test_patch_should_respond_404(self, test_client):
response = test_client.patch(
f"/public/variables/{TEST_VARIABLE_KEY}",
Expand Down Expand Up @@ -378,6 +416,28 @@ def test_post_should_respond_201(self, test_client, session, body, expected_resp
assert response.status_code == 201
assert response.json() == expected_response

def test_post_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.post(
"/public/variables",
json={
"key": "new variable key",
"value": "new variable value",
"description": "new variable description",
},
)
assert response.status_code == 401

def test_post_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.post(
"/public/variables",
json={
"key": "new variable key",
"value": "new variable value",
"description": "new variable description",
},
)
assert response.status_code == 403

def test_post_should_respond_409_when_key_exists(self, test_client, session):
self.create_variables()
# Attempting to post a variable with an existing key
Expand Down Expand Up @@ -927,3 +987,11 @@ def test_bulk_variables(self, test_client, actions, expected_results):
response_data = response.json()
for key, value in expected_results.items():
assert response_data[key] == value

def test_bulk_variables_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.patch("/public/variables", json={})
assert response.status_code == 401

def test_bulk_variables_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.patch("/public/variables", json={})
assert response.status_code == 403