Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
permissions:
id-token: write # trusted publishing
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v4
- name: Build
run: uv build
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9e8454a21b30b1df74fa848 # release/v1
20 changes: 20 additions & 0 deletions .github/workflows/test-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
name: Publish to TestPyPI

on:
workflow_dispatch: # manual trigger only

jobs:
test-publish:
runs-on: ubuntu-latest
environment: testpypi
permissions:
id-token: write # trusted publishing
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v4
- name: Build
run: uv build
- name: Publish to TestPyPI
uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9e8454a21b30b1df74fa848 # release/v1
with:
repository-url: https://test.pypi.org/legacy/
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ jobs:
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v4
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: astral-sh/setup-uv@f0ec1fc3b38f5e7cd731bb6ce540c5af426746bb # v4
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install dependencies
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ pip install chat-sdk[all] # all adapters + state backends
## Quick Start

```python
from chat_sdk import Chat, Card, Button, Actions
from chat_sdk import Chat, Card, Button, Actions, MemoryStateAdapter
from chat_sdk.adapters.slack import create_slack_adapter

chat = Chat(
adapters={"slack": create_slack_adapter()},
Expand Down
15 changes: 15 additions & 0 deletions src/chat_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@
resolve_emoji_from_slack,
)
from chat_sdk.errors import ChatError, ChatNotImplementedError, LockError, RateLimitError
from chat_sdk.shared.errors import (
AdapterRateLimitError,
AuthenticationError,
NetworkError,
PermissionError as AdapterPermissionError,
ResourceNotFoundError,
ValidationError,
)
from chat_sdk.from_full_stream import from_full_stream
from chat_sdk.logger import ConsoleLogger, Logger, LogLevel
from chat_sdk.message_history import MessageHistoryCache, MessageHistoryConfig
Expand Down Expand Up @@ -226,6 +234,13 @@
"ChatNotImplementedError",
"LockError",
"RateLimitError",
# Adapter errors
"AdapterRateLimitError",
"AuthenticationError",
"ValidationError",
"NetworkError",
"ResourceNotFoundError",
"AdapterPermissionError",
# Format converter
"BaseFormatConverter",
# Logger
Expand Down
3 changes: 2 additions & 1 deletion src/chat_sdk/adapters/discord/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import hmac
import json
import os
import re
Expand Down Expand Up @@ -178,7 +179,7 @@ async def handle_webhook(
# Check if this is a forwarded Gateway event (uses bot token for auth)
gateway_token = self._get_header(request, "x-discord-gateway-token")
if gateway_token:
if gateway_token != self._bot_token:
if not hmac.compare_digest(gateway_token, self._bot_token):
self._logger.warn("Invalid gateway token")
return self._make_response("Invalid gateway token", 401)
self._logger.info("Discord forwarded Gateway event received")
Expand Down
12 changes: 7 additions & 5 deletions src/chat_sdk/adapters/google_chat/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def __init__(self, config: GoogleChatAdapterConfig | None = None) -> None:
self._warned_no_webhook_verification = False
self._warned_no_pubsub_verification = False

# Cached JWKS client for JWT verification (lazy init on first use)
self._jwks_client: Any | None = None

# Auth setup
self._credentials: ServiceAccountCredentials | None = None
self._use_adc = False
Expand Down Expand Up @@ -651,11 +654,10 @@ async def _verify_bearer_token(
import jwt as pyjwt
from jwt import PyJWKClient

# TODO: For direct Chat webhook JWTs signed by chat@system.gserviceaccount.com,
# use https://www.googleapis.com/service_accounts/v1/metadata/x509/chat@system.gserviceaccount.com
# Currently only supports OIDC-based verification (Pub/Sub push tokens and HTTP endpoint auth)
jwks_client = PyJWKClient("https://www.googleapis.com/oauth2/v3/certs")
signing_key = jwks_client.get_signing_key_from_jwt(token)
# Lazily create and cache the JWKS client (avoid per-request instantiation)
if self._jwks_client is None:
self._jwks_client = PyJWKClient("https://www.googleapis.com/oauth2/v3/certs")
signing_key = self._jwks_client.get_signing_key_from_jwt(token)
payload = pyjwt.decode(
token,
signing_key.key,
Expand Down
7 changes: 7 additions & 0 deletions src/chat_sdk/adapters/slack/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,6 +2691,13 @@ async def _send_to_response_url(
thread_ts: str | None = None,
) -> dict[str, Any]:
"""Send a request to Slack's response_url to modify an ephemeral message."""
# Validate response_url points to Slack (prevent SSRF)
from urllib.parse import urlparse

parsed = urlparse(response_url)
if not (parsed.scheme == "https" and parsed.hostname and parsed.hostname.endswith(".slack.com")):
raise ValidationError("slack", f"Invalid response_url: must be https://*.slack.com, got {response_url}")

import httpx

payload: dict[str, Any]
Expand Down
98 changes: 98 additions & 0 deletions src/chat_sdk/adapters/teams/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from __future__ import annotations

import base64
import hmac
import json
import os
import re
from datetime import datetime, timezone
from typing import Any
from urllib.parse import urlparse

from chat_sdk.adapters.teams.cards import card_to_adaptive_card
from chat_sdk.adapters.teams.format_converter import TeamsFormatConverter
Expand Down Expand Up @@ -57,6 +59,37 @@
MESSAGEID_STRIP_PATTERN = re.compile(r";messageid=\d+")
CACHE_TTL_MS = 30 * 24 * 60 * 60 * 1000 # 30 days

# Allowed Microsoft Bot Framework service URL patterns (SSRF protection).
# Covers commercial, GCC, GCCH, DoD, and sovereign cloud endpoints.
ALLOWED_SERVICE_URL_PATTERNS = [
re.compile(r"^https://smba\.trafficmanager\.net/"),
re.compile(r"^https://[a-z0-9.-]+\.botframework\.com/"),
re.compile(r"^https://[a-z0-9.-]+\.botframework\.us/"),
re.compile(r"^https://[a-z0-9.-]+\.teams\.microsoft\.com/"),
re.compile(r"^https://[a-z0-9.-]+\.teams\.microsoft\.us/"),
re.compile(r"^https://smba\.infra\.(gcc|gov)\.teams\.microsoft\.(com|us)/"),
]

# Bot Framework OpenID configuration URL for JWT verification
BOT_FRAMEWORK_OPENID_CONFIG_URL = (
"https://login.botframework.com/v1/.well-known/openid-configuration"
)


def _validate_service_url(url: str) -> None:
"""Validate that a service URL matches known Microsoft Bot Framework endpoints.

Raises :class:`~chat_sdk.shared.errors.ValidationError` if the URL is not
in the allow-list, preventing SSRF attacks via crafted ``serviceUrl`` values.
"""
for pattern in ALLOWED_SERVICE_URL_PATTERNS:
if pattern.match(url):
return
raise ValidationError(
"teams",
f"Service URL is not an allowed Bot Framework endpoint: {url}",
)


def _handle_teams_error(error: Any, operation: str) -> None:
"""Convert Teams SDK errors to adapter errors and raise.
Expand Down Expand Up @@ -132,6 +165,7 @@ def __init__(self, config: TeamsAdapterConfig | None = None) -> None:
self._bot_user_id: str | None = self._app_id or None
self._access_token: str | None = None
self._token_expiry: float = 0
self._jwks_client: Any | None = None # Cached PyJWKClient for JWT verification

@property
def name(self) -> str:
Expand Down Expand Up @@ -170,6 +204,12 @@ async def handle_webhook(
body = await self._get_request_body(request)
self._logger.debug("Teams webhook raw body", {"body": body[:500] if body else ""})

# ---- JWT verification (Bot Framework tokens) ----
if self._app_id:
auth_result = await self._verify_bot_framework_token(request)
if auth_result is not None:
return auth_result

try:
activity: dict[str, Any] = json.loads(body)
except (json.JSONDecodeError, ValueError):
Expand Down Expand Up @@ -1529,6 +1569,7 @@ async def _teams_send(
"""Send an activity to a Teams conversation via Bot Framework REST API."""
import aiohttp # lazy import

_validate_service_url(decoded.service_url)
token = await self._get_access_token()
url = f"{decoded.service_url}v3/conversations/{decoded.conversation_id}/activities"

Expand Down Expand Up @@ -1560,6 +1601,7 @@ async def _teams_update(
"""Update an activity in a Teams conversation via Bot Framework REST API."""
import aiohttp # lazy import

_validate_service_url(decoded.service_url)
token = await self._get_access_token()
url = f"{decoded.service_url}v3/conversations/{decoded.conversation_id}/activities/{message_id}"

Expand Down Expand Up @@ -1589,6 +1631,7 @@ async def _teams_delete(
"""Delete an activity from a Teams conversation via Bot Framework REST API."""
import aiohttp # lazy import

_validate_service_url(decoded.service_url)
token = await self._get_access_token()
url = f"{decoded.service_url}v3/conversations/{decoded.conversation_id}/activities/{message_id}"

Expand All @@ -1606,6 +1649,61 @@ async def _teams_delete(
f"Teams API error: {response.status} {error_text}",
)

# =========================================================================
# JWT verification (Bot Framework)
# =========================================================================

async def _verify_bot_framework_token(self, request: Any) -> Any | None:
"""Verify the JWT Bearer token from the Bot Framework.

Returns a 401 response dict if authentication fails, or ``None`` if
the token is valid.
"""
auth_header: str | None = self._get_header(request, "authorization")
if not auth_header or not auth_header.startswith("Bearer "):
self._logger.warn("Missing or invalid Authorization header on Teams webhook")
return self._make_response("Unauthorized", 401)

token = auth_header[7:]
try:
import jwt as pyjwt
from jwt import PyJWKClient

# Lazily create and cache the JWKS client
if self._jwks_client is None:
import aiohttp

async with aiohttp.ClientSession() as session:
async with session.get(BOT_FRAMEWORK_OPENID_CONFIG_URL) as resp:
if resp.status != 200:
self._logger.error("Failed to fetch Bot Framework OpenID config", {"status": resp.status})
return self._make_response("Unauthorized", 401)
openid_config = await resp.json()
Comment on lines +1677 to +1681
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved robustness and clearer error logging, it's a good practice to explicitly check if the HTTP request to fetch the OpenID configuration was successful before attempting to parse the JSON response. This will provide more specific logs if the endpoint is unavailable or returns an error.

Suggested change
async with session.get(BOT_FRAMEWORK_OPENID_CONFIG_URL) as resp:
openid_config = await resp.json()
async with session.get(BOT_FRAMEWORK_OPENID_CONFIG_URL) as resp:
if not resp.ok:
self._logger.error(
"Failed to fetch Bot Framework OpenID config",
{"status": resp.status, "url": BOT_FRAMEWORK_OPENID_CONFIG_URL},
)
return self._make_response("Unauthorized", 401)
openid_config = await resp.json()

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already addressed in the existing code -- the resp.status != 200 check was present at line 1700 before this review comment was posted. No changes needed.

jwks_uri = openid_config.get("jwks_uri")
if not jwks_uri:
self._logger.error("No jwks_uri in Bot Framework OpenID config")
return self._make_response("Unauthorized", 401)
self._jwks_client = PyJWKClient(jwks_uri)

signing_key = self._jwks_client.get_signing_key_from_jwt(token)
payload = pyjwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
audience=self._app_id,
)
self._logger.debug(
"Teams JWT verified",
{
"iss": payload.get("iss"),
"aud": payload.get("aud"),
},
)
return None # success
except Exception as exc:
self._logger.warn(f"Teams JWT verification failed: {exc}")
return self._make_response("Unauthorized", 401)

# =========================================================================
# Request/Response helpers (framework-agnostic)
# =========================================================================
Expand Down
28 changes: 25 additions & 3 deletions src/chat_sdk/adapters/whatsapp/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,32 @@ async def download_media(self, media_id: str) -> bytes:

media_info = await meta_response.json()

# Step 2: Download the actual file
# Validate the download URL to prevent SSRF
download_url = media_info["url"]
parsed = urlparse(download_url)
if parsed.scheme != "https":
raise ValidationError(
"whatsapp",
f"Media download URL must use HTTPS, got: {parsed.scheme}",
)
host = (parsed.hostname or "").lower()
allowed_suffixes = (
".facebook.com", ".fbcdn.net", ".fbsbx.com",
".whatsapp.net", ".whatsapp.com",
)
allowed_exact = {"facebook.com", "fbcdn.net", "fbsbx.com", "whatsapp.net", "whatsapp.com"}
if not (
any(host.endswith(s) for s in allowed_suffixes)
or host in allowed_exact
):
raise ValidationError(
"whatsapp",
f"Media download URL host is not an allowed Meta domain: {host}",
)

# Step 2: Download the actual file (no Bearer token -- CDN URLs are pre-signed)
async with session.get(
media_info["url"],
headers={"Authorization": f"Bearer {self._access_token}"},
download_url,
) as data_response:
if data_response.status != 200:
self._logger.error(
Expand Down
6 changes: 2 additions & 4 deletions src/chat_sdk/state/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

import logging
import os
import random
import string
import secrets
import time
import warnings
from dataclasses import dataclass
Expand All @@ -35,8 +34,7 @@ class _CachedValue:


def _generate_token() -> str:
suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=14))
return f"mem_{int(time.time() * 1000)}_{suffix}"
return f"mem_{int(time.time() * 1000)}_{secrets.token_hex(16)}"


def _now_ms() -> float:
Expand Down
6 changes: 2 additions & 4 deletions src/chat_sdk/state/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import json
import logging
import os
import random
import string
import secrets
import time
from typing import Any

Expand All @@ -27,8 +26,7 @@


def _generate_token() -> str:
suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=14))
return f"redis_{int(time.time() * 1000)}_{suffix}"
return f"redis_{int(time.time() * 1000)}_{secrets.token_hex(16)}"


# ---------------------------------------------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions tests/test_teams_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,15 @@ def data(self) -> bytes:


class TestHandleWebhook:
@pytest.fixture(autouse=True)
def _skip_jwt(self, monkeypatch):
"""Bypass JWT verification in unit tests."""
monkeypatch.setattr(
TeamsAdapter,
"_verify_bot_framework_token",
AsyncMock(return_value=None),
)

@pytest.mark.asyncio
async def test_400_for_invalid_json(self):
adapter = _make_adapter(logger=_make_logger())
Expand Down
Loading
Loading