Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.vscode
*.pyc
# Packages
!/tests/**/*.egg
Expand Down
10 changes: 0 additions & 10 deletions .vscode/extensions.json

This file was deleted.

27 changes: 0 additions & 27 deletions .vscode/settings.json

This file was deleted.

21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,27 @@ See the [tests](/tests/test_py_d2) for more detailed usage examples.
- [x] Icons in shapes
- [x] Support for empty labels
- [x] Shape links
- [ ] SQL table shapes
- [x] SQL table shapes
- [x] Layers
- [ ] Class shapes
- [ ] Comments
- [ ] Layers

## Examples

`examples/`

```sh
poetry run python examples/<example>.py
```

SQL Table:

```sh
poetry run python example/simple_sql_schema.py
# Open diagram:
open simple_sql_schema.svg
```



## Development
Expand Down
42 changes: 42 additions & 0 deletions examples/simple_sql_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import subprocess

from py_d2.diagram import D2Diagram
from py_d2.sql_table import SQLConstraint, SQLTable, create_foreign_key_connection

FILE_NAME = "simple_sql_schema"

# Create a new diagram
diagram = D2Diagram()

# Create Users table
users = SQLTable("users")
users.add_field("id", "int", SQLConstraint.PRIMARY_KEY)
users.add_field("name", "varchar(255)")

# Create Orders table
orders = SQLTable("orders")
orders.add_field("id", "int", SQLConstraint.PRIMARY_KEY)
orders.add_field("user_id", "int", SQLConstraint.FOREIGN_KEY)
orders.add_field("total", "decimal(10,2)")

# Create connection
fk = create_foreign_key_connection("orders", "user_id", "users", "id")

# Add tables and connections to the diagram
diagram.add_shape(users)
diagram.add_shape(orders)
diagram.add_connection(fk)

# Write the diagram to a file
with open(f"{FILE_NAME}.d2", "w") as f:
f.write(str(diagram))

print(f"D2 diagram file created: {FILE_NAME}.d2")
print(str(diagram))

try:
subprocess.run(["d2", "--layout", "elk", f"{FILE_NAME}.d2", f"{FILE_NAME}.svg"], check=True)
print(f"SVG diagram generated: {os.path.abspath(f'{FILE_NAME}.svg')}")
except Exception as e:
print(f"Error generating SVG: {e}")
3 changes: 0 additions & 3 deletions poetry.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,3 @@ cache-dir = ".cache"
[virtualenvs]
create = true
in-project = true

[virtualenvs.options]
always-copy = true
155 changes: 155 additions & 0 deletions src/py_d2/sql_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

from enum import Enum
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

from py_d2.connection import D2Connection
from py_d2.connection import Direction
from py_d2.shape import D2Shape
from py_d2.shape import Shape


class SQLConstraint(Enum):
PRIMARY_KEY = "primary_key"
FOREIGN_KEY = "foreign_key"
UNIQUE = "unique"


class SQLField:
def __init__(
self,
name: str,
data_type: str,
constraint: Optional[Union[SQLConstraint, str, List[Union[SQLConstraint, str]]]] = None,
):
self.name = name
self.data_type = data_type

# Handle constraint(s)
if constraint is None:
self.constraints = []
elif isinstance(constraint, list):
self.constraints = constraint
else:
self.constraints = [constraint]

def to_d2_format(self) -> str:
"""Convert the field to D2 format."""
if not self.constraints:
return f"{self.name}: {self.data_type}"

constraints_str = "; ".join([c.value if isinstance(c, SQLConstraint) else c for c in self.constraints])

if len(self.constraints) == 1:
constraint_part = constraints_str
else:
constraint_part = f"[{constraints_str}]"

return f"{self.name}: {self.data_type} {{constraint: {constraint_part}}}"


class SQLTable(D2Shape):
def __init__(
self,
name: str,
fields: Optional[Dict[str, Union[str, Dict[str, Any]]]] = None,
label: Optional[str] = None,
style: Optional[Any] = None,
icon: Optional[str] = None,
near: Optional[str] = None,
link: Optional[str] = None,
):
super().__init__(
name=name,
label=label,
shape=Shape.sql_table,
style=style,
icon=icon,
near=near,
link=link,
)

self.fields: List[SQLField] = []

# Process fields if provided
if fields:
for field_name, field_info in fields.items():
if isinstance(field_info, str):
# Simple case: just a type
self.add_field(field_name, field_info)
elif isinstance(field_info, dict):
# Complex case: type and constraints
field_type = field_info.get("type", "")
constraint = field_info.get("constraint", None)
self.add_field(field_name, field_type, constraint)

def add_field(
self,
name: str,
data_type: str,
constraint: Optional[Union[SQLConstraint, str, List[Union[SQLConstraint, str]]]] = None,
) -> SQLField:
"""Add a field to the SQL table."""
field = SQLField(name, data_type, constraint)
self.fields.append(field)
return field

def lines(self) -> List[str]:
"""Generate D2 lines for the SQL table."""
# Get the base properties from parent class
properties = []

# Add shape property
if self.shape:
properties.append(f"shape: {self.shape.value}")

# Add fields
for field in self.fields:
properties.append(field.to_d2_format())

# Add other properties
if self.near:
properties.append(f"near: {self.near}")

if self.link:
properties.append(f"link: {self.link}")

if self.style:
properties.extend(self.style.lines())

if self.icon:
properties.append(f"icon: {self.icon}")

# Add child shapes and connections
shapes = [shape.lines() for shape in self.shapes]
connections = [connection.lines() for connection in self.connections]

for shape_lines in shapes:
properties.extend(shape_lines)

for connection_lines in connections:
properties.extend(connection_lines)

# Create the final lines
from py_d2.helpers import add_label_and_properties

lines = add_label_and_properties(self.name, self.label, properties)

return lines


def create_foreign_key_connection(
source_table: str,
source_field: str,
target_table: str,
target_field: str,
label: Optional[str] = None,
) -> D2Connection:
"""Create a foreign key connection between two tables."""
source = f"{source_table}.{source_field}"
target = f"{target_table}.{target_field}"
return D2Connection(source, target, label, Direction.TO)
Loading