diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index 690cd42..95ea855 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -75,5 +75,8 @@ jobs: - name: Run tests run: poetry run pytest + - name: Run type check + run: poetry run pyright + - name: Run lint run: poetry run ruff check . diff --git a/.gitignore b/.gitignore index ccc9588..04f0eb1 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ __pycache__/ .vscode/ .idea/ .DS_Store + +# 📦 Poetry (config locale) +poetry.toml diff --git a/app/__init__.py b/app/__init__.py index 15fcd89..531c352 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -6,15 +6,29 @@ from app.controllers.article_controller import article_bp from app.controllers.comment_controller import comment_bp from app.controllers.login_controller import login_bp -from configurations.configuration_variables import env_vars +from config.configuration_variables import env_vars from database.database_setup import db_session -def shutdown_session(exception=None): +def shutdown_session(exception: BaseException | None = None) -> None: + """ + Removes the database session at the end of the request. + + Args: + exception (BaseException | None): The exception that triggered the teardown, if any. + """ + db_session.remove() -def initialize_flask_application(): +def initialize_flask_application() -> Flask: + """ + Initializes and configures the Flask application. + Sets the secret key based on the environment and registers blueprints. + + Returns: + Flask: The configured Flask application instance. + """ app = Flask(__name__) if os.getenv("PYTEST_CURRENT_TEST") or "pytest" in sys.modules: app.secret_key = env_vars.test_secret_key diff --git a/app/constants.py b/app/constants.py new file mode 100644 index 0000000..a0366a3 --- /dev/null +++ b/app/constants.py @@ -0,0 +1,26 @@ +from enum import Enum + + +class Role(str, Enum): + """ + Enum representing user roles within the application. + """ + ADMIN = "admin" + AUTHOR = "author" + USER = "user" + + +class SessionKey(str, Enum): + """ + Enum representing keys used in the Flask session. + """ + USER_ID = "user_id" + ROLE = "role" + USERNAME = "username" + + +class PaginationConfig: + """ + Configuration for pagination settings. + """ + ARTICLES_PER_PAGE = 10 diff --git a/app/controllers/article_controller.py b/app/controllers/article_controller.py index f102f20..249b372 100644 --- a/app/controllers/article_controller.py +++ b/app/controllers/article_controller.py @@ -1,7 +1,18 @@ import math -from flask import Blueprint, flash, redirect, render_template, request, session, url_for +from flask import ( + Blueprint, + flash, + redirect, + render_template, + request, + session, + url_for, +) +from werkzeug.wrappers import Response +from app.constants import PaginationConfig, Role, SessionKey +from app.controllers.decorators import roles_accepted from app.services.article_service import ArticleService from app.services.comment_service import CommentService from database.database_setup import db_session @@ -10,68 +21,147 @@ @article_bp.route("/") -def list_articles(): - current_page_number = 1 - page_number = request.args.get("page", current_page_number, type=int) - articles_per_page = 10 - articles = ArticleService.get_paginated_articles(page_number, articles_per_page) - total_articles = ArticleService.get_total_count() +def list_articles() -> str: + """ + Renders the homepage with a paginated list of articles. + + Returns: + str: The rendered HTML template for the homepage. + """ + page_number = request.args.get("page", 1, type=int) + articles_per_page = PaginationConfig.ARTICLES_PER_PAGE + + article_service = ArticleService(db_session) + articles = article_service.get_paginated_articles(page_number, articles_per_page) + total_articles = article_service.get_total_count() total_pages = math.ceil(total_articles / articles_per_page) - return render_template("index.html", articles=articles, page_number=page_number, total_pages=total_pages) + + return render_template( + "index.html", + articles=articles, + page_number=page_number, + total_pages=total_pages + ) @article_bp.route("/article/") -def view_article(article_id): - article = ArticleService.get_by_id(article_id) +def view_article(article_id: int) -> str | Response: + """ + Displays the details of a specific article and its comments. + + Args: + article_id (int): ID of the article to view. + + Returns: + str | Response: The rendered HTML template for the article or a redirect if the article is not found. + """ + article_service = ArticleService(db_session) + article = article_service.get_by_id(article_id) if not article: flash("Article not found.") return redirect(url_for("article.list_articles")) - comments = CommentService.get_tree_by_article_id(article_id) + comment_service = CommentService(db_session) + comments = comment_service.get_tree_by_article_id(article_id) return render_template("article_detail.html", article=article, comments=comments) @article_bp.route("/article/new", methods=["GET", "POST"]) -def create_article(): - if session.get("role") not in ["admin", "author"]: - flash("Access restricted.") - return redirect(url_for("article.list_articles")) +@roles_accepted(Role.ADMIN, Role.AUTHOR) +def create_article() -> str | Response: + """ + Handles the creation of a new blog article. + Restricted to 'admin' and 'author' roles. + + Returns: + str | Response: The rendered HTML form (GET) or a redirect to the article list after creation (POST). + """ if request.method == "POST": - ArticleService.create_article(request.form.get("title"), request.form.get("content"), session["user_id"]) + article_service = ArticleService(db_session) + title = str(request.form.get("title") or "") + content = str(request.form.get("content") or "") + user_id = int(session.get(SessionKey.USER_ID) or 0) + + article_service.create_article( + title=title, + content=content, + author_id=user_id + ) db_session.commit() flash("Article published!") return redirect(url_for("article.list_articles")) + return render_template("article_form.html", article=None) @article_bp.route("/article//edit", methods=["GET", "POST"]) -def edit_article(article_id): +@roles_accepted(Role.ADMIN, Role.AUTHOR, Role.USER) +def edit_article(article_id: int) -> str | Response: + """ + Handles the editing of an existing article. + Ensures the user is authorized to perform the update. + + Args: + article_id (int): ID of the article to edit. + + Returns: + str | Response: The rendered HTML form (GET) or a redirect to the updated article (POST). + """ + article_service = ArticleService(db_session) + if request.method == "POST": - article = ArticleService.update_article( - article_id, - session.get("user_id"), - session.get("role"), - request.form.get("title"), - request.form.get("content") + user_id = int(session.get(SessionKey.USER_ID) or 0) + role = str(session.get(SessionKey.ROLE) or "") + title = str(request.form.get("title") or "") + content = str(request.form.get("content") or "") + + article = article_service.update_article( + article_id=article_id, + user_id=user_id, + role=role, + title=title, + content=content ) if article: db_session.commit() flash("Article updated!") return redirect(url_for("article.view_article", article_id=article_id)) + flash("Update failed: Unauthorized or not found.") return redirect(url_for("article.list_articles")) - article = ArticleService.get_by_id(article_id) + + article = article_service.get_by_id(article_id) if not article: flash("Article not found.") return redirect(url_for("article.list_articles")) + return render_template("article_form.html", article=article) @article_bp.route("/article//delete") -def delete_article(article_id): - if ArticleService.delete_article(article_id, session.get("user_id"), session.get("role")): +@roles_accepted(Role.ADMIN, Role.AUTHOR, Role.USER) +def delete_article(article_id: int) -> Response: + """ + Handles the deletion of an article. + + Args: + article_id (int): ID of the article to delete. + + Returns: + Response: A redirect to the article list after deletion. + """ + article_service = ArticleService(db_session) + user_id = int(session.get(SessionKey.USER_ID) or 0) + role = str(session.get(SessionKey.ROLE) or "") + + if article_service.delete_article( + article_id=article_id, + user_id=user_id, + role=role + ): db_session.commit() flash("Article deleted.") else: - flash("Delete failed.") + flash("Delete failed: Unauthorized or not found.") + return redirect(url_for("article.list_articles")) diff --git a/app/controllers/comment_controller.py b/app/controllers/comment_controller.py index 1a608bf..6b491a9 100644 --- a/app/controllers/comment_controller.py +++ b/app/controllers/comment_controller.py @@ -1,5 +1,8 @@ from flask import Blueprint, flash, redirect, request, session, url_for +from werkzeug.wrappers import Response +from app.constants import Role, SessionKey +from app.controllers.decorators import login_required, roles_accepted from app.services.comment_service import CommentService from database.database_setup import db_session @@ -7,27 +10,53 @@ @comment_bp.route("/create/", methods=["POST"]) -def create_comment(article_id): - # Exception to add here - if not session.get("user_id"): - flash("Login required.") - return redirect(url_for("login.render_login_page")) - if CommentService.create_comment(article_id, session["user_id"], request.form.get("content")): +@login_required +def create_comment(article_id: int) -> Response: + """ + Handles the creation of a new comment on an article. + Requires the user to be logged in. + + Args: + article_id (int): ID of the article being commented on. + + Returns: + Response: A redirect to the article view or login page. + """ + comment_service = CommentService(db_session) + content = str(request.form.get("content") or "") + if comment_service.create_comment( + article_id=article_id, + user_id=session[SessionKey.USER_ID], + content=content + ): db_session.commit() flash("Comment added.") else: flash("Error adding comment.") + return redirect(url_for("article.view_article", article_id=article_id)) @comment_bp.route("/reply/", methods=["POST"]) -def reply_to_comment(parent_comment_id): - # Exception to add here - if not session.get("user_id"): - flash("Login required.") - return redirect(url_for("login.render_login_page")) +@login_required +def reply_to_comment(parent_comment_id: int) -> Response: + """ + Handles the creation of a reply to an existing comment. + Requires the user to be logged in. - article_id = CommentService.create_reply(parent_comment_id, session["user_id"], request.form.get("content")) + Args: + parent_comment_id (int): ID of the comment being replied to. + + Returns: + Response: A redirect to the article view or the article list in case of error. + """ + comment_service = CommentService(db_session) + content = str(request.form.get("content") or "") + article_id = comment_service.create_reply( + parent_comment_id=parent_comment_id, + user_id=session[SessionKey.USER_ID], + content=content + ) if article_id: db_session.commit() return redirect(url_for("article.view_article", article_id=article_id)) @@ -37,8 +66,24 @@ def reply_to_comment(parent_comment_id): @comment_bp.route("/delete/") -def delete_comment(comment_id): - article_id = CommentService.delete_comment(comment_id, session.get("role")) +@roles_accepted(Role.ADMIN) +def delete_comment(comment_id: int) -> Response: + """ + Handles the deletion of a comment. + Restricted to users with the 'admin' role. + + Args: + comment_id (int): ID of the comment to delete. + + Returns: + Response: A redirect to the article view or article list after deletion. + """ + comment_service = CommentService(db_session) + role = str(session.get(SessionKey.ROLE) or "") + article_id = comment_service.delete_comment( + comment_id=comment_id, + role=role + ) if article_id: db_session.commit() flash("Comment deleted.") diff --git a/app/controllers/decorators.py b/app/controllers/decorators.py new file mode 100644 index 0000000..03adf42 --- /dev/null +++ b/app/controllers/decorators.py @@ -0,0 +1,54 @@ +from collections.abc import Callable +from functools import wraps +from typing import Any + +from flask import flash, redirect, session, url_for + +from app.constants import Role, SessionKey + + +def login_required(f: Callable[..., Any]) -> Callable[..., Any]: + """ + Decorator to ensure that a user is logged in before accessing a route. + + Args: + f (Callable): The route function to wrap. + + Returns: + Callable: The wrapped function. + """ + + @wraps(f) + def decorated_function(*args: Any, **kwargs: Any) -> Any: + if not session.get(SessionKey.USER_ID): + flash("Login required.") + return redirect(url_for("login.render_login_page")) + return f(*args, **kwargs) + + return decorated_function + + +def roles_accepted(*roles: Role) -> Callable[..., Any]: + """ + Decorator to ensure that a logged-in user has one of the required roles. + + Args: + *roles (Role): Variable list of accepted roles. + + Returns: + Callable: A decorator that wraps the route function. + """ + + def decorator(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + @login_required + def decorated_function(*args: Any, **kwargs: Any) -> Any: + user_role = session.get(SessionKey.ROLE) + if user_role not in [role.value for role in roles]: + flash("Access restricted: Insufficient permissions.") + return redirect(url_for("article.list_articles")) + return f(*args, **kwargs) + + return decorated_function + + return decorator diff --git a/app/controllers/login_controller.py b/app/controllers/login_controller.py index 1b8e42d..7aecea2 100644 --- a/app/controllers/login_controller.py +++ b/app/controllers/login_controller.py @@ -1,28 +1,57 @@ from flask import Blueprint, flash, redirect, render_template, request, session, url_for +from werkzeug.wrappers import Response +from app.constants import SessionKey from app.services.login_service import LoginService +from database.database_setup import db_session login_bp = Blueprint("login", __name__) @login_bp.route("/login-page") -def render_login_page(): +def render_login_page() -> str: + """ + Renders the login page. + + Returns: + str: The rendered HTML template for the login page. + """ return render_template("login.html") @login_bp.route("/login", methods=["POST"]) -def login_authentication(): - user_data = LoginService.authenticate_user(request.form.get("username"), request.form.get("password")) - if user_data: - session["user_id"] = user_data["id"] - session["username"] = user_data["username"] - session["role"] = user_data["role"] +def login_authentication() -> Response: + """ + Handles user authentication. + Validates credentials and sets up the session. + + Returns: + Response: A redirect to the article list on success, or back to the login page on failure. + """ + login_service = LoginService(db_session) + username = str(request.form.get("username") or "") + password = str(request.form.get("password") or "") + user = login_service.authenticate_user( + username=username, + password=password + ) + if user: + session[SessionKey.USER_ID] = user.account_id + session[SessionKey.USERNAME] = user.account_username + session[SessionKey.ROLE] = user.account_role return redirect(url_for("article.list_articles")) + flash("Incorrect credentials.") return redirect(url_for("login.render_login_page")) @login_bp.route("/logout") -def logout(): +def logout() -> Response: + """ + Clears the user session and redirects to the article list. + + Returns: + Response: A redirect to the article list after clearing the session. + """ session.clear() return redirect(url_for("article.list_articles")) diff --git a/app/models/account_model.py b/app/models/account_model.py index 2999cf7..d0d4c9d 100644 --- a/app/models/account_model.py +++ b/app/models/account_model.py @@ -1,26 +1,41 @@ +from datetime import datetime + from sqlalchemy import ( TIMESTAMP, CheckConstraint, - Column, Integer, Text, func, ) -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from database.database_setup import Base class Account(Base): + """ + Represents a user account in the system. + + Attributes: + account_id (int): Unique identifier for the account (Primary Key). + account_username (str): Unique username used for authentication. + account_password (str): Securely hashed password string. + account_email (str | None): Optional email address for the user. + account_role (str): Permissions role ('admin', 'author', or 'user'). + account_created_at (datetime): Automated timestamp of account creation. + articles (list[Article]): Collection of articles authored by this account. + comments (list[Comment]): Collection of comments written by this account. + """ + __tablename__ = "accounts" __table_args__ = (CheckConstraint(sqltext="account_role IN ('admin', 'author', 'user')", name="accounts_role_check"),) - account_id = Column(name="account_id", type_=Integer, primary_key=True, autoincrement=True) - account_username = Column(name="account_username", type_=Text, unique=True, nullable=False) - account_password = Column(name="account_password", type_=Text, nullable=False) - account_email = Column(name="account_email", type_=Text) - account_role = Column(name="account_role", type_=Text, nullable=False) - account_created_at = Column(name="account_created_at", type_=TIMESTAMP, server_default=func.now()) + account_id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + account_username: Mapped[str] = mapped_column(Text, unique=True, nullable=False) + account_password: Mapped[str] = mapped_column(Text, nullable=False) + account_email: Mapped[str | None] = mapped_column(Text) + account_role: Mapped[str] = mapped_column(Text, nullable=False) + account_created_at: Mapped[datetime] = mapped_column(TIMESTAMP, server_default=func.now()) - articles = relationship(argument="Article", back_populates="article_author", cascade="all, delete-orphan") - comments = relationship(argument="Comment", back_populates="comment_author", cascade="all, delete-orphan") + articles = relationship("Article", back_populates="article_author", cascade="all, delete-orphan") + comments = relationship("Comment", back_populates="comment_author", cascade="all, delete-orphan") diff --git a/app/models/article_model.py b/app/models/article_model.py index f985929..aaec205 100644 --- a/app/models/article_model.py +++ b/app/models/article_model.py @@ -1,24 +1,38 @@ +from datetime import datetime + from sqlalchemy import ( TIMESTAMP, - Column, ForeignKey, Integer, Text, func, ) -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from database.database_setup import Base class Article(Base): + """ + Represents a blog article. + + Attributes: + article_id (int): Unique identifier for the article (Primary Key). + article_author_id (int): Foreign key referencing the author's Account. + article_title (str): Title of the article. + article_content (str): Full text content of the article. + article_published_at (datetime): Automated timestamp of publication. + article_author (Account): Relationship to the author's Account instance. + article_comments (list[Comment]): Collection of comments linked to this article. + """ + __tablename__ = "articles" - article_id = Column(name="article_id", type_=Integer, primary_key=True, autoincrement=True) - article_author_id = Column(ForeignKey("accounts.account_id", ondelete="CASCADE"), name="article_author_id", type_=Integer, nullable=False) - article_title = Column(name="article_title", type_=Text, nullable=False) - article_content = Column(name="article_content", type_=Text, nullable=False) - article_published_at = Column(name="article_published_at", type_=TIMESTAMP, server_default=func.now()) + article_id: Mapped[int] = mapped_column("article_id", Integer, primary_key=True, autoincrement=True) + article_author_id: Mapped[int] = mapped_column(ForeignKey("accounts.account_id", ondelete="CASCADE"), name="article_author_id", nullable=False) + article_title: Mapped[str] = mapped_column("article_title", Text, nullable=False) + article_content: Mapped[str] = mapped_column("article_content", Text, nullable=False) + article_published_at: Mapped[datetime] = mapped_column("article_published_at", TIMESTAMP, server_default=func.now()) - article_author = relationship(argument="Account", back_populates="articles") - article_comments = relationship(argument="Comment", back_populates="comment_article", cascade="all, delete-orphan") + article_author = relationship("Account", back_populates="articles") + article_comments = relationship("Comment", back_populates="comment_article", cascade="all, delete-orphan") diff --git a/app/models/comment_model.py b/app/models/comment_model.py index fc1cb62..7e2f3bd 100644 --- a/app/models/comment_model.py +++ b/app/models/comment_model.py @@ -1,36 +1,54 @@ +from datetime import datetime + from sqlalchemy import ( TIMESTAMP, - Column, ForeignKey, Integer, Text, func, ) -from sqlalchemy.orm import relationship +from sqlalchemy.orm import Mapped, mapped_column, relationship from database.database_setup import Base class Comment(Base): + """ + Represents a comment or a reply on an article. + + Attributes: + comment_id (int): Unique identifier for the comment (Primary Key). + comment_article_id (int): Foreign key referencing the associated article. + comment_written_account_id (int): Foreign key referencing the author's account. + comment_reply_to (int | None): Foreign key referencing a parent comment (for replies). + comment_content (str): Text content of the comment. + comment_posted_at (datetime): Automated timestamp of when the comment was posted. + comment_article (Article): Relationship to the parent article instance. + comment_author (Account): Relationship to the author's account instance. + reply_to_comment (Comment | None): Relationship to the parent comment if this is a reply. + comment_replies (list[Comment]): Collection of replies (child comments). + """ + __tablename__ = "comments" - comment_id = Column(name="comment_id", type_=Integer, primary_key=True, autoincrement=True) - comment_article_id = Column(ForeignKey("articles.article_id", ondelete="CASCADE"), name="comment_article_id", type_=Integer, nullable=False) - comment_written_account_id = Column(ForeignKey("accounts.account_id", ondelete="CASCADE"), name="comment_written_account_id", type_=Integer, nullable=False) - comment_reply_to = Column(ForeignKey("comments.comment_id"), name="comment_reply_to", type_=Integer, nullable=True) - comment_content = Column(name="comment_content", type_=Text, nullable=False) - comment_posted_at = Column(name="comment_posted_at", type_=TIMESTAMP, server_default=func.now()) + comment_id: Mapped[int] = mapped_column("comment_id", Integer, primary_key=True, autoincrement=True) + comment_article_id: Mapped[int] = mapped_column(ForeignKey("articles.article_id", ondelete="CASCADE"), name="comment_article_id", nullable=False) + comment_written_account_id: Mapped[int] = mapped_column(ForeignKey("accounts.account_id", ondelete="CASCADE"), name="comment_written_account_id", nullable=False) + comment_reply_to: Mapped[int | None] = mapped_column(ForeignKey("comments.comment_id"), name="comment_reply_to", nullable=True) + comment_content: Mapped[str] = mapped_column("comment_content", Text, nullable=False) + comment_posted_at: Mapped[datetime] = mapped_column("comment_posted_at", TIMESTAMP, server_default=func.now()) + + comment_article = relationship("Article", back_populates="article_comments") + comment_author = relationship("Account", back_populates="comments") - comment_article = relationship(argument="Article", back_populates="article_comments") - comment_author = relationship(argument="Account", back_populates="comments") reply_to_comment = relationship( - argument="Comment", + "Comment", remote_side=[comment_id], back_populates="comment_replies", uselist=False, ) comment_replies = relationship( - argument="Comment", + "Comment", back_populates="reply_to_comment", cascade="all, delete-orphan", ) diff --git a/app/services/article_service.py b/app/services/article_service.py index 80f882f..d9f997c 100644 --- a/app/services/article_service.py +++ b/app/services/article_service.py @@ -1,14 +1,36 @@ -from sqlalchemy import func, select -from sqlalchemy.orm import defer, joinedload +from collections.abc import Sequence +from sqlalchemy import Row, func, select +from sqlalchemy.orm import Session, defer, joinedload, scoped_session + +from app.constants import Role from app.models.account_model import Account from app.models.article_model import Article -from database.database_setup import db_session class ArticleService: - @staticmethod - def get_all_ordered_by_date(): + """ + Service class responsible for business logic operations related to Articles. + Handles creating, retrieving, updating and deleting articles as well as pagination logic. + """ + + def __init__(self, session: Session | scoped_session[Session]): + """ + Initialize the service with a database session (Dependency Injection). + Supports both standard Session and scoped_session. + + Args: + session (Session | scoped_session[Session]): The SQLAlchemy database session. + """ + self.session = session + + def get_all_ordered_by_date(self) -> Sequence[Article]: + """ + Retrieves all articles ordered by their publication date in descending order. + + Returns: + Sequence[Article]: A sequence of Article instances. + """ query = ( select(Article) .options( @@ -17,22 +39,67 @@ def get_all_ordered_by_date(): ) .order_by(Article.article_published_at.desc()) ) - return db_session.execute(query).unique().scalars().all() + return self.session.execute(query).unique().scalars().all() + + def get_by_id(self, article_id: int) -> Article | None: + """ + Retrieves a single article by its ID. + + Args: + article_id (int): The unique identifier of the article. + + Returns: + Article | None: The Article instance if found, None otherwise. + """ + query = ( + select(Article) + .where(Article.article_id == article_id) + .options(joinedload(Article.article_author)) + ) + return self.session.execute(query).unique().scalar_one_or_none() - @staticmethod - def get_by_id(article_id): - query = select(Article).where(Article.article_id == article_id).options(joinedload(Article.article_author)) - return db_session.execute(query).unique().scalar_one_or_none() + def create_article(self, title: str, content: str, author_id: int) -> Article: + """ + Creates a new article instance. - @staticmethod - def create_article(title, content, author_id): - new_article = Article(article_title=title, article_content=content, article_author_id=author_id) - db_session.add(new_article) + Args: + title (str): The title of the new article. + content (str): The body content of the new article. + author_id (int): The unique identifier of the user creating the article. + + Returns: + Article: The newly created Article instance. + """ + new_article = Article( + article_title=title, + article_content=content, + article_author_id=author_id + ) + self.session.add(new_article) return new_article - @staticmethod - def update_article(article_id, user_id, role, title, content): - article = ArticleService.get_by_id(article_id) + def update_article( + self, + article_id: int, + user_id: int, + role: str, + title: str, + content: str + ) -> Article | None: + """ + Updates an existing article ensuring the requester is the original author. + + Args: + article_id (int): ID of the article to update. + user_id (int): ID of the user requesting the update. + role (str): Role of the user requesting the update. + title (str): New title for the article. + content (str): New content for the article. + + Returns: + Article | None: The Article instance or None if unauthorized/not found. + """ + article = self.get_by_id(article_id) if not article or article.article_author_id != user_id: return None @@ -40,17 +107,36 @@ def update_article(article_id, user_id, role, title, content): article.article_content = content return article - @staticmethod - def delete_article(article_id, user_id, role): - article = ArticleService.get_by_id(article_id) - if not article or (role != "admin" and article.article_author_id != user_id): + def delete_article(self, article_id: int, user_id: int, role: str) -> bool: + """ + Deletes an article. Only the original author or an Admin can delete it. + + Args: + article_id (int): ID of the article to delete. + user_id (int): ID of the user requesting deletion. + role (str): Role of the user requesting deletion. + + Returns: + bool: True if deleted, False otherwise. + """ + + article = self.get_by_id(article_id) + if not article: return False - db_session.delete(article) - return True + if article.article_author_id == user_id or role == Role.ADMIN: + self.session.delete(article) + return True + + return False - @staticmethod - def get_paginated_articles(page, per_page): + def get_paginated_articles(self, page: int, per_page: int) -> Sequence[Row]: + """ + Retrieves a paginated list of articles containing specific columns. + + Returns: + Sequence[Row]: A sequence of SQLAlchemy Row objects containing selected columns. + """ query = ( select( Article.article_id, @@ -64,9 +150,15 @@ def get_paginated_articles(page, per_page): .limit(per_page) .offset((page - 1) * per_page) ) - return db_session.execute(query).all() + return self.session.execute(query).all() + + def get_total_count(self) -> int: + """ + Retrieves the total number of articles. - @staticmethod - def get_total_count(): + Returns: + int: The total count. + """ query = select(func.count(Article.article_id)) - return db_session.execute(query).scalar() + count = self.session.execute(query).scalar() + return int(count) if count is not None else 0 diff --git a/app/services/comment_service.py b/app/services/comment_service.py index 224b678..cd5978b 100644 --- a/app/services/comment_service.py +++ b/app/services/comment_service.py @@ -1,55 +1,137 @@ +from collections.abc import Sequence + from sqlalchemy import select -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import Session, joinedload, scoped_session +from app.constants import Role from app.models.article_model import Article from app.models.comment_model import Comment -from database.database_setup import db_session class CommentService: - @staticmethod - def create_reply(parent_comment_id, user_id, content): - parent = db_session.get(Comment, parent_comment_id) + """ + Service class responsible for business logic operations related to Comments. + Handles creating top-level comments, replies, retrieving comments by article, and deleting comments. + """ + + def __init__(self, session: Session | scoped_session[Session]): + """ + Initialize the service with a database session (Dependency Injection). + Supports both standard Session and scoped_session. + + Args: + session (Session | scoped_session[Session]): The SQLAlchemy database session to use for queries. + """ + self.session = session + + def create_reply(self, parent_comment_id: int, user_id: int, content: str) -> int | None: + """ + Creates a reply to an existing comment. A reply is linked either to the parent directly + or to the parent's top-level comment (threading logic). + + Args: + parent_comment_id (int): The ID of the comment being replied to. + user_id (int): The identifier of the user creating the reply. + content (str): The text content of the reply. + + Returns: + int | None: The article ID the comment belongs to if successful, None if the parent comment is not found. + """ + parent = self.session.get(Comment, parent_comment_id) if not parent: return None actual_parent_id = parent.comment_reply_to if parent.comment_reply_to else parent.comment_id + new_reply = Comment( comment_article_id=parent.comment_article_id, comment_written_account_id=user_id, comment_content=content, comment_reply_to=actual_parent_id ) - db_session.add(new_reply) + self.session.add(new_reply) return parent.comment_article_id - @staticmethod - def get_by_id(comment_id): - query = select(Comment).options(joinedload(Comment.comment_author)).where(Comment.comment_id == comment_id) - return db_session.execute(query).unique().scalar_one_or_none() + def get_by_id(self, comment_id: int) -> Comment | None: + """ + Retrieves a single comment by its ID. Eagerly loads the author information. - @staticmethod - def create_comment(article_id, user_id, content): - article = db_session.get(Article, article_id) + Args: + comment_id (int): The unique identifier of the comment. + + Returns: + Comment | None: The Comment instance if found, None otherwise. + """ + query = ( + select(Comment) + .options(joinedload(Comment.comment_author)) + .where(Comment.comment_id == comment_id) + ) + return self.session.execute(query).unique().scalar_one_or_none() + + def create_comment(self, article_id: int, user_id: int, content: str) -> bool: + """ + Creates a top-level comment on an article. + + Args: + article_id (int): The ID of the article being commented on. + user_id (int): The identifier of the user creating the comment. + content (str): The body text of the comment. + + Returns: + bool: True if the comment was created successfully, False if the article does not exist. + """ + article = self.session.get(Article, article_id) if not article: return False - new_comment = Comment(comment_article_id=article_id, comment_written_account_id=user_id, comment_content=content) - db_session.add(new_comment) + + new_comment = Comment( + comment_article_id=article_id, + comment_written_account_id=user_id, + comment_content=content + ) + self.session.add(new_comment) return True - @staticmethod - def delete_comment(comment_id, role): - if role != "admin": - return False - comment = db_session.get(Comment, comment_id) + def delete_comment(self, comment_id: int, role: str) -> int | None: + """ + Deletes a comment. Only users with the 'admin' role can delete comments. + + Args: + comment_id (int): The ID of the comment to delete. + role (str): The role of the user attempting the deletion. + + Returns: + int | None: The article ID the comment belonged to if successful, None if unauthorized or not found. + """ + if role != Role.ADMIN: + return None + + comment = self.session.get(Comment, comment_id) if not comment: - return False + return None + article_id = comment.comment_article_id - db_session.delete(comment) + self.session.delete(comment) return article_id - @staticmethod - def get_tree_by_article_id(article_id): - query = select(Comment).where(Comment.comment_article_id == article_id).options(joinedload(Comment.comment_author)).order_by(Comment.comment_posted_at.asc()) - all_comments = db_session.execute(query).unique().scalars().all() + def get_tree_by_article_id(self, article_id: int) -> Sequence[Comment]: + """ + Retrieves all comments for a specific article as a threaded tree structure. + Returns only the top-level comments. + + Args: + article_id (int): The ID of the article. + + Returns: + Sequence[Comment]: A sequence of top-level Comment instances for the given article. + """ + query = ( + select(Comment) + .where(Comment.comment_article_id == article_id) + .options(joinedload(Comment.comment_author)) + .order_by(Comment.comment_posted_at.asc()) + ) + all_comments = self.session.execute(query).unique().scalars().all() + return [c for c in all_comments if c.comment_reply_to is None] diff --git a/app/services/login_service.py b/app/services/login_service.py index b5a06f8..bc1c5ad 100644 --- a/app/services/login_service.py +++ b/app/services/login_service.py @@ -1,18 +1,39 @@ from sqlalchemy import select +from sqlalchemy.orm import Session, scoped_session from app.models.account_model import Account -from database.database_setup import db_session class LoginService: - @staticmethod - def authenticate_user(username, password): + """ + Service class responsible for handling user authentication logic. + """ + + def __init__(self, session: Session | scoped_session[Session]): + """ + Initialize the service with a database session (Dependency Injection). + Supports both standard Session and scoped_session proxy objects. + + Args: + session (Session | scoped_session[Session]): The SQLAlchemy database session to use for queries. + """ + self.session = session + + def authenticate_user(self, username: str, password: str) -> Account | None: + """ + Validates the user's credentials against the database. + + Args: + username (str): The username provided by the user. + password (str): The plaintext password provided by the user. + + Returns: + Account | None: The authenticated Account instance if credentials match, None otherwise. + """ query = select(Account).where(Account.account_username == username) - user = db_session.execute(query).scalar_one_or_none() + user = self.session.execute(query).scalar_one_or_none() + if user and user.account_password == password: - return { - "id": user.account_id, - "username": user.account_username, - "role": user.account_role - } + return user + return None diff --git a/configurations/__init__.py b/config/__init__.py similarity index 100% rename from configurations/__init__.py rename to config/__init__.py diff --git a/config/configuration_variables.py b/config/configuration_variables.py new file mode 100644 index 0000000..41e7425 --- /dev/null +++ b/config/configuration_variables.py @@ -0,0 +1,76 @@ +import os +from pathlib import Path + +from dotenv import load_dotenv + +BASE_DIR = Path(__file__).resolve().parent.parent + +load_dotenv(BASE_DIR / ".env") +load_dotenv(BASE_DIR / ".env.test") + +def get_env_variable(name: str) -> str: + """ + Retrieves an environment variable or raises a RuntimeError if missing. + + Args: + name (str): The name of the environment variable. + + Returns: + str: The value of the environment variable. + + Raises: + RuntimeError: If the environment variable is not set. + """ + value = os.getenv(name) + if not value: + raise RuntimeError(f"Missing mandatory environment variable: '{name}'") + return value + + +class EnvVariablesConfig: + """ + Configuration class to manage environment variables for the application. + """ + + @property + def database_url(self) -> str: + """ + The production database URL. + + Returns: + str: Database connection string. + """ + return get_env_variable("DATABASE_URL") + + @property + def secret_key(self) -> str: + """ + The secret key for Flask sessions. + + Returns: + str: Secret key string. + """ + return get_env_variable("SECRET_KEY") + + @property + def test_database_url(self) -> str: + """ + The test database URL. + + Returns: + str: Test database connection string. + """ + return get_env_variable("TEST_DATABASE_URL") + + @property + def test_secret_key(self) -> str: + """ + The secret key for Flask sessions during testing. + + Returns: + str: Test secret key string. + """ + return get_env_variable("TEST_SECRET_KEY") + + +env_vars = EnvVariablesConfig() diff --git a/configurations/configuration_variables.py b/configurations/configuration_variables.py deleted file mode 100644 index 78356b5..0000000 --- a/configurations/configuration_variables.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -from pathlib import Path - -from dotenv import load_dotenv - -BASE_DIR = Path(__file__).resolve().parent.parent - -load_dotenv(BASE_DIR / ".env") -load_dotenv(BASE_DIR / ".env.test") - -def get_env_variable(name: str): - value = os.getenv(name) - if not value: - raise RuntimeError(f"Missing mandatory environment variable: '{name}'") - return value - -class EnvVariablesConfig: - @property - def database_url(self): - return get_env_variable("DATABASE_URL") - - @property - def secret_key(self): - return get_env_variable("SECRET_KEY") - - @property - def test_database_url(self): - return get_env_variable("TEST_DATABASE_URL") - - @property - def test_secret_key(self): - return get_env_variable("TEST_SECRET_KEY") - -env_vars = EnvVariablesConfig() diff --git a/database/database_setup.py b/database/database_setup.py index eb05f8f..9bddb9c 100644 --- a/database/database_setup.py +++ b/database/database_setup.py @@ -4,12 +4,16 @@ from sqlalchemy import create_engine from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker -from configurations.configuration_variables import env_vars +from config.configuration_variables import env_vars + +""" +Database engine and session management. +""" if os.getenv("PYTEST_CURRENT_TEST") or "pytest" in sys.modules: - database_url = env_vars.test_database_url + database_url: str = env_vars.test_database_url else: - database_url = env_vars.database_url + database_url: str = env_vars.database_url database_engine = create_engine(database_url) session_factory = sessionmaker(bind=database_engine) diff --git a/main.py b/main.py index c02d1f8..3491f7b 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,8 @@ from app import initialize_flask_application if __name__ == "__main__": + """ + Main entry point for running the Flask development server. + """ app = initialize_flask_application() app.run(debug=True) diff --git a/poetry.lock b/poetry.lock index 521da8d..d3967d2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -455,6 +455,18 @@ files = [ {file = "mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558"}, ] +[[package]] +name = "nodeenv" +version = "1.10.0" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +files = [ + {file = "nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827"}, + {file = "nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb"}, +] + [[package]] name = "packaging" version = "25.0" @@ -737,6 +749,27 @@ files = [ [package.extras] windows-terminal = ["colorama (>=0.4.6)"] +[[package]] +name = "pyright" +version = "1.1.408" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +groups = ["dev"] +files = [ + {file = "pyright-1.1.408-py3-none-any.whl", hash = "sha256:090b32865f4fdb1e0e6cd82bf5618480d48eecd2eb2e70f960982a3d9a4c17c1"}, + {file = "pyright-1.1.408.tar.gz", hash = "sha256:f28f2321f96852fa50b5829ea492f6adb0e6954568d1caa3f3af3a5f555eb684"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" +typing-extensions = ">=4.1" + +[package.extras] +all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] +nodejs = ["nodejs-wheel-binaries"] + [[package]] name = "pytest" version = "9.0.0" @@ -942,7 +975,7 @@ version = "4.15.0" description = "Backported and Experimental Type Hints for Python 3.9+" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, diff --git a/pyproject.toml b/pyproject.toml index 16dafe5..41a49f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ dev = [ "flake8 (>=7.3.0,<8.0.0)", "pytest (>=9.0.0,<10.0.0)", "pytest-flask (>=1.3.0,<2.0.0)", - "ruff (>=0.14.5,<0.15.0)" + "ruff (>=0.14.5,<0.15.0)", + "pyright (>=1.1.408,<2.0.0)" ] [tool.ruff] diff --git a/tests/conftest.py b/tests/conftest.py index 28ccc23..7c19d36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,16 +1,29 @@ +from collections.abc import Generator +from typing import Any + import pytest +from flask import Flask +from flask.testing import FlaskClient from sqlalchemy import text +from sqlalchemy.engine import Connection +from sqlalchemy.orm import Session, scoped_session from app import initialize_flask_application from app.models.account_model import Account from app.models.article_model import Article from app.models.comment_model import Comment -from configurations.configuration_variables import env_vars +from config.configuration_variables import env_vars from database.database_setup import Base, database_engine from database.database_setup import db_session as app_db_session -def truncate_all_tables(connection): +def truncate_all_tables(connection: Connection) -> None: + """ + Truncates all tables in the database to ensure a clean state for tests. + + Args: + connection (Connection): SQLAlchemy connection object. + """ tables = Base.metadata.sorted_tables table_names = ", ".join(f'"{t.name}"' for t in tables) if table_names: @@ -18,7 +31,13 @@ def truncate_all_tables(connection): @pytest.fixture(scope="function") -def app(): +def app() -> Generator[Flask, Any, None]: + """ + Pytest fixture that initializes the Flask application for testing. + + Yields: + Flask: The Flask application instance. + """ flask_app = initialize_flask_application() flask_app.config.update({ "TESTING": True, @@ -28,19 +47,34 @@ def app(): @pytest.fixture(scope="function") -def client(app): +def client(app: Flask) -> FlaskClient: + """ + Pytest fixture that provides a test client for the Flask application. + + Args: + app (Flask): The Flask application instance. + + Returns: + FlaskClient: A test client. + """ return app.test_client() @pytest.fixture(scope="function") -def db_session(app): - # We include the 'app' fixture as a dependency to ensure that the Flask application - # is fully initialized before the database session is established. This guarantees - # that all configurations and model discoveries are completed +def db_session(app) -> Generator[scoped_session[Session], None, None]: + """ + Pytest fixture that provides a clean database session for each test function. + Truncates all tables before yielding the session. + + Args: + app (Flask): The Flask application instance. - # Explicitly referencing models to satisfy linters (prevent unused import errors) - # Ensure SQLAlchemy's Base metadata is populated for TRUNCATE operations + Yields: + Session: A scoped SQLAlchemy session. + """ + # Explicitly referencing models to satisfy linters and ensure metadata is populated _ = (Account, Article, Comment) + if database_engine.url.render_as_string(hide_password=False) != env_vars.test_database_url: pytest.exit("SECURITY ERROR: The current database URL does not match the configured TEST database URL.") diff --git a/tests/factories.py b/tests/factories.py index ed542cc..5afc913 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -4,11 +4,23 @@ def make_account( - account_username="Xxx__D4RK_V4D0R__xxX", - account_password="password123", - account_email="vador@empire.com", - account_role="user", -): + account_username: str = "Xxx__D4RK_V4D0R__xxX", + account_password: str = "password123", + account_email: str = "vador@empire.com", + account_role: str = "user", +) -> Account: + """ + Factory to create an Account instance for testing. + + Args: + account_username (str): Username. + account_password (str): Password. + account_email (str): Email. + account_role (str): Role. + + Returns: + Account: An Account instance. + """ return Account( account_username=account_username, account_password=account_password, @@ -16,14 +28,46 @@ def make_account( account_role=account_role, ) -def make_article(article_author_id, article_title="Luke, I'm your father !", article_content="On the platform, Darth Vader stepped forward..."): + +def make_article( + article_author_id: int, + article_title: str = "Luke, I'm your father !", + article_content: str = "On the platform, Darth Vader stepped forward...", +) -> Article: + """ + Factory to create an Article instance for testing. + + Args: + article_author_id (int): ID of the author account. + article_title (str): Title. + article_content (str): Content. + + Returns: + Article: An Article instance. + """ return Article( article_author_id=article_author_id, article_title=article_title, article_content=article_content, ) -def make_comment(comment_article_id, comment_written_account_id, comment_content="Bravo !"): + +def make_comment( + comment_article_id: int, + comment_written_account_id: int, + comment_content: str = "Bravo !", +) -> Comment: + """ + Factory to create a Comment instance for testing. + + Args: + comment_article_id (int): ID of the article. + comment_written_account_id (int): ID of the author account. + comment_content (str): Content. + + Returns: + Comment: A Comment instance. + """ return Comment( comment_article_id=comment_article_id, comment_written_account_id=comment_written_account_id, diff --git a/tests/test_controllers/test_article_controller.py b/tests/test_controllers/test_article_controller.py index a4d1f2d..6b04955 100644 --- a/tests/test_controllers/test_article_controller.py +++ b/tests/test_controllers/test_article_controller.py @@ -1,5 +1,6 @@ from unittest.mock import patch +from app.constants import Role, SessionKey from app.models.article_model import Article from tests.factories import make_account, make_article @@ -30,21 +31,21 @@ def test_view_article_not_found(client): def test_create_article_restricted(client): with client.session_transaction() as sess: - sess["user_id"] = 1 - sess["role"] = "user" + sess[SessionKey.USER_ID] = 1 + sess[SessionKey.ROLE] = Role.USER response = client.get("/article/new", follow_redirects=True) assert b"Access restricted" in response.data def test_create_article_success(client, db_session): - author = make_account(account_role="author") + author = make_account(account_role=Role.AUTHOR) db_session.add(author) db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = author.account_id - sess["role"] = "author" + sess[SessionKey.USER_ID] = author.account_id + sess[SessionKey.ROLE] = Role.AUTHOR response = client.post("/article/new", data={"title": "Nouveau Titre", "content": "Contenu"}, follow_redirects=True) assert b"Article published!" in response.data @@ -53,13 +54,13 @@ def test_create_article_success(client, db_session): def test_create_article_atomicity_failure(client, db_session): - author = make_account(account_role="author") + author = make_account(account_role=Role.AUTHOR) db_session.add(author) db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = author.account_id - sess["role"] = "author" + sess[SessionKey.USER_ID] = author.account_id + sess[SessionKey.ROLE] = Role.AUTHOR with patch("database.database_setup.db_session.commit") as mock_commit: mock_commit.side_effect = Exception("Database Failure") @@ -74,8 +75,8 @@ def test_create_article_atomicity_failure(client, db_session): def test_edit_article_unauthorized(client, db_session): - author1 = make_account(account_username="Author1", account_role="author") - author2 = make_account(account_username="Author2", account_role="author") + author1 = make_account(account_username="Author1", account_role=Role.AUTHOR) + author2 = make_account(account_username="Author2", account_role=Role.AUTHOR) db_session.add_all([author1, author2]) db_session.commit() @@ -84,8 +85,8 @@ def test_edit_article_unauthorized(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = author2.account_id - sess["role"] = "author" + sess[SessionKey.USER_ID] = author2.account_id + sess[SessionKey.ROLE] = Role.AUTHOR response = client.post(f"/article/{article.article_id}/edit", data={"title": "Hack", "content": "Hack"}, follow_redirects=True) assert b"Update failed" in response.data @@ -94,7 +95,7 @@ def test_edit_article_unauthorized(client, db_session): def test_delete_article_success(client, db_session): - author = make_account(account_role="admin") + author = make_account(account_role=Role.ADMIN) db_session.add(author) db_session.commit() article = make_article(author.account_id) @@ -102,8 +103,8 @@ def test_delete_article_success(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = author.account_id - sess["role"] = "admin" + sess[SessionKey.USER_ID] = author.account_id + sess[SessionKey.ROLE] = Role.ADMIN response = client.get(f"/article/{article.article_id}/delete", follow_redirects=True) assert b"Article deleted" in response.data @@ -129,8 +130,8 @@ def test_list_articles_pagination(client, db_session): def test_edit_article_not_found(client): with client.session_transaction() as sess: - sess["user_id"] = 1 - sess["role"] = "admin" + sess[SessionKey.USER_ID] = 1 + sess[SessionKey.ROLE] = Role.ADMIN response = client.get("/article/999/edit", follow_redirects=True) assert response.status_code == 200 @@ -138,7 +139,7 @@ def test_edit_article_not_found(client): def test_edit_article_success_by_author(client, db_session): - author = make_account(account_role="author") + author = make_account(account_role=Role.AUTHOR) db_session.add(author) db_session.commit() article = make_article(author.account_id, article_title="Ancien Titre") @@ -146,8 +147,8 @@ def test_edit_article_success_by_author(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = author.account_id - sess["role"] = "author" + sess[SessionKey.USER_ID] = author.account_id + sess[SessionKey.ROLE] = Role.AUTHOR response = client.post(f"/article/{article.article_id}/edit", data={"title": "Titre Modifié", "content": "Nouveau contenu"}, follow_redirects=True) @@ -158,7 +159,7 @@ def test_edit_article_success_by_author(client, db_session): def test_admin_cannot_edit_others_article(client, db_session): author = make_account(account_username="Auteur") - admin = make_account(account_username="Admin", account_role="admin") + admin = make_account(account_username="Admin", account_role=Role.ADMIN) db_session.add_all([author, admin]) db_session.commit() article = make_article(author.account_id, article_title="Titre Intouchable") @@ -166,8 +167,8 @@ def test_admin_cannot_edit_others_article(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = admin.account_id - sess["role"] = "admin" + sess[SessionKey.USER_ID] = admin.account_id + sess[SessionKey.ROLE] = Role.ADMIN response = client.post(f"/article/{article.article_id}/edit", data={"title": "Hack par Admin", "content": "..."}, follow_redirects=True) @@ -177,7 +178,7 @@ def test_admin_cannot_edit_others_article(client, db_session): def test_delete_article_success_by_author(client, db_session): - author = make_account(account_role="author") + author = make_account(account_role=Role.AUTHOR) db_session.add(author) db_session.commit() article = make_article(author.account_id) @@ -185,8 +186,8 @@ def test_delete_article_success_by_author(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = author.account_id - sess["role"] = "author" + sess[SessionKey.USER_ID] = author.account_id + sess[SessionKey.ROLE] = Role.AUTHOR response = client.get(f"/article/{article.article_id}/delete", follow_redirects=True) assert b"Article deleted" in response.data diff --git a/tests/test_controllers/test_comment_controller.py b/tests/test_controllers/test_comment_controller.py index 9719501..4cc9b4b 100644 --- a/tests/test_controllers/test_comment_controller.py +++ b/tests/test_controllers/test_comment_controller.py @@ -1,5 +1,6 @@ from unittest.mock import patch +from app.constants import Role, SessionKey from app.models.comment_model import Comment from tests.factories import make_account, make_article, make_comment @@ -8,6 +9,7 @@ def test_create_comment_unauthorized(client): response = client.post("/comments/create/1", data={"content": "Test"}, follow_redirects=True) assert b"Login required." in response.data + def test_create_comment_success(client, db_session): user = make_account() db_session.add(user) @@ -17,7 +19,7 @@ def test_create_comment_success(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = user.account_id + sess[SessionKey.USER_ID] = user.account_id response = client.post(f"/comments/create/{article.article_id}", data={"content": "Mon super commentaire"}, follow_redirects=True) @@ -25,16 +27,18 @@ def test_create_comment_success(client, db_session): comment = db_session.query(Comment).filter_by(comment_content="Mon super commentaire").first() assert comment is not None + def test_create_comment_on_invalid_article(client, db_session): user = make_account() db_session.add(user) db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = user.account_id + sess[SessionKey.USER_ID] = user.account_id response = client.post("/comments/create/9999", data={"content": "Error"}, follow_redirects=True) assert b"Error adding comment." in response.data + def test_create_reply_success(client, db_session): user = make_account() db_session.add(user) @@ -47,7 +51,7 @@ def test_create_reply_success(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = user.account_id + sess[SessionKey.USER_ID] = user.account_id response = client.post(f"/comments/reply/{parent.comment_id}", data={"content": "Réponse"}, follow_redirects=True) assert response.status_code == 200 @@ -56,6 +60,7 @@ def test_create_reply_success(client, db_session): assert reply.comment_reply_to == parent.comment_id assert reply.comment_article_id == article.article_id + def test_create_reply_atomicity_failure(client, db_session): user = make_account() db_session.add(user) @@ -68,7 +73,7 @@ def test_create_reply_atomicity_failure(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = user.account_id + sess[SessionKey.USER_ID] = user.account_id with patch("database.database_setup.db_session.commit") as mock_commit: mock_commit.side_effect = Exception("Atomic Failure") @@ -81,9 +86,10 @@ def test_create_reply_atomicity_failure(client, db_session): reply = db_session.query(Comment).filter_by(comment_content="My answer").first() assert reply is None + def test_delete_comment_admin_only(client, db_session): - admin = make_account(account_username="Admin", account_role="admin") - user = make_account(account_username="User", account_role="user") + admin = make_account(account_username="Admin", account_role=Role.ADMIN) + user = make_account(account_username="User", account_role=Role.USER) db_session.add_all([admin, user]) db_session.commit() article = make_article(admin.account_id) @@ -94,23 +100,24 @@ def test_delete_comment_admin_only(client, db_session): db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = user.account_id - sess["role"] = "user" + sess[SessionKey.USER_ID] = user.account_id + sess[SessionKey.ROLE] = Role.USER response = client.get(f"/comments/delete/{comment.comment_id}", follow_redirects=True) - assert b"Unauthorized" in response.data + assert b"Access restricted" in response.data or b"Insufficient permissions" in response.data with client.session_transaction() as sess: - sess["user_id"] = admin.account_id - sess["role"] = "admin" + sess[SessionKey.USER_ID] = admin.account_id + sess[SessionKey.ROLE] = Role.ADMIN response = client.get(f"/comments/delete/{comment.comment_id}", follow_redirects=True) assert b"Comment deleted." in response.data + def test_reply_to_non_existent_comment(client, db_session): user = make_account() db_session.add(user) db_session.commit() with client.session_transaction() as sess: - sess["user_id"] = user.account_id + sess[SessionKey.USER_ID] = user.account_id response = client.post("/comments/reply/999", data={"content": "Hello"}, follow_redirects=True) assert b"Error replying" in response.data diff --git a/tests/test_controllers/test_login_controller.py b/tests/test_controllers/test_login_controller.py index f60591d..53f829e 100644 --- a/tests/test_controllers/test_login_controller.py +++ b/tests/test_controllers/test_login_controller.py @@ -1,3 +1,4 @@ +from app.constants import Role, SessionKey from tests.factories import make_account @@ -6,10 +7,12 @@ def test_render_login_page(client): assert response.status_code == 200 assert b'action="/login"' in response.data + def test_login_method_not_allowed(client): response = client.get("/login") assert response.status_code == 405 + def test_login_success(client, db_session): user = make_account(account_username="Vador", account_password="dark_password") db_session.add(user) @@ -17,9 +20,10 @@ def test_login_success(client, db_session): response = client.post("/login", data={"username": "Vador", "password": "dark_password"}, follow_redirects=True) assert response.status_code == 200 with client.session_transaction() as session: - assert session["user_id"] == user.account_id - assert session["username"] == "Vador" - assert session["role"] == "user" + assert session[SessionKey.USER_ID] == user.account_id + assert session[SessionKey.USERNAME] == "Vador" + assert session[SessionKey.ROLE] == Role.USER + def test_login_failure_wrong_password(client, db_session): user = make_account(account_username="Luke", account_password="correct_password") @@ -28,18 +32,20 @@ def test_login_failure_wrong_password(client, db_session): response = client.post("/login", data={"username": "Luke", "password": "wrong_password"}, follow_redirects=True) assert b"Incorrect credentials." in response.data with client.session_transaction() as session: - assert "user_id" not in session + assert SessionKey.USER_ID not in session + def test_login_non_existent_user(client): response = client.post("/login", data={"username": "Inconnu", "password": "password"}, follow_redirects=True) assert b"Incorrect credentials." in response.data + def test_logout(client): with client.session_transaction() as session: - session["user_id"] = 1 - session["username"] = "Test" + session[SessionKey.USER_ID] = 1 + session[SessionKey.USERNAME] = "Test" response = client.get("/logout", follow_redirects=True) assert response.status_code == 200 with client.session_transaction() as session: - assert "user_id" not in session - assert "username" not in session + assert SessionKey.USER_ID not in session + assert SessionKey.USERNAME not in session diff --git a/tests/test_models/__init__.py b/tests/test_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_models/test_account_model.py b/tests/test_models/test_account_model.py index 25d783d..0d9e54b 100644 --- a/tests/test_models/test_account_model.py +++ b/tests/test_models/test_account_model.py @@ -1,7 +1,8 @@ from datetime import datetime +from typing import cast import pytest -import sqlalchemy +from sqlalchemy import exc from app.models.account_model import Account from tests.factories import make_account @@ -23,19 +24,19 @@ def test_account_username_unique(db_session): db_session.commit() db_session.add(make_account(account_username="unique")) - with pytest.raises(sqlalchemy.exc.IntegrityError): + with pytest.raises(exc.IntegrityError): db_session.commit() def test_account_missing_username(db_session): - account = make_account(account_username=None) + account = make_account(account_username=cast(str, None)) db_session.add(account) - with pytest.raises(sqlalchemy.exc.IntegrityError): + with pytest.raises(exc.IntegrityError): db_session.commit() def test_account_role_invalid(db_session): account = make_account(account_role="super_admin") db_session.add(account) - with pytest.raises(sqlalchemy.exc.IntegrityError): + with pytest.raises(exc.IntegrityError): db_session.commit() diff --git a/tests/test_models/test_article_model.py b/tests/test_models/test_article_model.py index 7c750e9..eb3c885 100644 --- a/tests/test_models/test_article_model.py +++ b/tests/test_models/test_article_model.py @@ -1,5 +1,7 @@ +from typing import cast + import pytest -import sqlalchemy +from sqlalchemy import exc from app.models.article_model import Article from tests.factories import make_account, make_article @@ -25,9 +27,9 @@ def test_article_missing_title(db_session): author = make_account() db_session.add(author) db_session.commit() - article = make_article(article_author_id=author.account_id, article_title=None) + article = make_article(article_author_id=author.account_id, article_title=cast(str, None)) db_session.add(article) - with pytest.raises(sqlalchemy.exc.IntegrityError): + with pytest.raises(exc.IntegrityError): db_session.commit() @@ -36,9 +38,9 @@ def test_article_missing_content(db_session): db_session.add(author) db_session.commit() - article = make_article(article_author_id=author.account_id, article_content=None) + article = make_article(article_author_id=author.account_id, article_content=cast(str, None)) db_session.add(article) - with pytest.raises(sqlalchemy.exc.IntegrityError): + with pytest.raises(exc.IntegrityError): db_session.commit() diff --git a/tests/test_services/__init__.py b/tests/test_services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_services/test_article_service.py b/tests/test_services/test_article_service.py index 702a448..186d62f 100644 --- a/tests/test_services/test_article_service.py +++ b/tests/test_services/test_article_service.py @@ -1,3 +1,4 @@ +from app.constants import Role from app.models.article_model import Article from app.services.article_service import ArticleService from tests.factories import make_account, make_article @@ -7,7 +8,8 @@ def test_create_article_service(db_session): author = make_account(account_role="author") db_session.add(author) db_session.commit() - article = ArticleService.create_article("Titre", "Contenu", author.account_id) + article_service = ArticleService(db_session) + article = article_service.create_article("Titre", "Contenu", author.account_id) db_session.commit() assert article.article_id is not None assert article.article_title == "Titre" @@ -22,7 +24,8 @@ def test_get_paginated_articles(db_session): db_session.add(make_article(author.account_id, article_title=f"Art {i}")) db_session.commit() - results = ArticleService.get_paginated_articles(page=1, per_page=3) + article_service = ArticleService(db_session) + results = article_service.get_paginated_articles(page=1, per_page=3) assert len(results) == 3 @@ -33,7 +36,8 @@ def test_update_article_success(db_session): article = make_article(author.account_id, article_title="Old Title") db_session.add(article) db_session.commit() - result = ArticleService.update_article(article.article_id, author.account_id, "user", "New Title", "New Content") + article_service = ArticleService(db_session) + result = article_service.update_article(article.article_id, author.account_id, "user", "New Title", "New Content") db_session.commit() assert result is not None assert result.article_title == "New Title" @@ -47,7 +51,8 @@ def test_update_article_unauthorized(db_session): article = make_article(author.account_id) db_session.add(article) db_session.commit() - result = ArticleService.update_article(article.article_id, wrong_user.account_id, "user", "Hacked", "...") + article_service = ArticleService(db_session) + result = article_service.update_article(article.article_id, wrong_user.account_id, "user", "Hacked", "...") assert result is None @@ -59,7 +64,8 @@ def test_delete_article_by_admin(db_session): article = make_article(author.account_id) db_session.add(article) db_session.commit() - result = ArticleService.delete_article(article.article_id, admin.account_id, "admin") + article_service = ArticleService(db_session) + result = article_service.delete_article(article.article_id, admin.account_id, Role.ADMIN) db_session.commit() assert result is True assert db_session.get(Article, article.article_id) is None @@ -74,7 +80,8 @@ def test_get_total_count(db_session): db_session.add(make_article(author.account_id)) db_session.commit() - count = ArticleService.get_total_count() + article_service = ArticleService(db_session) + count = article_service.get_total_count() assert count == 3 @@ -85,7 +92,8 @@ def test_get_all_ordered_by_date(db_session): db_session.add(make_article(author.account_id, article_title="First")) db_session.add(make_article(author.account_id, article_title="Second")) db_session.commit() - articles = ArticleService.get_all_ordered_by_date() + article_service = ArticleService(db_session) + articles = article_service.get_all_ordered_by_date() assert len(articles) == 2 assert articles[0].article_title is not None @@ -98,6 +106,7 @@ def test_delete_article_unauthorized(db_session): article = make_article(author.account_id) db_session.add(article) db_session.commit() - result = ArticleService.delete_article(article.article_id, stranger.account_id, "user") + article_service = ArticleService(db_session) + result = article_service.delete_article(article.article_id, stranger.account_id, "user") assert result is False assert db_session.get(Article, article.article_id) is not None diff --git a/tests/test_services/test_comment_service.py b/tests/test_services/test_comment_service.py index 5a39ed0..74560ee 100644 --- a/tests/test_services/test_comment_service.py +++ b/tests/test_services/test_comment_service.py @@ -1,3 +1,4 @@ +from app.constants import Role from app.models.comment_model import Comment from app.services.comment_service import CommentService from tests.factories import make_account, make_article, make_comment @@ -10,12 +11,13 @@ def test_create_reply_logic(db_session): article = make_article(author.account_id) db_session.add(article) db_session.commit() - CommentService.create_comment(article.article_id, author.account_id, "Parent") + comment_service = CommentService(db_session) + comment_service.create_comment(article.article_id, author.account_id, "Parent") db_session.commit() - parent = CommentService.get_tree_by_article_id(article.article_id)[0] - CommentService.create_reply(parent.comment_id, author.account_id, "Reply") + parent = comment_service.get_tree_by_article_id(article.article_id)[0] + comment_service.create_reply(parent.comment_id, author.account_id, "Reply") db_session.commit() - tree = CommentService.get_tree_by_article_id(article.article_id) + tree = comment_service.get_tree_by_article_id(article.article_id) assert len(tree) == 1 assert len(tree[0].comment_replies) == 1 assert tree[0].comment_replies[0].comment_content == "Reply" @@ -28,13 +30,14 @@ def test_comment_flattening_logic(db_session): article = make_article(author.account_id) db_session.add(article) db_session.commit() - CommentService.create_comment(article.article_id, author.account_id, "Root") + comment_service = CommentService(db_session) + comment_service.create_comment(article.article_id, author.account_id, "Root") db_session.commit() root = db_session.query(Comment).filter_by(comment_content="Root").first() - CommentService.create_reply(root.comment_id, author.account_id, "Reply A") + comment_service.create_reply(root.comment_id, author.account_id, "Reply A") db_session.commit() reply_a = db_session.query(Comment).filter_by(comment_content="Reply A").first() - CommentService.create_reply(reply_a.comment_id, author.account_id, "Reply B") + comment_service.create_reply(reply_a.comment_id, author.account_id, "Reply B") db_session.commit() reply_b = db_session.query(Comment).filter_by(comment_content="Reply B").first() assert reply_b.comment_reply_to == root.comment_id @@ -51,7 +54,8 @@ def test_delete_comment_as_admin(db_session): comment = make_comment(article.article_id, author.account_id) db_session.add(comment) db_session.commit() - result = CommentService.delete_comment(comment.comment_id, "admin") + comment_service = CommentService(db_session) + result = comment_service.delete_comment(comment.comment_id, Role.ADMIN) db_session.commit() assert result == article.article_id assert db_session.get(Comment, comment.comment_id) is None @@ -61,5 +65,6 @@ def test_create_comment_invalid_article(db_session): author = make_account() db_session.add(author) db_session.commit() - result = CommentService.create_comment(999, author.account_id, "Hello") + comment_service = CommentService(db_session) + result = comment_service.create_comment(999, author.account_id, "Hello") assert result is False diff --git a/tests/test_services/test_login_service.py b/tests/test_services/test_login_service.py index ccd4761..f84118f 100644 --- a/tests/test_services/test_login_service.py +++ b/tests/test_services/test_login_service.py @@ -6,21 +6,24 @@ def test_authenticate_user_success(db_session): user = make_account(account_username="leia", account_password="password123") db_session.add(user) db_session.commit() - result = LoginService.authenticate_user("leia", "password123") + login_service = LoginService(db_session) + result = login_service.authenticate_user("leia", "password123") assert result is not None - assert result["username"] == "leia" - assert result["role"] == "user" - assert result["id"] == user.account_id + assert result.account_username == "leia" + assert result.account_role == "user" + assert result.account_id == user.account_id def test_authenticate_user_wrong_password(db_session): user = make_account(account_username="leia", account_password="password123") db_session.add(user) db_session.commit() - result = LoginService.authenticate_user("leia", "mauvais_pass") + login_service = LoginService(db_session) + result = login_service.authenticate_user("leia", "mauvais_pass") assert result is None def test_authenticate_user_non_existent(db_session): - result = LoginService.authenticate_user("fantome", "rien") + login_service = LoginService(db_session) + result = login_service.authenticate_user("fantome", "rien") assert result is None