diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 9b8408980de709..ce18908e40b6a3 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,4 +1,5 @@ import csv +import hmac import io from collections.abc import Callable from functools import wraps @@ -77,7 +78,7 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R: auth_token = extract_access_token(request) if not auth_token: raise Unauthorized("Authorization header is missing.") - if auth_token != dify_config.ADMIN_API_KEY: + if not hmac.compare_digest(auth_token, dify_config.ADMIN_API_KEY): raise Unauthorized("API key is invalid.") return view(*args, **kwargs) diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index b55cda42448fdb..cd5e3b6b532ee8 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -1,3 +1,4 @@ +import hmac from collections.abc import Callable from functools import wraps from typing import Concatenate @@ -163,7 +164,11 @@ def post(self, oauth_provider_app: OAuthProviderApp): if not payload.code: raise BadRequest("code is required") - if payload.client_secret != oauth_provider_app.client_secret: + if ( + not payload.client_secret + or not oauth_provider_app.client_secret + or not hmac.compare_digest(payload.client_secret, oauth_provider_app.client_secret) + ): raise BadRequest("client_secret is invalid") if payload.redirect_uri not in oauth_provider_app.redirect_uris: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index efa46c9779ab8a..c8f516788dd4d5 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,3 +1,5 @@ +import hmac + from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field @@ -125,7 +127,7 @@ def post(self, id): extension_data_from_db.name = payload.name extension_data_from_db.api_endpoint = payload.api_endpoint - if payload.api_key != HIDDEN_VALUE: + if not hmac.compare_digest(payload.api_key, HIDDEN_VALUE): extension_data_from_db.api_key = payload.api_key return APIBasedExtensionService.save(extension_data_from_db) diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index f086bf186226e4..44dcdbd3d405e5 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -1,3 +1,4 @@ +import hmac import os from typing import Literal @@ -54,7 +55,8 @@ def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse if tenant_count > 0: raise AlreadySetupError() - if payload.password != os.environ.get("INIT_PASSWORD"): + init_password = os.environ.get("INIT_PASSWORD") + if not init_password or not hmac.compare_digest(payload.password, init_password): session["is_init_validated"] = False raise InitValidateFailedError() diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 874fd8a7e375e3..5dd29268228f21 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -2,6 +2,7 @@ from collections.abc import Callable from functools import wraps from hashlib import sha1 +from hmac import compare_digest as hmac_compare_digest from hmac import new as hmac_new from flask import abort, request @@ -19,7 +20,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # get header 'X-Inner-Api-Key' inner_api_key = request.headers.get("X-Inner-Api-Key") - if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: + if ( + not inner_api_key + or not dify_config.INNER_API_KEY + or not hmac_compare_digest(inner_api_key, dify_config.INNER_API_KEY) + ): abort(401) return view(*args, **kwargs) @@ -35,7 +40,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # get header 'X-Inner-Api-Key' inner_api_key = request.headers.get("X-Inner-Api-Key") - if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: + if ( + not inner_api_key + or not dify_config.INNER_API_KEY + or not hmac_compare_digest(inner_api_key, dify_config.INNER_API_KEY) + ): abort(401) return view(*args, **kwargs) @@ -69,10 +78,8 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R: signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) signature_base64 = b64encode(signature.digest()).decode("utf-8") - if signature_base64 != token: - return view(*args, **kwargs) - - kwargs["user"] = db.session.get(EndUser, user_id) + if hmac_compare_digest(signature_base64, token): + kwargs["user"] = db.session.get(EndUser, user_id) return view(*args, **kwargs) @@ -87,7 +94,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # get header 'X-Inner-Api-Key' inner_api_key = request.headers.get("X-Inner-Api-Key") - if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN: + if ( + not inner_api_key + or not dify_config.INNER_API_KEY_FOR_PLUGIN + or not hmac_compare_digest(inner_api_key, dify_config.INNER_API_KEY_FOR_PLUGIN) + ): abort(404) return view(*args, **kwargs) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index bc59eaca635aef..b7db0f95724d3c 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,3 +1,4 @@ +import hmac import json from typing import cast @@ -55,7 +56,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non # Check for admin API key authentication first if dify_config.ADMIN_API_KEY_ENABLE and auth_token: admin_api_key = dify_config.ADMIN_API_KEY - if admin_api_key and admin_api_key == auth_token: + if admin_api_key and hmac.compare_digest(admin_api_key, auth_token): workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: tenant_account_join = db.session.execute( diff --git a/api/libs/token.py b/api/libs/token.py index a34db7076496b8..777df2dbd786bc 100644 --- a/api/libs/token.py +++ b/api/libs/token.py @@ -1,3 +1,4 @@ +import hmac import logging import re from datetime import UTC, datetime, timedelta @@ -191,7 +192,7 @@ def check_csrf_token(request: Request, user_id: str): # since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required. if dify_config.ADMIN_API_KEY_ENABLE: auth_token = extract_access_token(request) - if auth_token and auth_token == dify_config.ADMIN_API_KEY: + if auth_token and dify_config.ADMIN_API_KEY and hmac.compare_digest(auth_token, dify_config.ADMIN_API_KEY): return def _unauthorized(): diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index deb26438a8f7f7..e60db130878aa6 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -1,4 +1,5 @@ import hashlib +import hmac import json import logging from collections.abc import Mapping @@ -792,13 +793,17 @@ def _merge_credentials_with_masked( # Check if client_id is masked and unchanged final_client_id = client_id - if existing_masked.get("client_id") and client_id == existing_masked["client_id"]: + if existing_masked.get("client_id") and hmac.compare_digest(client_id, existing_masked["client_id"]): # Use existing decrypted value final_client_id = existing_decrypted.get("client_id", client_id) # Check if client_secret is masked and unchanged final_client_secret = client_secret - if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]: + if ( + existing_masked.get("client_secret") + and client_secret is not None + and hmac.compare_digest(client_secret, existing_masked["client_secret"]) + ): # Use existing decrypted value final_client_secret = existing_decrypted.get("client_secret", client_secret) diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index 1c4c6a899f63a0..dd111532fd4fac 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -185,6 +185,38 @@ def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine): with pytest.raises(ValueError): method(api, "b1", "disable") + def test_patch_binding_scoped_to_current_tenant(self, app, patch_tenant, mock_engine): + """Verify that the patch query includes tenant_id to prevent IDOR attacks.""" + + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=True) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class, + patch("controllers.console.datasets.data_source.db.session.add"), + patch("controllers.console.datasets.data_source.db.session.commit"), + ): + mock_session = MagicMock() + mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + method(api, "b1", "enable") + + # Inspect the SELECT statement passed to session.execute + call_args = mock_session.execute.call_args + stmt = call_args[0][0] + + # Structurally verify tenant_id filtering instead of brittle SQL string matching + where_clause = stmt.whereclause + assert where_clause is not None, "The query must have a WHERE clause" + + # Check that tenant_id appears in the WHERE predicates + where_str = str(where_clause.compile(compile_kwargs={"literal_binds": True})) + assert "tenant_id" in where_str, "The patch query must filter by tenant_id to prevent IDOR vulnerabilities" + class TestDataSourceNotionListApi: @pytest.fixture diff --git a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py index efe1841f08e94b..84648334c3a4d6 100644 --- a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py +++ b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py @@ -249,8 +249,8 @@ def protected_view(**kwargs): headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key} ): with patch.object(dify_config, "INNER_API", True): - with patch("controllers.inner_api.wraps.db.session.get") as mock_get: - mock_get.return_value = mock_user + with patch("controllers.inner_api.wraps.db.session") as mock_session: + mock_session.get.return_value = mock_user result = protected_view() # Assert