Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ scaffoldr --no-banner generate my-project
When you generate a FastAPI project, you get:

- **Complete FastAPI Application**: Pre-configured with proper structure
- **Database Integration**: SQLAlchemy with Alembic migrations
- **Database Integration**: SQLAlchemy with Alembic migrations or MongoDB with Beanie ODM
- **File Storage**: Built-in file upload/download endpoints
- **API Documentation**: Auto-generated OpenAPI/Swagger docs
- **Development Tools**: Pre-configured with ruff, mypy, pytest
Expand Down Expand Up @@ -117,7 +117,9 @@ my-project/
│ ├── core/ # Core components (config, utils, etc.)
│ ├── features/ # Feature modules (business logic)
│ ├── services/ # External service integrations
│ └── database/ # Data access layer
{% if database %}
│ └── database/ # Data access layer (SQLAlchemy or MongoDB)
{% endif %}
├── tests/ # Comprehensive test suite
├── scripts/ # Development scripts
├── .github/ # GitHub workflows and templates
Expand Down
5 changes: 5 additions & 0 deletions src/scaffoldr/cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def generate(
if use_cloud:
cloud_type = Helper.cloud_type()

database_type = None
if use_database:
database_type = Helper.database_type()

project_name = project_name.replace(" ", "-").lower()
# Check if directory already exists
if destination == ".":
Expand Down Expand Up @@ -71,6 +75,7 @@ def generate(
"use_docker": docker,
"cloud_type": cloud_type,
"database": use_database,
"database_type": database_type,
"framework": framework,
}

Expand Down
3 changes: 2 additions & 1 deletion src/scaffoldr/core/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .art import ascii_art
from .const import CloudTypes, Frameworks, console
from .const import CloudTypes, DatabaseTypes, Frameworks, console

__all__ = [
"CloudTypes",
"DatabaseTypes",
"Frameworks",
"ascii_art",
"console",
Expand Down
5 changes: 5 additions & 0 deletions src/scaffoldr/core/constants/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ class CloudTypes(str, Enum):
GCP = "gcp"
AZURE = "azure"
NONE = "none"


class DatabaseTypes(str, Enum):
SQLALCHEMY = "sqlalchemy"
MONGODB = "mongodb"
20 changes: 17 additions & 3 deletions src/scaffoldr/core/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typer
from copier import subprocess

from scaffoldr.core.constants.const import CloudTypes
from scaffoldr.core.constants.const import CloudTypes, DatabaseTypes, console

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -40,6 +40,20 @@ def cloud_type() -> str:
cloud_type: str = cast("str", typer.prompt(f"Cloud type: {' | '.join(cloud_types)}"))
if cloud_type in cloud_types:
return cloud_type
raise ValueError(
f"Invalid cloud type: {cloud_type}. Must be one of: {' | '.join(cloud_types)}",
console.print(f"[red]Error:[/red] Invalid cloud type: {cloud_type}")
raise typer.Exit(1)

@staticmethod
def database_type() -> str:
"""
Prompt for database type selection.
"""
database_types = list(DatabaseTypes)
database_type: str = cast(
"str",
typer.prompt(f"Database type: {' | '.join(database_types)}"),
)
if database_type in database_types:
return database_type
console.print(f"[red]Error:[/red] Invalid database type: {database_type}")
raise typer.Exit(1)
7 changes: 7 additions & 0 deletions templates/fastapi_template/{{ project_name }}/.env.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
{{ project_slug|upper }}_LOG_GRANIAN_ERROR_LEVEL=40

{% if database %}
{% if database_ype == 'postgresql' %}
{{ project_slug|upper }}_DB_URL=postgresql+asyncpg://user:password@localhost:5432/{{ project_slug }}
{{ project_slug|upper }}_DB_ECHO=false
{% elif database_type == 'sqlite' %}
{{ project_slug|upper }}_DB_URL=sqlite+aiosqlite:///./{{ project_name|upper }}.db
{{ project_slug|upper }}_DB_ECHO=false
{% elif database_type == 'mongodb' %}
{{ project_slug|upper }}_DB_URL=mongodb://localhost:27017
{% endif %}
{% endif %}

{% if cloud_type %}
Expand Down
7 changes: 7 additions & 0 deletions templates/fastapi_template/{{ project_name }}/README.md.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@

- **FastAPI**: Modern, fast web framework for building APIs
- **Pydantic**: Data validation and serialization
{% if database %}
{% if database_type == 'sqlalchemy' %}
- **SQLAlchemy**: SQL toolkit and ORM
- **Alembic**: Database migration tool
{% elif database_type == 'mongodb' %}
- **Beanie**: MongoDB ODM for Python
- **Motor**: Asynchronous MongoDB driver
{% endif %}
{% endif %}
- **Uvicorn**: ASGI web server
- **Granian**: High-performance web server
{% if cloud_type == 'aws' %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ dependencies = [
"azure-identity>=1.25.0",
{% endif %}
{% if database %}
{% if database_type == 'sqlalchemy' %}
"aiosqlite>=0.21.0",
"sqlalchemy[asyncio]>=2.0.0",
{% elif database_type == 'mongodb' %}
"beanie>=2.0.0",
"motor>=3.7.1",
"pymongo>=4.15.1",
{% endif %}
{% endif %}
"fastapi>=0.117.1",
"granian[reload]>=2.5.4",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,31 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI

{% if database %}
{% if database_type == 'sqlalchemy' %}
from {{ project_slug }}.connection import Base, database
{% elif database_type == 'mongodb' %}
from {{ project_slug }}.connection import mongodb
{% endif %}
{% endif %}

{% if framework == 'fastapi' %}

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
{% if database %}
{% if database_type == 'sqlalchemy' %}
async with database.engine.begin() as conn:
# run_sync executes the given function in a sync context using the connection
await conn.run_sync(Base.metadata.create_all)
{% endif %}
{% elif database_type == 'mongodb' %}
await mongodb.connect()
{% endif %}
{% endif %}
yield
{% if database %}
{% if database_type == 'mongodb' %}
await mongodb.close()
{% endif %}
{% endif %}

{% endif %}
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
{% if database %}
{% if database_type == 'sqlalchemy' %}
from .database import database, inject_session, Base
__all__ = ["Base", "database", "inject_session"]
{% elif database_type == 'mongodb' %}
from .database import mongodb
__all__ = ["mongodb"]
{% endif %}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{% if database_type == 'sqlalchemy' %}
from collections.abc import AsyncGenerator, Callable
from functools import wraps
from typing import Any, cast
Expand Down Expand Up @@ -74,3 +75,44 @@ def inject_session(func: Callable[..., Any]) -> Callable[..., Any]:
return None # This should never be reached, but for type safety

return cast("Callable[..., Any]", wrapper)
{% elif database_type == 'mongodb' %}
from beanie import init_beanie
from motor.motor_asyncio import AsyncIOMotorClient
from typing import Optional

from {{ project_slug }}.core.config import settings


class MongoDB:
"""MongoDB connection manager using Beanie ODM."""

def __init__(self) -> None:
"""Initialize the MongoDB connection."""
self.client: Optional[AsyncIOMotorClient] = None
self.database = None

async def connect(self) -> None:
"""Connect to MongoDB and initialize Beanie."""
self.client = AsyncIOMotorClient(settings.database.URL)
self.database = self.client[settings.database.DATABASE_NAME]

# Initialize Beanie with document models
# Note: Document models should be imported here when they exist
await init_beanie(
database=self.database,
document_models=[], # Add your document models here
)

async def close(self) -> None:
"""Close the MongoDB connection."""
if self.client:
self.client.close()

async def get_database(self):
"""Get the database instance."""
return self.database


# Global MongoDB instance
mongodb = MongoDB()
{% endif %}
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,16 @@ class CloudSettings(BaseSettings):


class DatabaseSettings(BaseSettings):
{% if database_type == 'sqlalchemy' %}
URL: str = "sqlite+aiosqlite:///./{{ project_name|upper }}.db"
ECHO: Annotated[
bool,
AfterValidator(true_bool_validator),
] = False
{% elif database_type == 'mongodb' %}
URL: str = "mongodb://localhost:27017"
DATABASE_NAME: str = "{{ project_name|lower }}_db"
{% endif %}

model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
extra="ignore",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ from typing import Generator

import pytest
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncSession

from {{ project_slug }}.api.application import app
{% if database %}
{% if database_type == 'sqlalchemy' %}
from sqlalchemy.ext.asyncio import AsyncSession

from {{ project_slug }}.connection import database
{% elif database_type == 'mongodb' %}
from {{ project_slug }}.connection import mongodb
{% endif %}
{% endif %}


Expand All @@ -28,6 +33,7 @@ def client() -> Generator[TestClient, None, None]:


{% if database %}
{% if database_type == 'sqlalchemy' %}
@pytest.fixture(scope="function")
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""Create a database session for testing."""
Expand All @@ -40,4 +46,14 @@ async def db_session() -> AsyncGenerator[AsyncSession, None]:
except Exception:
await session.rollback()
raise
{% elif database_type == 'mongodb' %}
@pytest.fixture(scope="function")
async def mongodb_client():
"""Create a MongoDB client for testing."""
await mongodb.connect()
try:
yield mongodb
finally:
await mongodb.close()
{% endif %}
{% endif %}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import pytest
from fastapi.testclient import TestClient
{% if database %}
{% if database_type == 'sqlalchemy' %}
from sqlalchemy.ext.asyncio import AsyncSession
{% endif %}
{% endif %}


class TestHealth:
Expand All @@ -19,6 +21,7 @@ class TestHealth:
assert data["status"] == "success"

{% if database %}
{% if database_type == 'sqlalchemy' %}
@pytest.mark.asyncio
async def test_health_endpoint_with_database(
self, client: TestClient, db_session: AsyncSession
Expand All @@ -30,4 +33,17 @@ class TestHealth:
assert "status" in data
assert data["status"] == "success"
# Note: Database health check not implemented in basic health endpoint
{% elif database_type == 'mongodb' %}
@pytest.mark.asyncio
async def test_health_endpoint_with_mongodb(
self, client: TestClient, mongodb_client: None,
) -> None:
"""Test health check with MongoDB connectivity."""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert data["status"] == "success"
# Note: Database health check not implemented in basic health endpoint
{% endif %}
{% endif %}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class TestSettings:
assert settings.server.PORT > 0

{% if database %}
{% if database_type == 'sqlalchemy' %}
def test_database_settings(self) -> None:
"""Test database settings have expected attributes."""
settings = get_settings()
Expand All @@ -47,4 +48,14 @@ class TestSettings:
assert isinstance(settings.database.URL, str)
assert isinstance(settings.database.ECHO, bool)
assert "sqlite" in settings.database.URL # Should contain sqlite for basic setup
{% elif database_type == 'mongodb' %}
def test_database_settings(self) -> None:
"""Test MongoDB database settings have expected attributes."""
settings = get_settings()
assert hasattr(settings.database, 'URL')
assert hasattr(settings.database, 'DATABASE_NAME')
assert isinstance(settings.database.URL, str)
assert isinstance(settings.database.DATABASE_NAME, str)
assert "mongodb" in settings.database.URL # Should contain mongodb for MongoDB setup
{% endif %}
{% endif %}
Loading