diff --git a/.gitignore b/.gitignore index b36e01e..ca9fb53 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.vscode *.pyc # Packages !/tests/**/*.egg diff --git a/.vscode/extensions.json b/.vscode/extensions.json deleted file mode 100644 index 292ac1a..0000000 --- a/.vscode/extensions.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "recommendations": [ - "tamasfe.even-better-toml", - "ms-python.python", - "yzhang.markdown-all-in-one", - "ms-python.black-formatter", - "lextudio.restructuredtext", - "trond-snekvik.simple-rst" - ] -} diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 916a1af..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "editor.rulers": [ - 120 - ], - "python.formatting.provider": "black", - "python.linting.enabled": true, - "python.linting.lintOnSave": true, - "python.linting.flake8Enabled": true, - "python.linting.flake8Args": [ - "--max-line-length=120", - "--extend-ignore=E203,E501,W503", - "--exclude=.venv" - ], - "python.sortImports.args": [ - "--profile=black" - ], - "[python]": { - "editor.formatOnSave": true, - "editor.codeActionsOnSave": { - "source.organizeImports": true - } - }, - "python.autoComplete.extraPaths": [ - - ], - "python.analysis.extraPaths": [] -} diff --git a/README.md b/README.md index 2b85c82..264d9f9 100644 --- a/README.md +++ b/README.md @@ -66,10 +66,27 @@ See the [tests](/tests/test_py_d2) for more detailed usage examples. - [x] Icons in shapes - [x] Support for empty labels - [x] Shape links -- [ ] SQL table shapes +- [x] SQL table shapes +- [x] Layers - [ ] Class shapes - [ ] Comments -- [ ] Layers + +## Examples + +`examples/` + +```sh +poetry run python examples/.py +``` + +SQL Table: + +```sh +poetry run python example/simple_sql_schema.py +# Open diagram: +open simple_sql_schema.svg +``` + ## Development diff --git a/examples/simple_sql_schema.py b/examples/simple_sql_schema.py new file mode 100644 index 0000000..d8317ea --- /dev/null +++ b/examples/simple_sql_schema.py @@ -0,0 +1,42 @@ +import os +import subprocess + +from py_d2.diagram import D2Diagram +from py_d2.sql_table import SQLConstraint, SQLTable, create_foreign_key_connection + +FILE_NAME = "simple_sql_schema" + +# Create a new diagram +diagram = D2Diagram() + +# Create Users table +users = SQLTable("users") +users.add_field("id", "int", SQLConstraint.PRIMARY_KEY) +users.add_field("name", "varchar(255)") + +# Create Orders table +orders = SQLTable("orders") +orders.add_field("id", "int", SQLConstraint.PRIMARY_KEY) +orders.add_field("user_id", "int", SQLConstraint.FOREIGN_KEY) +orders.add_field("total", "decimal(10,2)") + +# Create connection +fk = create_foreign_key_connection("orders", "user_id", "users", "id") + +# Add tables and connections to the diagram +diagram.add_shape(users) +diagram.add_shape(orders) +diagram.add_connection(fk) + +# Write the diagram to a file +with open(f"{FILE_NAME}.d2", "w") as f: + f.write(str(diagram)) + +print(f"D2 diagram file created: {FILE_NAME}.d2") +print(str(diagram)) + +try: + subprocess.run(["d2", "--layout", "elk", f"{FILE_NAME}.d2", f"{FILE_NAME}.svg"], check=True) + print(f"SVG diagram generated: {os.path.abspath(f'{FILE_NAME}.svg')}") +except Exception as e: + print(f"Error generating SVG: {e}") diff --git a/poetry.toml b/poetry.toml index 4924ae7..5d2f4cc 100644 --- a/poetry.toml +++ b/poetry.toml @@ -3,6 +3,3 @@ cache-dir = ".cache" [virtualenvs] create = true in-project = true - -[virtualenvs.options] -always-copy = true diff --git a/src/py_d2/sql_table.py b/src/py_d2/sql_table.py new file mode 100644 index 0000000..905692a --- /dev/null +++ b/src/py_d2/sql_table.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +from py_d2.connection import D2Connection +from py_d2.connection import Direction +from py_d2.shape import D2Shape +from py_d2.shape import Shape + + +class SQLConstraint(Enum): + PRIMARY_KEY = "primary_key" + FOREIGN_KEY = "foreign_key" + UNIQUE = "unique" + + +class SQLField: + def __init__( + self, + name: str, + data_type: str, + constraint: Optional[Union[SQLConstraint, str, List[Union[SQLConstraint, str]]]] = None, + ): + self.name = name + self.data_type = data_type + + # Handle constraint(s) + if constraint is None: + self.constraints = [] + elif isinstance(constraint, list): + self.constraints = constraint + else: + self.constraints = [constraint] + + def to_d2_format(self) -> str: + """Convert the field to D2 format.""" + if not self.constraints: + return f"{self.name}: {self.data_type}" + + constraints_str = "; ".join([c.value if isinstance(c, SQLConstraint) else c for c in self.constraints]) + + if len(self.constraints) == 1: + constraint_part = constraints_str + else: + constraint_part = f"[{constraints_str}]" + + return f"{self.name}: {self.data_type} {{constraint: {constraint_part}}}" + + +class SQLTable(D2Shape): + def __init__( + self, + name: str, + fields: Optional[Dict[str, Union[str, Dict[str, Any]]]] = None, + label: Optional[str] = None, + style: Optional[Any] = None, + icon: Optional[str] = None, + near: Optional[str] = None, + link: Optional[str] = None, + ): + super().__init__( + name=name, + label=label, + shape=Shape.sql_table, + style=style, + icon=icon, + near=near, + link=link, + ) + + self.fields: List[SQLField] = [] + + # Process fields if provided + if fields: + for field_name, field_info in fields.items(): + if isinstance(field_info, str): + # Simple case: just a type + self.add_field(field_name, field_info) + elif isinstance(field_info, dict): + # Complex case: type and constraints + field_type = field_info.get("type", "") + constraint = field_info.get("constraint", None) + self.add_field(field_name, field_type, constraint) + + def add_field( + self, + name: str, + data_type: str, + constraint: Optional[Union[SQLConstraint, str, List[Union[SQLConstraint, str]]]] = None, + ) -> SQLField: + """Add a field to the SQL table.""" + field = SQLField(name, data_type, constraint) + self.fields.append(field) + return field + + def lines(self) -> List[str]: + """Generate D2 lines for the SQL table.""" + # Get the base properties from parent class + properties = [] + + # Add shape property + if self.shape: + properties.append(f"shape: {self.shape.value}") + + # Add fields + for field in self.fields: + properties.append(field.to_d2_format()) + + # Add other properties + if self.near: + properties.append(f"near: {self.near}") + + if self.link: + properties.append(f"link: {self.link}") + + if self.style: + properties.extend(self.style.lines()) + + if self.icon: + properties.append(f"icon: {self.icon}") + + # Add child shapes and connections + shapes = [shape.lines() for shape in self.shapes] + connections = [connection.lines() for connection in self.connections] + + for shape_lines in shapes: + properties.extend(shape_lines) + + for connection_lines in connections: + properties.extend(connection_lines) + + # Create the final lines + from py_d2.helpers import add_label_and_properties + + lines = add_label_and_properties(self.name, self.label, properties) + + return lines + + +def create_foreign_key_connection( + source_table: str, + source_field: str, + target_table: str, + target_field: str, + label: Optional[str] = None, +) -> D2Connection: + """Create a foreign key connection between two tables.""" + source = f"{source_table}.{source_field}" + target = f"{target_table}.{target_field}" + return D2Connection(source, target, label, Direction.TO) diff --git a/tests/test_py_d2/test_d2_sql_table.py b/tests/test_py_d2/test_d2_sql_table.py new file mode 100644 index 0000000..c2a7aec --- /dev/null +++ b/tests/test_py_d2/test_d2_sql_table.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +from py_d2.connection import Direction +from py_d2.diagram import D2Diagram +from py_d2.shape import D2Shape +from py_d2.sql_table import ( + SQLConstraint, + SQLField, + SQLTable, + create_foreign_key_connection, +) +from py_d2.style import D2Style + + +def test_sql_field_simple(): + """Test creating a simple SQL field without constraints.""" + field = SQLField("name", "varchar(255)") + assert field.to_d2_format() == "name: varchar(255)" + + +def test_sql_field_with_constraint(): + """Test creating a SQL field with a single constraint.""" + field = SQLField("id", "int", SQLConstraint.PRIMARY_KEY) + assert field.to_d2_format() == "id: int {constraint: primary_key}" + + +def test_sql_field_with_string_constraint(): + """Test creating a SQL field with a string constraint.""" + field = SQLField("status", "varchar(50)", "not null") + assert field.to_d2_format() == "status: varchar(50) {constraint: not null}" + + +def test_sql_field_with_multiple_constraints(): + """Test creating a SQL field with multiple constraints.""" + field = SQLField("email", "varchar(255)", [SQLConstraint.UNIQUE, "not null"]) + assert field.to_d2_format() == "email: varchar(255) {constraint: [unique; not null]}" + + +def test_sql_table_empty(): + """Test creating an empty SQL table.""" + table = SQLTable("users") + assert str(table) == "users: {\n shape: sql_table\n}" + + +def test_sql_table_with_fields(): + """Test creating a SQL table with fields.""" + table = SQLTable("users") + table.add_field("id", "int", SQLConstraint.PRIMARY_KEY) + table.add_field("name", "varchar(255)") + table.add_field("email", "varchar(255)", SQLConstraint.UNIQUE) + + expected = "\n".join( + [ + "users: {", + " shape: sql_table", + " id: int {constraint: primary_key}", + " name: varchar(255)", + " email: varchar(255) {constraint: unique}", + "}", + ] + ) + + assert str(table) == expected + + +def test_sql_table_with_fields_dict(): + """Test creating a SQL table with fields provided as a dictionary.""" + fields = { + "id": {"type": "int", "constraint": SQLConstraint.PRIMARY_KEY}, + "name": "varchar(255)", + "email": {"type": "varchar(255)", "constraint": SQLConstraint.UNIQUE}, + } + + table = SQLTable("users", fields=fields) + + expected = "\n".join( + [ + "users: {", + " shape: sql_table", + " id: int {constraint: primary_key}", + " name: varchar(255)", + " email: varchar(255) {constraint: unique}", + "}", + ] + ) + + assert str(table) == expected + + +def test_sql_table_with_label_and_style(): + """Test creating a SQL table with a label and style.""" + + table = SQLTable("users", label="User Table", style=D2Style(fill="lightblue")) + table.add_field("id", "int", SQLConstraint.PRIMARY_KEY) + + expected = "\n".join( + [ + "users: User Table {", + " shape: sql_table", + " id: int {constraint: primary_key}", + " style: {", + " fill: lightblue", + " }", + "}", + ] + ) + + assert str(table) == expected + + +def test_foreign_key_connection(): + """Test creating a foreign key connection between tables.""" + connection = create_foreign_key_connection("orders", "user_id", "users", "id") + + assert connection.shape_1 == "orders.user_id" + assert connection.shape_2 == "users.id" + assert connection.direction == Direction.TO + assert str(connection) == "orders.user_id -> users.id" + + +def test_foreign_key_connection_with_label(): + """Test creating a foreign key connection with a label.""" + connection = create_foreign_key_connection("orders", "user_id", "users", "id", "belongs to") + + assert connection.label == "belongs to" + assert str(connection) == "orders.user_id -> users.id: belongs to" + + +def test_complex_sql_table_relationship(): + """Test creating multiple SQL tables with relationships.""" + # Create tables + users = SQLTable("users") + users.add_field("id", "int", SQLConstraint.PRIMARY_KEY) + users.add_field("name", "varchar(255)") + + orders = SQLTable("orders") + orders.add_field("id", "int", SQLConstraint.PRIMARY_KEY) + orders.add_field("user_id", "int", SQLConstraint.FOREIGN_KEY) + orders.add_field("total", "decimal(10,2)") + + # Create connection + fk = create_foreign_key_connection("orders", "user_id", "users", "id") + + diagram = D2Diagram() + diagram.add_shape(users) + diagram.add_shape(orders) + diagram.add_connection(fk) + + expected = "\n".join( + [ + "users: {", + " shape: sql_table", + " id: int {constraint: primary_key}", + " name: varchar(255)", + "}", + "orders: {", + " shape: sql_table", + " id: int {constraint: primary_key}", + " user_id: int {constraint: foreign_key}", + " total: decimal(10,2)", + "}", + "orders.user_id -> users.id", + ] + ) + + assert str(diagram) == expected + + +def test_nested_sql_tables(): + """Test creating nested SQL tables within a container.""" + # Create container + + cloud = D2Shape("cloud", label="Cloud Infrastructure") + + # Create SQL tables + disks = SQLTable("disks") + disks.add_field("id", "int", SQLConstraint.PRIMARY_KEY) + + blocks = SQLTable("blocks") + blocks.add_field("id", "int", SQLConstraint.PRIMARY_KEY) + blocks.add_field("disk", "int", SQLConstraint.FOREIGN_KEY) + blocks.add_field("blob", "blob") + + # Create connection + fk = create_foreign_key_connection("blocks", "disk", "disks", "id") + + # Add tables to container + cloud.add_shape(disks) + cloud.add_shape(blocks) + cloud.add_connection(fk) + + expected = "\n".join( + [ + "cloud: Cloud Infrastructure {", + " disks: {", + " shape: sql_table", + " id: int {constraint: primary_key}", + " }", + " blocks: {", + " shape: sql_table", + " id: int {constraint: primary_key}", + " disk: int {constraint: foreign_key}", + " blob: blob", + " }", + " blocks.disk -> disks.id", + "}", + ] + ) + + assert str(cloud) == expected