Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions agent_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
get_credentials,
has_embedded_credentials,
run_oauth_flow,
run_oauth_flow_async,
)
from agent_core.core.config import (
ConfigRegistry,
Expand Down Expand Up @@ -312,6 +313,7 @@
"get_credentials",
"has_embedded_credentials",
"run_oauth_flow",
"run_oauth_flow_async",
# Config
"ConfigRegistry",
"get_workspace_root",
Expand Down
3 changes: 2 additions & 1 deletion agent_core/core/credentials/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
encode_credential,
generate_credentials_block,
)
from agent_core.core.credentials.oauth_server import run_oauth_flow
from agent_core.core.credentials.oauth_server import run_oauth_flow, run_oauth_flow_async

__all__ = [
"get_credential",
Expand All @@ -17,4 +17,5 @@
"encode_credential",
"generate_credentials_block",
"run_oauth_flow",
"run_oauth_flow_async",
]
186 changes: 142 additions & 44 deletions agent_core/core/credentials/oauth_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@

# HTTPS (for Slack and other providers requiring https redirect URIs)
code, error = run_oauth_flow("https://slack.com/oauth/...", use_https=True)

# Async version with cancellation support (recommended for UI contexts)
code, error = await run_oauth_flow_async("https://provider.com/oauth/...")
"""

import asyncio
import ipaddress
import logging
import os
Expand All @@ -29,7 +33,7 @@
from datetime import datetime, timedelta, timezone
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import urlparse, parse_qs
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -104,58 +108,78 @@ def _cleanup_files(*paths: str) -> None:
pass


class _OAuthCallbackHandler(BaseHTTPRequestHandler):
"""Handler for OAuth callback requests."""

code: Optional[str] = None
state: Optional[str] = None
error: Optional[str] = None

def do_GET(self):
"""Handle GET request from OAuth callback."""
params = parse_qs(urlparse(self.path).query)
_OAuthCallbackHandler.code = params.get("code", [None])[0]
_OAuthCallbackHandler.state = params.get("state", [None])[0]
_OAuthCallbackHandler.error = params.get("error", [None])[0]
def _make_callback_handler(result_holder: Dict[str, Any]):
"""
Create a callback handler class that stores results in the provided dict.

self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
if _OAuthCallbackHandler.code:
self.wfile.write(
b"<h2>Authorization successful!</h2><p>You can close this tab.</p>"
)
else:
self.wfile.write(
f"<h2>Failed</h2><p>{_OAuthCallbackHandler.error}</p>".encode()
)
This avoids class-level state that would be shared across OAuth flows.
"""
class _OAuthCallbackHandler(BaseHTTPRequestHandler):
"""Handler for OAuth callback requests."""

def do_GET(self):
"""Handle GET request from OAuth callback."""
params = parse_qs(urlparse(self.path).query)
result_holder["code"] = params.get("code", [None])[0]
result_holder["state"] = params.get("state", [None])[0]
result_holder["error"] = params.get("error", [None])[0]

self.send_response(200)
self.send_header("Content-Type", "text/html")
self.end_headers()
if result_holder["code"]:
self.wfile.write(
b"<h2>Authorization successful!</h2><p>You can close this tab.</p>"
)
else:
self.wfile.write(
f"<h2>Failed</h2><p>{result_holder['error']}</p>".encode()
)

def log_message(self, format, *args):
"""Suppress default HTTP server logging."""
pass

def log_message(self, format, *args):
"""Suppress default HTTP server logging."""
pass
return _OAuthCallbackHandler


def _serve_until_code(server: HTTPServer, deadline: float) -> None:
def _serve_until_code(
server: HTTPServer,
deadline: float,
result_holder: Dict[str, Any],
cancel_event: Optional[threading.Event] = None,
) -> None:
"""
Handle requests in a loop until we capture the OAuth code/error or timeout.
Handle requests in a loop until we capture the OAuth code/error, timeout, or cancelled.

A single handle_request() can be consumed by TLS handshake failures,
favicon requests, browser pre-connects, etc. Looping ensures the server
stays alive for the actual callback.
"""
while time.time() < deadline:
remaining = max(0.5, deadline - time.time())
server.timeout = min(remaining, 2.0)
# Check for cancellation
if cancel_event and cancel_event.is_set():
logger.debug("[OAUTH] Cancellation requested, stopping server")
break

remaining = max(0.1, deadline - time.time())
# Use shorter timeout (0.5s) for responsive cancellation checking
server.timeout = min(remaining, 0.5)
try:
server.handle_request()
except Exception as e:
logger.debug(f"[OAUTH] handle_request error (will retry): {e}")
if _OAuthCallbackHandler.code or _OAuthCallbackHandler.error:

if result_holder.get("code") or result_holder.get("error"):
break


def run_oauth_flow(
auth_url: str, port: int = 8765, timeout: int = 120, use_https: bool = False
auth_url: str,
port: int = 8765,
timeout: int = 120,
use_https: bool = False,
cancel_event: Optional[threading.Event] = None,
) -> Tuple[Optional[str], Optional[str]]:
"""
Open browser for OAuth, wait for callback.
Expand All @@ -167,17 +191,27 @@ def run_oauth_flow(
use_https: If True, serve HTTPS with a self-signed cert.
Required for providers like Slack that reject http:// redirect URIs.
Default False (plain HTTP — works with Google, Notion, etc.).
cancel_event: Optional threading.Event to signal cancellation.
When set, the OAuth flow will stop and return a cancellation error.

Returns:
Tuple of (code, error_message):
- On success: (authorization_code, None)
- On failure: (None, error_message)
"""
_OAuthCallbackHandler.code = None
_OAuthCallbackHandler.state = None
_OAuthCallbackHandler.error = None
# Check for early cancellation
if cancel_event and cancel_event.is_set():
return None, "OAuth cancelled"

server = HTTPServer(("127.0.0.1", port), _OAuthCallbackHandler)
# Use instance-level result holder instead of class-level state
result_holder: Dict[str, Any] = {"code": None, "state": None, "error": None}
handler_class = _make_callback_handler(result_holder)

try:
server = HTTPServer(("127.0.0.1", port), handler_class)
except OSError as e:
# Port already in use
return None, f"Failed to start OAuth server: {e}"

if use_https:
cert_path = key_path = None
Expand All @@ -198,21 +232,85 @@ def run_oauth_flow(

deadline = time.time() + timeout
thread = threading.Thread(
target=_serve_until_code, args=(server, deadline), daemon=True
target=_serve_until_code,
args=(server, deadline, result_holder, cancel_event),
daemon=True
)
thread.start()

# Check cancellation before opening browser
if cancel_event and cancel_event.is_set():
server.server_close()
return None, "OAuth cancelled"

try:
webbrowser.open(auth_url)
except Exception:
server.server_close()
return None, f"Could not open browser. Visit manually:\n{auth_url}"

thread.join(timeout=timeout)
# Wait for thread with periodic cancellation checks
while thread.is_alive():
thread.join(timeout=0.5)
if cancel_event and cancel_event.is_set():
logger.debug("[OAUTH] Cancellation detected during wait")
break

server.server_close()

if _OAuthCallbackHandler.error:
return None, _OAuthCallbackHandler.error
if _OAuthCallbackHandler.code:
return _OAuthCallbackHandler.code, None
# Check cancellation first
if cancel_event and cancel_event.is_set():
return None, "OAuth cancelled"

if result_holder.get("error"):
return None, result_holder["error"]
if result_holder.get("code"):
return result_holder["code"], None
return None, "OAuth timed out."


async def run_oauth_flow_async(
auth_url: str,
port: int = 8765,
timeout: int = 120,
use_https: bool = False,
) -> Tuple[Optional[str], Optional[str]]:
"""
Async version of run_oauth_flow with proper cancellation support.

This function runs the OAuth flow in a thread executor and properly handles
asyncio task cancellation by signaling the OAuth server to stop.

Args:
auth_url: The full OAuth authorization URL to open.
port: Local port for callback server (default: 8765).
timeout: Seconds to wait for callback (default: 120).
use_https: If True, serve HTTPS with a self-signed cert.

Returns:
Tuple of (code, error_message):
- On success: (authorization_code, None)
- On failure: (None, error_message)

Raises:
asyncio.CancelledError: If the task is cancelled (after signaling OAuth to stop)
"""
cancel_event = threading.Event()
loop = asyncio.get_event_loop()

def run_flow():
return run_oauth_flow(
auth_url=auth_url,
port=port,
timeout=timeout,
use_https=use_https,
cancel_event=cancel_event,
)

try:
return await loop.run_in_executor(None, run_flow)
except asyncio.CancelledError:
# Signal the OAuth server to stop
cancel_event.set()
logger.debug("[OAUTH] Async task cancelled, signaled OAuth server to stop")
raise
20 changes: 10 additions & 10 deletions app/credentials/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ async def login(self, args):
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
from agent_core import run_oauth_flow
code, error = run_oauth_flow(f"https://accounts.google.com/o/oauth2/v2/auth?{urlencode(params)}")
from agent_core import run_oauth_flow_async
code, error = await run_oauth_flow_async(f"https://accounts.google.com/o/oauth2/v2/auth?{urlencode(params)}")
if error: return False, f"Google OAuth failed: {error}"

token_data = {
Expand Down Expand Up @@ -141,8 +141,8 @@ async def invite(self, args):

scopes = "chat:write,channels:read,channels:history,groups:read,groups:history,users:read,files:write,im:read,im:write,im:history"
params = {"client_id": SLACK_SHARED_CLIENT_ID, "scope": scopes, "redirect_uri": REDIRECT_URI_HTTPS, "state": secrets.token_urlsafe(32)}
from agent_core import run_oauth_flow
code, error = run_oauth_flow(f"https://slack.com/oauth/v2/authorize?{urlencode(params)}", use_https=True)
from agent_core import run_oauth_flow_async
code, error = await run_oauth_flow_async(f"https://slack.com/oauth/v2/authorize?{urlencode(params)}", use_https=True)
if error: return False, f"Slack OAuth failed: {error}"

import aiohttp
Expand Down Expand Up @@ -206,8 +206,8 @@ async def invite(self, args):
return False, "CraftOS Notion integration not configured. Set NOTION_SHARED_CLIENT_ID and NOTION_SHARED_CLIENT_SECRET env vars.\nAlternatively, use /notion login <token> with your own integration token."

params = {"client_id": NOTION_SHARED_CLIENT_ID, "redirect_uri": REDIRECT_URI, "response_type": "code", "owner": "user", "state": secrets.token_urlsafe(32)}
from agent_core import run_oauth_flow
code, error = run_oauth_flow(f"https://api.notion.com/v1/oauth/authorize?{urlencode(params)}")
from agent_core import run_oauth_flow_async
code, error = await run_oauth_flow_async(f"https://api.notion.com/v1/oauth/authorize?{urlencode(params)}")
if error: return False, f"Notion OAuth failed: {error}"

import aiohttp
Expand Down Expand Up @@ -264,8 +264,8 @@ async def login(self, args):
return False, "Not configured. Set LINKEDIN_CLIENT_ID and LINKEDIN_CLIENT_SECRET env vars."

params = {"response_type": "code", "client_id": LINKEDIN_CLIENT_ID, "redirect_uri": REDIRECT_URI, "scope": "openid profile email w_member_social", "state": secrets.token_urlsafe(32)}
from agent_core import run_oauth_flow
code, error = run_oauth_flow(f"https://www.linkedin.com/oauth/v2/authorization?{urlencode(params)}")
from agent_core import run_oauth_flow_async
code, error = await run_oauth_flow_async(f"https://www.linkedin.com/oauth/v2/authorization?{urlencode(params)}")
if error: return False, f"LinkedIn OAuth failed: {error}"

import aiohttp
Expand Down Expand Up @@ -818,8 +818,8 @@ async def login(self, args):
"code_challenge": code_challenge,
"code_challenge_method": "S256",
}
from agent_core import run_oauth_flow
code, error = run_oauth_flow(
from agent_core import run_oauth_flow_async
code, error = await run_oauth_flow_async(
f"https://login.microsoftonline.com/common/oauth2/v2.0/authorize?{urlencode(params)}"
)
if error:
Expand Down
Loading