Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ on:
- 'tests/**'
- 'scripts/**'
- 'pyproject.toml'
- '.pyrefly-baseline.json'
- '.github/workflows/lint.yml'
pull_request:
branches: [main]
Expand All @@ -18,7 +17,6 @@ on:
- 'tests/**'
- 'scripts/**'
- 'pyproject.toml'
- '.pyrefly-baseline.json'
- '.github/workflows/lint.yml'
workflow_dispatch:

Expand Down Expand Up @@ -59,10 +57,10 @@ jobs:
continue-on-error: true
run: uv run python scripts/audit_test_quality.py

- name: Pyrefly type check (against baseline)
- name: Pyrefly type check
id: pyrefly
continue-on-error: true
run: uv run pyrefly check --baseline=.pyrefly-baseline.json
run: uv run pyrefly check
Comment thread
patrick-chinchill marked this conversation as resolved.

- name: Minimize uv cache
run: uv cache prune --ci
Expand All @@ -80,9 +78,9 @@ jobs:
echo "| Pyrefly | ${{ steps.pyrefly.outcome }} |" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
if [ "${{ steps.pyrefly.outcome }}" = "success" ]; then
echo "No new type issues above the baseline." >> $GITHUB_STEP_SUMMARY
echo "Zero type errors." >> $GITHUB_STEP_SUMMARY
else
echo "New type issues above the baseline. Either fix them or refresh the baseline with \`uv run pyrefly check --baseline=.pyrefly-baseline.json --update-baseline\` if intentional." >> $GITHUB_STEP_SUMMARY
echo "Type errors detected — see run output." >> $GITHUB_STEP_SUMMARY
fi

- name: Fail if any step failed
Expand Down
2,560 changes: 0 additions & 2,560 deletions .pyrefly-baseline.json

This file was deleted.

21 changes: 13 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ dev = [
# ---------------------------------------------------------------------------
# Pyrefly type checker configuration
# ---------------------------------------------------------------------------
# Baseline workflow:
# Check against baseline: uv run pyrefly check --baseline=.pyrefly-baseline.json
# Refresh baseline: uv run pyrefly check --baseline=.pyrefly-baseline.json --update-baseline
# Target: zero errors. Run:
# uv run pyrefly check
Comment thread
coderabbitai[bot] marked this conversation as resolved.
# ---------------------------------------------------------------------------

[tool.pyrefly]
Expand All @@ -99,21 +98,27 @@ python-platform = "linux"
# Optional adapter deps that aren't in the default install. Treating them as
# Any keeps pyrefly from flagging missing-import every time we lazy-load one.
replace-imports-with-any = [
# Slack
# Slack — covers top-level + submodule imports like slack_sdk.web.async_client
"slack_sdk",
# Discord
"slack_sdk.*",
# Discord — signed-request verification uses pynacl's nacl.signing
"nacl",
# Google Chat
"nacl.*",
# Google Chat — auth uses google.auth.* subpackages
"google",
"google.*",
# State backends
"redis",
"redis.*",
"asyncpg",
"asyncpg.*",
# HTTP clients used in lazy paths
"httpx",
"httpx.*",
# GitHub App auth
"jwt",
"jwt.*",
]

[tool.pyrefly.errors]
# Default severity. Known issues tracked in .pyrefly-baseline.json —
# new errors fail CI, existing ones are allowed.
# CI policy: zero pyrefly errors.
63 changes: 41 additions & 22 deletions src/chat_sdk/adapters/discord/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from __future__ import annotations

import hmac
import inspect
import json
import os
import re
from contextvars import ContextVar
from datetime import datetime, timezone
from typing import Any
from typing import Any, Literal, cast
from urllib.parse import quote

from chat_sdk.adapters.discord.cards import (
Expand Down Expand Up @@ -52,8 +53,10 @@
FetchResult,
FileUpload,
FormattedContent,
LockScope,
Message,
MessageMetadata,
PostableRaw,
RawMessage,
ReactionEvent,
SlashCommandEvent,
Expand Down Expand Up @@ -160,7 +163,7 @@ def bot_user_id(self) -> str | None:
return self._bot_user_id

@property
def lock_scope(self) -> str | None:
def lock_scope(self) -> LockScope | None:
return None

@property
Expand Down Expand Up @@ -362,7 +365,7 @@ def _handle_component_interaction(
),
message_id=message_id,
thread_id=thread_id,
thread=None,
thread=None, # pyrefly: ignore[bad-argument-type] # filled in by Chat
adapter=self,
raw=interaction,
),
Expand All @@ -379,7 +382,11 @@ def _handle_application_command_interaction(
self._logger.warn("Chat instance not initialized, ignoring interaction")
return

data = interaction.get("data", {})
# `interaction["data"]` is a union of several TypedDicts (one per
# interaction type). Cast to a plain dict so we can access shared
# fields like `name` and `options` without pyrefly rejecting keys
# that only appear on one variant.
data = cast("dict[str, Any]", interaction.get("data", {}))
command_name = data.get("name")
if not command_name:
self._logger.warn("No command name in application command interaction")
Expand Down Expand Up @@ -447,7 +454,7 @@ def _handle_application_command_interaction(
is_me=user.get("id") == self._application_id,
),
adapter=self,
channel=None,
channel=None, # pyrefly: ignore[bad-argument-type] # filled in by Chat
raw=interaction,
)
event.channel_id = channel_id # type: ignore[attr-defined]
Expand Down Expand Up @@ -704,7 +711,7 @@ async def _handle_forwarded_reaction(
self._chat.process_reaction(
ReactionEvent(
adapter=self,
thread=None,
thread=None, # pyrefly: ignore[bad-argument-type] # filled in by Chat
thread_id=thread_id,
message_id=data.get("message_id", ""),
emoji=normalized,
Expand Down Expand Up @@ -953,7 +960,7 @@ async def remove_reaction(
"DELETE",
)

async def start_typing(self, thread_id: str, _status: str | None = None) -> None:
async def start_typing(self, thread_id: str, status: str | None = None) -> None:
"""Start typing indicator in a Discord channel or thread."""
decoded = self.decode_thread_id(thread_id)
target_channel_id = decoded.thread_id or decoded.channel_id
Expand Down Expand Up @@ -1159,7 +1166,7 @@ async def stream(

accumulated += text

postable: AdapterPostableMessage = {"raw": accumulated}
postable: AdapterPostableMessage = PostableRaw(raw=accumulated)

if message_id:
await self.edit_message(thread_id, message_id, postable)
Expand Down Expand Up @@ -1256,7 +1263,7 @@ def _parse_discord_message(self, raw: dict[str, Any], thread_id: str) -> Message
],
)

def _get_attachment_type(self, mime_type: str | None) -> str:
def _get_attachment_type(self, mime_type: str | None) -> Literal["audio", "file", "image", "video"]:
"""Determine attachment type from MIME type."""
if not mime_type:
return "file"
Expand Down Expand Up @@ -1396,23 +1403,35 @@ async def _discord_fetch(
# Request/Response helpers (framework-agnostic)
# =========================================================================

async def _get_request_body(self, request: Any) -> str:
@staticmethod
async def _get_request_body(request: Any) -> str:
"""Extract the request body as a string."""
if hasattr(request, "body"):
body = request.body
# `hasattr` narrows `Any` → `object` (not awaitable); using
# `getattr(..., None)` preserves `Any` for framework duck-typing.
# Handle both callable and non-callable `request.text`. Gating
# entry on callability would drop populated string attributes.
text_attr = getattr(request, "text", None)
if text_attr is not None:
if callable(text_attr):
result = text_attr()
text_attr = await result if inspect.isawaitable(result) else result
return text_attr.decode("utf-8") if isinstance(text_attr, (bytes, bytearray)) else str(text_attr)
body = getattr(request, "body", None)
if body is not None:
if callable(body):
body = body()
# Some frameworks expose `body` as an async method; if calling it
# produced a coroutine, await it before treating as bytes/str.
if inspect.isawaitable(body):
body = await body
if hasattr(body, "read"):
raw = await body.read() if hasattr(body.read, "__await__") else body.read()
return raw.decode("utf-8") if isinstance(raw, bytes) else raw
return body.decode("utf-8") if isinstance(body, bytes) else str(body)
if hasattr(request, "text"):
if callable(request.text):
return await request.text()
return request.text
if hasattr(request, "data"):
data = request.data
return data.decode("utf-8") if isinstance(data, bytes) else str(data)
raw_result = body.read()
raw = await raw_result if inspect.isawaitable(raw_result) else raw_result
return raw.decode("utf-8") if isinstance(raw, (bytes, bytearray)) else str(raw)
return body.decode("utf-8") if isinstance(body, (bytes, bytearray)) else str(body)
data = getattr(request, "data", None)
if data is not None:
return data.decode("utf-8") if isinstance(data, (bytes, bytearray)) else str(data)
return ""

def _get_header(self, request: Any, name: str) -> str | None:
Expand Down
14 changes: 7 additions & 7 deletions src/chat_sdk/adapters/discord/cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Any
from typing import Any, cast

from chat_sdk.adapters.discord.types import DiscordActionRow, DiscordButton
from chat_sdk.cards import (
Expand Down Expand Up @@ -113,12 +113,12 @@ def _process_child(
elif child_type == "fields":
_convert_fields_element(child, fields) # type: ignore[arg-type]
elif child_type == "link":
label = child.get("label", "") # type: ignore[union-attr]
url = child.get("url", "") # type: ignore[union-attr]
label = cast("str", child.get("label", ""))
url = cast("str", child.get("url", ""))
text_parts.append(f"[{_convert_emoji(label)}]({url})")
elif child_type == "table":
headers = child.get("headers", []) # type: ignore[union-attr]
rows = child.get("rows", []) # type: ignore[union-attr]
headers = cast("list[str]", child.get("headers", []))
rows = cast("list[list[str]]", child.get("rows", []))
text_parts.append("\n".join(render_gfm_table(headers, rows)))
else:
text = card_child_to_fallback_text(child)
Expand Down Expand Up @@ -270,8 +270,8 @@ def _child_to_fallback_text(child: CardChild) -> str | None:
if (t := _child_to_fallback_text(c))
)
if child_type == "table":
headers = child.get("headers", []) # type: ignore[union-attr]
rows = child.get("rows", []) # type: ignore[union-attr]
headers = cast("list[str]", child.get("headers", []))
rows = cast("list[list[str]]", child.get("rows", []))
return f"```\n{table_element_to_ascii(headers, rows)}\n```"
if child_type == "divider":
return "---"
Expand Down
Loading
Loading