diff --git a/grace/__init__.py b/grace/__init__.py index 25fa04c..3b87ff3 100644 --- a/grace/__init__.py +++ b/grace/__init__.py @@ -1 +1 @@ -__version__ = "0.9.10-alpha" +__version__ = "0.10.4-alpha" diff --git a/grace/application.py b/grace/application.py index 925de1b..fff1768 100644 --- a/grace/application.py +++ b/grace/application.py @@ -5,7 +5,7 @@ from logging.handlers import RotatingFileHandler from types import ModuleType -from typing import Generator, Any, Union, Dict +from typing import Generator, Any, Union, Dict, no_type_check from sqlalchemy import create_engine from sqlalchemy.engine import Engine @@ -21,13 +21,15 @@ create_database, drop_database ) - from pathlib import Path from grace.config import Config from grace.exceptions import ConfigError from grace.importer import find_all_importables, import_module +ConfigReturn = Union[str, int, float, None] + + class Application: """This class is the core of the application In other words, this class that manage the database, the application environment @@ -38,16 +40,17 @@ class Application: __session: Union[Session, None] = None __base: DeclarativeMeta = declarative_base() - def __init__(self): + def __init__(self) -> None: database_config_path: Path = Path("config/database.cfg") if not database_config_path.exists(): raise ConfigError("Unable to find the 'database.cfg' file.") - self.__token: str = self.config.get("discord", "token") + self.__token: str = str(self.config.get("discord", "token")) self.__engine: Union[Engine, None] = None self.command_sync: bool = True + self.watch: bool = False @property def base(self) -> DeclarativeMeta: @@ -55,9 +58,10 @@ def base(self) -> DeclarativeMeta: @property def token(self) -> str: - return self.__token + return str(self.__token) @property + @no_type_check def session(self) -> Session: """Instantiate the session for querying the database.""" @@ -109,15 +113,13 @@ def get_extension_module(self, extension_name) -> Union[str, None]: return extension return None - def load(self, environment: str, command_sync: bool = True): + def load(self, environment: str): """ Sets the environment and loads all the component of the application """ - - self.command_sync = command_sync - self.environment = environment - + self.environment: str = environment self.config.set_environment(environment) + self.load_logs() self.load_models() self.load_database() @@ -129,7 +131,7 @@ def load_models(self): for module in find_all_importables(models): import_module(module) - def load_logs(self): + def load_logs(self) -> None: file_handler: RotatingFileHandler = RotatingFileHandler( f"logs/{self.config.current_environment}.log", maxBytes=10000, diff --git a/grace/bot.py b/grace/bot.py index 00e9054..64c0bd1 100644 --- a/grace/bot.py +++ b/grace/bot.py @@ -1,7 +1,12 @@ from logging import info, warning, critical from discord import Intents, LoginFailure from discord.ext.commands import Bot as DiscordBot, when_mentioned_or +from discord.ext.commands.errors import ( + ExtensionNotLoaded, + ExtensionAlreadyLoaded +) from grace.application import Application, SectionProxy +from grace.watcher import Watcher, Observer class Bot(DiscordBot): @@ -17,6 +22,7 @@ class Bot(DiscordBot): def __init__(self, app: Application, **kwargs): self.app: Application = app self.config: SectionProxy = self.app.client + self.watcher: Watcher = Watcher(self.on_reload) command_prefix = kwargs.pop( 'command_prefix', @@ -43,21 +49,49 @@ async def _load_extensions(self) -> None: info(f"Loading module '{module}'") await self.load_extension(module) + async def _sync_commands(self) -> None: + warning("Syncing application commands. This may take some time.") + + if guild_id := self.config.get("guild"): + guild = self.get_guild(int(guild_id)) + await self.tree.sync(guild=guild) + + async def invoke(self, ctx): + if ctx.command: + info(f"'{ctx.command}' has been invoked by {ctx.author} " + f"({ctx.author.display_name})") + await super().invoke(ctx) + async def setup_hook(self) -> None: await self._load_extensions() if self.app.command_sync: - warning("Syncing application commands. This may take some time.") + await self._sync_commands() - if guild_id := self.config.get("guild"): - guild = self.get_guild(int(guild_id)) - await self.tree.sync(guild=guild) + if self.app.watch: + self.watcher.start() - def run(self) -> None: # type: ignore[override] - """Run the bot + async def load_extension(self, name: str) -> None: + try: + await super().load_extension(name) + except ExtensionAlreadyLoaded: + warning(f"Extension '{name}' already loaded, skipping.") + + async def unload_extension(self, name: str) -> None: + try: + await super().unload_extension(name) + except ExtensionNotLoaded: + warning(f"Extension '{name}' was not loaded, skipping.") + + async def on_reload(self): + for module in self.app.extension_modules: + info(f"Reloading extension '{module}'") - Override the `run` method to handle the token retrieval - """ + await self.unload_extension(module) + await self.load_extension(module) + + def run(self) -> None: # type: ignore[override] + """Override the `run` method to handle the token retrieval""" try: if self.app.token: super().run(self.app.token) diff --git a/grace/cli.py b/grace/cli.py index 4ca33bc..d4bb480 100644 --- a/grace/cli.py +++ b/grace/cli.py @@ -14,6 +14,7 @@ | PID: {pid} | Environment: {env} | Syncing command: {command_sync} +| Watcher enabled: {watch} | Using database: {database} with {dialect} """.rstrip() @@ -63,10 +64,13 @@ def db(): @app_cli.command() @option("--sync/--no-sync", default=True, help="Sync the application command.") +@option("--watch/--no-watch", default=False, help="Enables hot reload.") @pass_context -def run(ctx, sync): +def run(ctx, sync, watch): app = ctx.obj["app"] bot = ctx.obj["bot"] + + app.watch = watch app.command_sync = sync _load_database(app) @@ -147,6 +151,7 @@ def _show_application_info(app): env=app.environment, pid=getpid(), command_sync=app.command_sync, + watch=app.watch, database=app.database_infos["database"], dialect=app.database_infos["dialect"], )) diff --git a/grace/config.py b/grace/config.py index 4a4d51d..5c5402d 100644 --- a/grace/config.py +++ b/grace/config.py @@ -67,7 +67,7 @@ class Config: instantiate a second or multiple Config object, they will all share the same environment. This is to say, that the config objects are identical. """ - def __init__(self): + def __init__(self) -> None: load_dotenv(".env") self.__environment: Optional[str] = None @@ -96,7 +96,7 @@ def current_environment(self) -> Optional[str]: return self.__environment @property - def database_uri(self) -> Union[str, URL]: + def database_uri(self) -> Union[str, URL, None]: if self.database.get("url"): return self.database.get("url") diff --git a/grace/generators/templates/project/{{ cookiecutter.__project_slug }}/config/environment.cfg b/grace/generators/templates/project/{{ cookiecutter.__project_slug }}/config/environment.cfg index 69ef798..fee1dc6 100644 --- a/grace/generators/templates/project/{{ cookiecutter.__project_slug }}/config/environment.cfg +++ b/grace/generators/templates/project/{{ cookiecutter.__project_slug }}/config/environment.cfg @@ -8,4 +8,4 @@ sqlalchemy_echo = True [test] log_level = ERROR -sqlalchemy_echo = True +sqlalchemy_echo = True \ No newline at end of file diff --git a/grace/watcher.py b/grace/watcher.py new file mode 100644 index 0000000..1956d80 --- /dev/null +++ b/grace/watcher.py @@ -0,0 +1,142 @@ +import sys +import asyncio +import importlib.util + +from pathlib import Path +from typing import Callable, Coroutine, Any, Union +from logging import WARNING, getLogger, info, error +from watchdog.events import FileSystemEvent, FileSystemEventHandler +from watchdog.observers import Observer + + +# Suppress verbose watchdog logs +getLogger("watchdog").setLevel(WARNING) + + +ReloadCallback = Callable[[], Coroutine[Any, Any, None]] + + +class Watcher: + """ + Wrapper around the watchdog observer that watches a specified + directory (./bot) for Python file changes and manages event handling. + + :param bot: The bot instance, must implement `on_reload()` and `unload_extension()`. + :type bot: Callable + """ + def __init__(self, callback: ReloadCallback) -> None: + self.callback: ReloadCallback = callback + self.observer: Observer = Observer() + self.watch_path: str = "./bot" + + self.observer.schedule( + BotEventHandler(self.callback, self.watch_path), + self.watch_path, + recursive=True + ) + + def start(self) -> None: + """Starts the file system observer.""" + info("Starting file watcher...") + self.observer.start() + + def stop(self) -> None: + """Stops the file system observer and waits for it to shut down.""" + info("Stopping file watcher...") + self.observer.stop() + self.observer.join() + + +class BotEventHandler(FileSystemEventHandler): + """ + Handles file events in the bot directory and calls the provided + async callback. + + :param callback: Async function to call with the module name. + :type callback: Callable[[str], Coroutine] + :param base_path: Directory path to watch. + :type base_path: Path or str + """ + def __init__(self, callback: ReloadCallback, base_path: Union[Path, str]): + self.callback = callback + self.bot_path = Path(base_path).resolve() + + def path_to_module_name(self, path: Path) -> str: + """ + Converts a file path to a Python module name. + + :param path: Full path to the Python file. + :type path: Path + :return: Dotted module path (e.g., 'bot.module.sub'). + :rtype: str + """ + relative_path = path.resolve().relative_to(self.bot_path) + parts = relative_path.with_suffix('').parts + return '.'.join(['bot'] + list(parts)) + + def reload_module(self, module_name: str) -> None: + """ + Reloads a module if it's already in sys.modules. + + :param module_name: Dotted module name to reload. + :type module_name: str + """ + try: + if module_name in sys.modules: + info(f"Reloading module '{module_name}'") + importlib.reload(sys.modules[module_name]) + except Exception as e: + error(f"Failed to reload module {module_name}: {e}") + + def run_callback(self) -> None: + """Runs a coroutine callback in the current or a new event loop.""" + try: + loop = asyncio.get_running_loop() + asyncio.ensure_future(self.callback()) + except RuntimeError: + asyncio.run(self.callback()) + + def on_modified(self, event: FileSystemEvent) -> None: + """ + Handles modified Python files by reloading them and calling the callback. + + :param event: The filesystem event. + :type event: FileSystemEvent + """ + try: + if event.is_directory: + return + + module_path = Path(event.src_path) + if module_path.suffix != '.py': + return + + module_name = self.path_to_module_name(module_path) + if not module_name: + return + + self.reload_module(module_name) + self.run_callback() + except Exception as e: + error(f"Failed to reload module {module_name}: {e}") + + + def on_deleted(self, event: FileSystemEvent) -> None: + """ + Handles deleted Python files by calling the callback with the module name. + + :param event: The filesystem event. + :type event: FileSystemEvent + """ + try: + module_path = Path(event.src_path) + if module_path.suffix != '.py': + return + + module_name = self.path_to_module_name(module_path) + if not module_name: + return + + self.run_coro(self.callback()) + except Exception as e: + error(f"Failed to reload module {module_name}: {e}") \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 90cfa32..85628a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=64", "setuptools_scm>=8"] [project] name = "grace-framework" -version = "0.9.10-alpha" +version = "0.10.5-alpha" authors = [ { name="Code Society Lab" }, { name="Simon Roy", email="simon.roy1211@gmail.com" } @@ -31,6 +31,7 @@ dependencies = [ "flake8", "pytest-mock", "coverage", + "watchdog" ] [project.urls]