Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion grace/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.10-alpha"
__version__ = "0.10.4-alpha"
24 changes: 13 additions & 11 deletions grace/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from logging.handlers import RotatingFileHandler

from types import ModuleType
from typing import Generator, Any, Union, Dict
from typing import Generator, Any, Union, Dict, no_type_check

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
Expand All @@ -21,13 +21,15 @@
create_database,
drop_database
)

from pathlib import Path
from grace.config import Config
from grace.exceptions import ConfigError
from grace.importer import find_all_importables, import_module


ConfigReturn = Union[str, int, float, None]


class Application:
"""This class is the core of the application In other words,
this class that manage the database, the application environment
Expand All @@ -38,26 +40,28 @@ class Application:
__session: Union[Session, None] = None
__base: DeclarativeMeta = declarative_base()

def __init__(self):
def __init__(self) -> None:
database_config_path: Path = Path("config/database.cfg")

if not database_config_path.exists():
raise ConfigError("Unable to find the 'database.cfg' file.")

self.__token: str = self.config.get("discord", "token")
self.__token: str = str(self.config.get("discord", "token"))
self.__engine: Union[Engine, None] = None

self.command_sync: bool = True
self.watch: bool = False

@property
def base(self) -> DeclarativeMeta:
return self.__base

@property
def token(self) -> str:
return self.__token
return str(self.__token)

@property
@no_type_check
def session(self) -> Session:
"""Instantiate the session for querying the database."""

Expand Down Expand Up @@ -109,15 +113,13 @@ def get_extension_module(self, extension_name) -> Union[str, None]:
return extension
return None

def load(self, environment: str, command_sync: bool = True):
def load(self, environment: str):
"""
Sets the environment and loads all the component of the application
"""

self.command_sync = command_sync
self.environment = environment

self.environment: str = environment
self.config.set_environment(environment)

self.load_logs()
self.load_models()
self.load_database()
Expand All @@ -129,7 +131,7 @@ def load_models(self):
for module in find_all_importables(models):
import_module(module)

def load_logs(self):
def load_logs(self) -> None:
file_handler: RotatingFileHandler = RotatingFileHandler(
f"logs/{self.config.current_environment}.log",
maxBytes=10000,
Expand Down
50 changes: 42 additions & 8 deletions grace/bot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from logging import info, warning, critical
from discord import Intents, LoginFailure
from discord.ext.commands import Bot as DiscordBot, when_mentioned_or
from discord.ext.commands.errors import (
ExtensionNotLoaded,
ExtensionAlreadyLoaded
)
from grace.application import Application, SectionProxy
from grace.watcher import Watcher, Observer


class Bot(DiscordBot):
Expand All @@ -17,6 +22,7 @@ class Bot(DiscordBot):
def __init__(self, app: Application, **kwargs):
self.app: Application = app
self.config: SectionProxy = self.app.client
self.watcher: Watcher = Watcher(self.on_reload)

command_prefix = kwargs.pop(
'command_prefix',
Expand All @@ -43,21 +49,49 @@ async def _load_extensions(self) -> None:
info(f"Loading module '{module}'")
await self.load_extension(module)

async def _sync_commands(self) -> None:
warning("Syncing application commands. This may take some time.")

if guild_id := self.config.get("guild"):
guild = self.get_guild(int(guild_id))
await self.tree.sync(guild=guild)

async def invoke(self, ctx):
if ctx.command:
info(f"'{ctx.command}' has been invoked by {ctx.author} "
f"({ctx.author.display_name})")
await super().invoke(ctx)

async def setup_hook(self) -> None:
await self._load_extensions()

if self.app.command_sync:
warning("Syncing application commands. This may take some time.")
await self._sync_commands()

if guild_id := self.config.get("guild"):
guild = self.get_guild(int(guild_id))
await self.tree.sync(guild=guild)
if self.app.watch:
self.watcher.start()

def run(self) -> None: # type: ignore[override]
"""Run the bot
async def load_extension(self, name: str) -> None:
try:
await super().load_extension(name)
except ExtensionAlreadyLoaded:
warning(f"Extension '{name}' already loaded, skipping.")

async def unload_extension(self, name: str) -> None:
try:
await super().unload_extension(name)
except ExtensionNotLoaded:
warning(f"Extension '{name}' was not loaded, skipping.")

async def on_reload(self):
for module in self.app.extension_modules:
info(f"Reloading extension '{module}'")

Override the `run` method to handle the token retrieval
"""
await self.unload_extension(module)
await self.load_extension(module)

def run(self) -> None: # type: ignore[override]
"""Override the `run` method to handle the token retrieval"""
try:
if self.app.token:
super().run(self.app.token)
Expand Down
7 changes: 6 additions & 1 deletion grace/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
| PID: {pid}
| Environment: {env}
| Syncing command: {command_sync}
| Watcher enabled: {watch}
| Using database: {database} with {dialect}
""".rstrip()

Expand Down Expand Up @@ -63,10 +64,13 @@ def db():

@app_cli.command()
@option("--sync/--no-sync", default=True, help="Sync the application command.")
@option("--watch/--no-watch", default=False, help="Enables hot reload.")
@pass_context
def run(ctx, sync):
def run(ctx, sync, watch):
app = ctx.obj["app"]
bot = ctx.obj["bot"]

app.watch = watch
app.command_sync = sync

_load_database(app)
Expand Down Expand Up @@ -147,6 +151,7 @@ def _show_application_info(app):
env=app.environment,
pid=getpid(),
command_sync=app.command_sync,
watch=app.watch,
database=app.database_infos["database"],
dialect=app.database_infos["dialect"],
))
Expand Down
4 changes: 2 additions & 2 deletions grace/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Config:
instantiate a second or multiple Config object, they will all share the
same environment. This is to say, that the config objects are identical.
"""
def __init__(self):
def __init__(self) -> None:
load_dotenv(".env")

self.__environment: Optional[str] = None
Expand Down Expand Up @@ -96,7 +96,7 @@ def current_environment(self) -> Optional[str]:
return self.__environment

@property
def database_uri(self) -> Union[str, URL]:
def database_uri(self) -> Union[str, URL, None]:
if self.database.get("url"):
return self.database.get("url")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ sqlalchemy_echo = True

[test]
log_level = ERROR
sqlalchemy_echo = True
sqlalchemy_echo = True
142 changes: 142 additions & 0 deletions grace/watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import sys
import asyncio
import importlib.util

from pathlib import Path
from typing import Callable, Coroutine, Any, Union
from logging import WARNING, getLogger, info, error
from watchdog.events import FileSystemEvent, FileSystemEventHandler
from watchdog.observers import Observer


# Suppress verbose watchdog logs
getLogger("watchdog").setLevel(WARNING)


ReloadCallback = Callable[[], Coroutine[Any, Any, None]]


class Watcher:
"""
Wrapper around the watchdog observer that watches a specified
directory (./bot) for Python file changes and manages event handling.

:param bot: The bot instance, must implement `on_reload()` and `unload_extension()`.
:type bot: Callable
"""
def __init__(self, callback: ReloadCallback) -> None:
self.callback: ReloadCallback = callback
self.observer: Observer = Observer()
self.watch_path: str = "./bot"

self.observer.schedule(
BotEventHandler(self.callback, self.watch_path),
self.watch_path,
recursive=True
)

def start(self) -> None:
"""Starts the file system observer."""
info("Starting file watcher...")
self.observer.start()

def stop(self) -> None:
"""Stops the file system observer and waits for it to shut down."""
info("Stopping file watcher...")
self.observer.stop()
self.observer.join()


class BotEventHandler(FileSystemEventHandler):
"""
Handles file events in the bot directory and calls the provided
async callback.

:param callback: Async function to call with the module name.
:type callback: Callable[[str], Coroutine]
:param base_path: Directory path to watch.
:type base_path: Path or str
"""
def __init__(self, callback: ReloadCallback, base_path: Union[Path, str]):
self.callback = callback
self.bot_path = Path(base_path).resolve()

def path_to_module_name(self, path: Path) -> str:
"""
Converts a file path to a Python module name.

:param path: Full path to the Python file.
:type path: Path
:return: Dotted module path (e.g., 'bot.module.sub').
:rtype: str
"""
relative_path = path.resolve().relative_to(self.bot_path)
parts = relative_path.with_suffix('').parts
return '.'.join(['bot'] + list(parts))

def reload_module(self, module_name: str) -> None:
"""
Reloads a module if it's already in sys.modules.

:param module_name: Dotted module name to reload.
:type module_name: str
"""
try:
if module_name in sys.modules:
info(f"Reloading module '{module_name}'")
importlib.reload(sys.modules[module_name])
except Exception as e:
error(f"Failed to reload module {module_name}: {e}")

def run_callback(self) -> None:
"""Runs a coroutine callback in the current or a new event loop."""
try:
loop = asyncio.get_running_loop()
asyncio.ensure_future(self.callback())
except RuntimeError:
asyncio.run(self.callback())

def on_modified(self, event: FileSystemEvent) -> None:
"""
Handles modified Python files by reloading them and calling the callback.

:param event: The filesystem event.
:type event: FileSystemEvent
"""
try:
if event.is_directory:
return

module_path = Path(event.src_path)
if module_path.suffix != '.py':
return

module_name = self.path_to_module_name(module_path)
if not module_name:
return

self.reload_module(module_name)
self.run_callback()
except Exception as e:
error(f"Failed to reload module {module_name}: {e}")


def on_deleted(self, event: FileSystemEvent) -> None:
"""
Handles deleted Python files by calling the callback with the module name.

:param event: The filesystem event.
:type event: FileSystemEvent
"""
try:
module_path = Path(event.src_path)
if module_path.suffix != '.py':
return

module_name = self.path_to_module_name(module_path)
if not module_name:
return

self.run_coro(self.callback())
except Exception as e:
error(f"Failed to reload module {module_name}: {e}")
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools>=64", "setuptools_scm>=8"]

[project]
name = "grace-framework"
version = "0.9.10-alpha"
version = "0.10.5-alpha"
authors = [
{ name="Code Society Lab" },
{ name="Simon Roy", email="simon.roy1211@gmail.com" }
Expand Down Expand Up @@ -31,6 +31,7 @@ dependencies = [
"flake8",
"pytest-mock",
"coverage",
"watchdog"
]

[project.urls]
Expand Down