Skip to content
3 changes: 2 additions & 1 deletion api/controllers/console/admin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import hmac
import io
from collections.abc import Callable
from functools import wraps
Expand Down Expand Up @@ -77,7 +78,7 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
auth_token = extract_access_token(request)
if not auth_token:
raise Unauthorized("Authorization header is missing.")
if auth_token != dify_config.ADMIN_API_KEY:
if not hmac.compare_digest(auth_token, dify_config.ADMIN_API_KEY):
raise Unauthorized("API key is invalid.")

return view(*args, **kwargs)
Expand Down
7 changes: 6 additions & 1 deletion api/controllers/console/auth/oauth_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hmac
from collections.abc import Callable
from functools import wraps
from typing import Concatenate
Expand Down Expand Up @@ -163,7 +164,11 @@ def post(self, oauth_provider_app: OAuthProviderApp):
if not payload.code:
raise BadRequest("code is required")

if payload.client_secret != oauth_provider_app.client_secret:
if (
not payload.client_secret
or not oauth_provider_app.client_secret
or not hmac.compare_digest(payload.client_secret, oauth_provider_app.client_secret)
):
raise BadRequest("client_secret is invalid")

if payload.redirect_uri not in oauth_provider_app.redirect_uris:
Expand Down
4 changes: 3 additions & 1 deletion api/controllers/console/extension.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hmac

from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field
Expand Down Expand Up @@ -125,7 +127,7 @@ def post(self, id):
extension_data_from_db.name = payload.name
extension_data_from_db.api_endpoint = payload.api_endpoint

if payload.api_key != HIDDEN_VALUE:
if not hmac.compare_digest(payload.api_key, HIDDEN_VALUE):
extension_data_from_db.api_key = payload.api_key

return APIBasedExtensionService.save(extension_data_from_db)
Expand Down
4 changes: 3 additions & 1 deletion api/controllers/console/init_validate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hmac
import os
from typing import Literal

Expand Down Expand Up @@ -54,7 +55,8 @@ def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse
if tenant_count > 0:
raise AlreadySetupError()

if payload.password != os.environ.get("INIT_PASSWORD"):
init_password = os.environ.get("INIT_PASSWORD")
if not init_password or not hmac.compare_digest(payload.password, init_password):
session["is_init_validated"] = False
raise InitValidateFailedError()

Expand Down
25 changes: 18 additions & 7 deletions api/controllers/inner_api/wraps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import compare_digest as hmac_compare_digest
from hmac import new as hmac_new

from flask import abort, request
Expand All @@ -19,7 +20,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R:

# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
if (
not inner_api_key
or not dify_config.INNER_API_KEY
or not hmac_compare_digest(inner_api_key, dify_config.INNER_API_KEY)
):
abort(401)

return view(*args, **kwargs)
Expand All @@ -35,7 +40,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R:

# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
if (
not inner_api_key
or not dify_config.INNER_API_KEY
or not hmac_compare_digest(inner_api_key, dify_config.INNER_API_KEY)
):
abort(401)

return view(*args, **kwargs)
Expand Down Expand Up @@ -69,10 +78,8 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R:
signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1)
signature_base64 = b64encode(signature.digest()).decode("utf-8")

if signature_base64 != token:
return view(*args, **kwargs)

kwargs["user"] = db.session.get(EndUser, user_id)
if hmac_compare_digest(signature_base64, token):
kwargs["user"] = db.session.get(EndUser, user_id)

return view(*args, **kwargs)

Expand All @@ -87,7 +94,11 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> R:

# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
if (
not inner_api_key
or not dify_config.INNER_API_KEY_FOR_PLUGIN
or not hmac_compare_digest(inner_api_key, dify_config.INNER_API_KEY_FOR_PLUGIN)
):
abort(404)

return view(*args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion api/extensions/ext_login.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hmac
import json
from typing import cast

Expand Down Expand Up @@ -55,7 +56,7 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
# Check for admin API key authentication first
if dify_config.ADMIN_API_KEY_ENABLE and auth_token:
admin_api_key = dify_config.ADMIN_API_KEY
if admin_api_key and admin_api_key == auth_token:
if admin_api_key and hmac.compare_digest(admin_api_key, auth_token):
workspace_id = request.headers.get("X-WORKSPACE-ID")
if workspace_id:
tenant_account_join = db.session.execute(
Expand Down
3 changes: 2 additions & 1 deletion api/libs/token.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hmac
import logging
import re
from datetime import UTC, datetime, timedelta
Expand Down Expand Up @@ -191,7 +192,7 @@ def check_csrf_token(request: Request, user_id: str):
# since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required.
if dify_config.ADMIN_API_KEY_ENABLE:
auth_token = extract_access_token(request)
if auth_token and auth_token == dify_config.ADMIN_API_KEY:
if auth_token and dify_config.ADMIN_API_KEY and hmac.compare_digest(auth_token, dify_config.ADMIN_API_KEY):
return

def _unauthorized():
Expand Down
9 changes: 7 additions & 2 deletions api/services/tools/mcp_tools_manage_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import hmac
import json
import logging
from collections.abc import Mapping
Expand Down Expand Up @@ -792,13 +793,17 @@ def _merge_credentials_with_masked(

# Check if client_id is masked and unchanged
final_client_id = client_id
if existing_masked.get("client_id") and client_id == existing_masked["client_id"]:
if existing_masked.get("client_id") and hmac.compare_digest(client_id, existing_masked["client_id"]):
# Use existing decrypted value
final_client_id = existing_decrypted.get("client_id", client_id)

# Check if client_secret is masked and unchanged
final_client_secret = client_secret
if existing_masked.get("client_secret") and client_secret == existing_masked["client_secret"]:
if (
existing_masked.get("client_secret")
and client_secret is not None
and hmac.compare_digest(client_secret, existing_masked["client_secret"])
):
# Use existing decrypted value
final_client_secret = existing_decrypted.get("client_secret", client_secret)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,38 @@ def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine):
with pytest.raises(ValueError):
method(api, "b1", "disable")

def test_patch_binding_scoped_to_current_tenant(self, app, patch_tenant, mock_engine):
"""Verify that the patch query includes tenant_id to prevent IDOR attacks."""

api = DataSourceApi()
method = unwrap(api.patch)

binding = MagicMock(id="b1", disabled=True)

with (
app.test_request_context("/"),
patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class,
patch("controllers.console.datasets.data_source.db.session.add"),
patch("controllers.console.datasets.data_source.db.session.commit"),
):
mock_session = MagicMock()
mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session
mock_session.execute.return_value.scalar_one_or_none.return_value = binding

method(api, "b1", "enable")

# Inspect the SELECT statement passed to session.execute
call_args = mock_session.execute.call_args
stmt = call_args[0][0]

# Structurally verify tenant_id filtering instead of brittle SQL string matching
where_clause = stmt.whereclause
assert where_clause is not None, "The query must have a WHERE clause"

# Check that tenant_id appears in the WHERE predicates
where_str = str(where_clause.compile(compile_kwargs={"literal_binds": True}))
assert "tenant_id" in where_str, "The patch query must filter by tenant_id to prevent IDOR vulnerabilities"


class TestDataSourceNotionListApi:
@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def protected_view(**kwargs):
headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key}
):
with patch.object(dify_config, "INNER_API", True):
with patch("controllers.inner_api.wraps.db.session.get") as mock_get:
mock_get.return_value = mock_user
with patch("controllers.inner_api.wraps.db.session") as mock_session:
mock_session.get.return_value = mock_user
result = protected_view()

# Assert
Expand Down
Loading