Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions nb_cli/cli/commands/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import sys
import json
import shlex
from typing import Any
from pathlib import Path
from logging import Logger
from functools import partial
from typing_extensions import Required
from typing import TypeAlias, TypedDict
from dataclasses import field, dataclass
from collections.abc import Sequence, MutableMapping

import click
import nonestorage
Expand Down Expand Up @@ -50,10 +52,24 @@
"bootstrap": _("bootstrap (for beginner or user)"),
"simple": _("simple (for plugin developer)"),
}
HIDDEN_FILE_OVERRIDES = {".env", ".env.dev", ".env.prod", ".gitignore", ".vscode"}
SerializedJSON: TypeAlias = str

BLACKLISTED_PROJECT_NAME.update(sys.stdlib_module_names)


class ProjectTemplateProps(TypedDict):
"""项目模板渲染变量字典集"""

project_name: Required[str]
inplace: bool
adapters: SerializedJSON
drivers: SerializedJSON
environment: MutableMapping[str, str]
use_src: bool
devtools: Sequence[str]


@dataclass
class ProjectContext:
"""项目模板生成上下文
Expand All @@ -63,12 +79,14 @@ class ProjectContext:
packages: 项目需要安装的包
"""

variables: dict[str, Any] = field(default_factory=dict)
variables: ProjectTemplateProps = field( # pyright: ignore[reportAssignmentType]
default_factory=dict
)
packages: list[str] = field(default_factory=list)


def project_name_validator(name: str) -> bool:
return (
return name == "." or (
bool(re.match(VALID_PROJECT_NAME, name))
and name not in BLACKLISTED_PROJECT_NAME
)
Expand All @@ -92,6 +110,25 @@ async def prompt_common_context(context: ProjectContext) -> ProjectContext:
error_message=_("Invalid project name!"),
).prompt_async(style=CLI_DEFAULT_STYLE)
context.variables["project_name"] = project_name
context.variables["inplace"] = False

if project_name == ".":
_parent_dirname = Path(".").absolute().name
if not project_name_validator(_parent_dirname):
click.secho(_("Invalid project name!"), fg="red")
raise CancelledError
if any(
(f.name in HIDDEN_FILE_OVERRIDES or not f.name.startswith("."))
for f in Path(project_name).iterdir()
):
if not await ConfirmPrompt(
_("Current folder is not empty. Overwrite existing files?"),
False,
).prompt_async(style=CLI_DEFAULT_STYLE):
click.echo(_("Stopped creating bot."))
raise CancelledError
project_name = context.variables["project_name"] = _parent_dirname
context.variables["inplace"] = True

confirm = False
adapters = []
Expand Down Expand Up @@ -297,7 +334,9 @@ async def create(
use_venv = False
project_dir_name = context.variables["project_name"].replace(" ", "-")
project_dir = Path(output_dir or ".") / project_dir_name
venv_dir = project_dir / ".venv"
venv_dir = (
Path("./.venv") if context.variables["inplace"] else project_dir / ".venv"
)

if install_dependencies:
try:
Expand Down
8 changes: 2 additions & 6 deletions nb_cli/handlers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from cookiecutter.main import cookiecutter

from nb_cli import _
from nb_cli.config import (
SimpleInfo,
PackageInfo,
NoneBotConfig,
LegacyNoneBotConfig,
)
from nb_cli.config import SimpleInfo, PackageInfo, NoneBotConfig, LegacyNoneBotConfig

from . import templates
from .driver import list_drivers
Expand Down Expand Up @@ -51,6 +46,7 @@ def create_project(
no_input=no_input,
extra_context=context,
output_dir=output_dir or ".",
overwrite_if_exists=True,
)


Expand Down
Loading