Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,877 changes: 0 additions & 2,877 deletions poetry.lock

This file was deleted.

1 change: 1 addition & 0 deletions src/lpm_plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .plot_fidelity import plot_fidelity
from .plot_heatmap import plot_heatmap, reformat_data
from .plot_lines import plot_lines
from .plot_marginal import (
plot_marginal_1d,
plot_marginal_2d,
Expand Down
63 changes: 63 additions & 0 deletions src/lpm_plot/plot_lines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import altair as alt
import polars as pl

alt.data_transformers.enable("vegafusion")


def plot_lines(
data: dict[str, list[float]],
x_title: str = "Step",
y_title: str = "Value",
width: int = 500,
height: int = 300,
y_scale: str | None = None,
) -> alt.Chart:
"""Plot multiple lines on a single chart.

Args:
data: Dictionary mapping series names to lists of y-values.
All lists must have the same length. X-axis is the index.
x_title: Label for x-axis.
y_title: Label for y-axis.
width: Chart width in pixels.
height: Chart height in pixels.
y_scale: Scale type for y-axis (e.g., "log", "sqrt", "symlog").
If None, uses linear scale.

Returns:
Altair Chart with one line per series.
"""
if not data:
raise ValueError("data must not be empty")

lengths = [len(v) for v in data.values()]
if len(set(lengths)) != 1:
raise ValueError("All series must have the same length")

n_steps = lengths[0]
series_names = list(data.keys())

df = pl.DataFrame(
{
"x": list(range(n_steps)) * len(series_names),
"y": [v for series in series_names for v in data[series]],
"series": [name for name in series_names for _ in range(n_steps)],
}
)

chart = (
alt.Chart(df)
.mark_line()
.encode(
x=alt.X("x:Q", title=x_title),
y=alt.Y(
"y:Q",
title=y_title,
scale=alt.Scale(type=y_scale) if y_scale else alt.Undefined,
),
color=alt.Color("series:N", legend=alt.Legend(title="Series")),
)
.properties(width=width, height=height)
)

return chart
2 changes: 1 addition & 1 deletion tests/test_plot_fidelity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import polars as pl
import altair
import polars as pl

from lpm_plot import plot_fidelity

Expand Down
49 changes: 49 additions & 0 deletions tests/test_plot_lines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import altair
import pytest

from lpm_plot import plot_lines


def test_plot_lines_smoke():
data = {
"train": [5.0, 4.0, 3.0, 2.5],
"test": [5.2, 4.3, 3.5, 3.0],
}
chart = plot_lines(data)
assert isinstance(chart, altair.vegalite.v5.api.Chart)


def test_plot_lines_empty_raises():
with pytest.raises(ValueError, match="must not be empty"):
plot_lines({})


def test_plot_lines_unequal_lengths_raises():
data = {
"a": [1.0, 2.0, 3.0],
"b": [1.0, 2.0],
}
with pytest.raises(ValueError, match="same length"):
plot_lines(data)


def test_plot_lines_y_scale_log():
data = {"series": [1.0, 10.0, 100.0]}
chart = plot_lines(data, y_scale="log")
spec = chart.to_dict(format="vega")
assert spec["scales"][1]["type"] == "log"


def test_plot_lines_y_scale_sqrt():
data = {"series": [1.0, 4.0, 9.0]}
chart = plot_lines(data, y_scale="sqrt")
spec = chart.to_dict(format="vega")
assert spec["scales"][1]["type"] == "sqrt"


def test_plot_lines_y_scale_default():
data = {"series": [1.0, 2.0, 3.0]}
chart = plot_lines(data)
spec = chart.to_dict(format="vega")
# Default is linear scale
assert spec["scales"][1]["type"] == "linear"
3 changes: 1 addition & 2 deletions tests/test_plot_marginal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import altair
import polars as pl

from lpm_plot import plot_marginal_1d
from lpm_plot import plot_marginal_2d
from lpm_plot import plot_marginal_1d, plot_marginal_2d


def test_plot_marginal_1d_smoke():
Expand Down