Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/controllers/console/apikey.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from flask_restx import Resource, fields, marshal_with
from flask_restx._http import HTTPStatus
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden

from extensions.ext_database import db
Expand Down Expand Up @@ -34,7 +34,7 @@


def _get_resource(resource_id, tenant_id, resource_model):
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
resource = session.execute(
select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
).scalar_one_or_none()
Expand Down
4 changes: 2 additions & 2 deletions api/controllers/console/explore/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from flask import request
from pydantic import BaseModel, Field, TypeAdapter, model_validator
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound

from controllers.common.schema import register_schema_models
Expand Down Expand Up @@ -74,7 +74,7 @@ def get(self, installed_app):
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
pagination = WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
Expand Down
4 changes: 2 additions & 2 deletions api/controllers/console/workspace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import wraps
from typing import ParamSpec, TypeVar

from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden

from extensions.ext_database import db
Expand All @@ -24,7 +24,7 @@ def decorated(*args: P.args, **kwargs: P.kwargs):
user = current_user
tenant_id = current_tenant_id

with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
permission = (
session.query(TenantPluginPermission)
.where(
Expand Down
4 changes: 2 additions & 2 deletions api/controllers/console/workspace/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator, model_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker

from configs import dify_config
from constants.languages import supported_language
Expand Down Expand Up @@ -562,7 +562,7 @@ def post(self):

user_email = current_user.email
else:
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
Expand Down
26 changes: 13 additions & 13 deletions api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import Forbidden

from configs import dify_config
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def put(self):

# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=payload.provider_id
Expand All @@ -1034,7 +1034,7 @@ def put(self):
)

# Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.update_provider(
tenant_id=current_tenant_id,
Expand All @@ -1061,7 +1061,7 @@ def delete(self):
payload = MCPProviderDeletePayload.model_validate(console_ns.payload or {})
_, current_tenant_id = current_account_with_tenant()

with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id)

Expand All @@ -1079,7 +1079,7 @@ def post(self):
provider_id = payload.provider_id
_, tenant_id = current_account_with_tenant()

with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
db_provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
if not db_provider:
Expand All @@ -1100,7 +1100,7 @@ def post(self):
sse_read_timeout=provider_entity.sse_read_timeout,
):
# Update credentials in new transaction
with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.update_provider_credentials(
provider_id=provider_id,
Expand All @@ -1118,17 +1118,17 @@ def post(self):
resource_metadata_url=e.resource_metadata_url,
scope_hint=e.scope_hint,
)
with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
response = service.execute_auth_actions(auth_result)
return response
except MCPRefreshTokenError as e:
with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to refresh token, please try to authorize again: {e}") from e
except (MCPError, ValueError) as e:
with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
service.clear_provider_credentials(provider_id=provider_id, tenant_id=tenant_id)
raise ValueError(f"Failed to connect to MCP server: {e}") from e
Expand All @@ -1141,7 +1141,7 @@ class ToolMCPDetailApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
provider = service.get_provider(provider_id=provider_id, tenant_id=tenant_id)
return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
Expand All @@ -1155,7 +1155,7 @@ class ToolMCPListAllApi(Resource):
def get(self):
_, tenant_id = current_account_with_tenant()

with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
# Skip sensitive data decryption for list view to improve performance
tools = service.list_providers(tenant_id=tenant_id, include_sensitive=False)
Expand All @@ -1170,7 +1170,7 @@ class ToolMCPUpdateApi(Resource):
@account_initialization_required
def get(self, provider_id):
_, tenant_id = current_account_with_tenant()
with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
service = MCPToolManageService(session=session)
tools = service.list_provider_tools(
tenant_id=tenant_id,
Expand All @@ -1188,7 +1188,7 @@ def get(self):
authorization_code = query.code

# Create service instance for handle_callback
with Session(db.engine) as session, session.begin():
with sessionmaker(db.engine).begin() as session:
mcp_service = MCPToolManageService(session=session)
# handle_callback now returns state data and tokens
state_data, tokens = handle_callback(state_key, authorization_code)
Expand Down
5 changes: 2 additions & 3 deletions api/controllers/console/workspace/trigger_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flask_restx import Resource
from graphon.model_runtime.utils.encoders import jsonable_encoder
from pydantic import BaseModel, model_validator
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import BadRequest, Forbidden

from configs import dify_config
Expand Down Expand Up @@ -375,7 +375,7 @@ def post(self, subscription_id: str):
assert user.current_tenant_id is not None

try:
with Session(db.engine) as session:
with sessionmaker(db.engine).begin() as session:
# Delete trigger provider subscription
TriggerProviderService.delete_trigger_provider(
session=session,
Expand All @@ -388,7 +388,6 @@ def post(self, subscription_id: str):
tenant_id=user.current_tenant_id,
subscription_id=subscription_id,
)
session.commit()
return {"result": "success"}
except ValueError as e:
raise BadRequest(str(e))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def client(flask_app_with_containers):
return_value=(MagicMock(id="u1"), "t1"),
autospec=True,
)
@patch("controllers.console.workspace.tool_providers.Session", autospec=True)
@patch("controllers.console.workspace.tool_providers.sessionmaker", autospec=True)
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True)
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client):
Expand All @@ -88,7 +88,7 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
create_result.id = "provider-1"
svc.create_provider.return_value = create_result
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
mock_session.return_value.__enter__.return_value = MagicMock()
mock_session.return_value.begin.return_value.__enter__.return_value = MagicMock()
# Patch MCPToolManageService constructed inside controller
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc, autospec=True):
payload = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,14 @@ def test_delete_subscription(self, app):
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls,
patch("controllers.console.workspace.trigger_providers.sessionmaker") as mock_session_cls,
patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"),
patch(
"controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription"
),
):
mock_db.engine = MagicMock()
mock_session_cls.return_value.__enter__.return_value = mock_session
mock_session_cls.return_value.begin.return_value.__enter__.return_value = mock_session

result = method(api, "sub1")

Expand All @@ -327,14 +327,14 @@ def test_delete_subscription_value_error(self, app):
app.test_request_context("/"),
patch("controllers.console.workspace.trigger_providers.current_user", mock_user()),
patch("controllers.console.workspace.trigger_providers.db") as mock_db,
patch("controllers.console.workspace.trigger_providers.Session") as session_cls,
patch("controllers.console.workspace.trigger_providers.sessionmaker") as session_cls,
patch(
"controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider",
side_effect=ValueError("bad"),
),
):
mock_db.engine = MagicMock()
session_cls.return_value.__enter__.return_value = MagicMock()
session_cls.return_value.begin.return_value.__enter__.return_value = MagicMock()

with pytest.raises(BadRequest):
method(api, "sub1")
Expand Down
Loading