From d433dfe8ef9c9647a70ec09ba915fb9b7109c4a8 Mon Sep 17 00:00:00 2001 From: "eve n.u" Date: Tue, 16 Dec 2025 21:43:19 -0800 Subject: [PATCH] feat(plot_lines): use ordinal x-axis for integer tick labels Switch x-axis encoding from quantitative (Q) to ordinal (O) to guarantee integer-only tick labels. The x-axis represents a step/epoch index, making ordinal encoding semantically correct. Quantitative encoding with tickMinStep was stripped by vegafusion during Vega compilation. Also adds IPython cell markers (# %%) to test files for interactive visual verification in IDEs with ipykernel support. Each plot function can be visually inspected by running the corresponding cell. Build: adds ipykernel dev dependency for cell execution. --- pyproject.toml | 1 + src/lpm_plot/plot_lines.py | 2 +- tests/test_plot_fidelity.py | 84 ++++++++++++++++++++++++++++++++ tests/test_plot_heatmap.py | 37 +++++++++++++++ tests/test_plot_lines.py | 13 +++++ tests/test_plot_marginal.py | 95 +++++++++++++++++++++++++++++++++++++ uv.lock | 4 +- 7 files changed, 234 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8ac7a4f..754193f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ [dependency-groups] dev = [ + "ipykernel>=6.29.5", "jupyter>=1.1.1", "pytest>=8.3.3", "ruff>=0.13.2", diff --git a/src/lpm_plot/plot_lines.py b/src/lpm_plot/plot_lines.py index 90c5552..6bbcaca 100644 --- a/src/lpm_plot/plot_lines.py +++ b/src/lpm_plot/plot_lines.py @@ -49,7 +49,7 @@ def plot_lines( alt.Chart(df) .mark_line() .encode( - x=alt.X("x:Q", title=x_title), + x=alt.X("x:O", title=x_title, axis=alt.Axis(labelAngle=0)), y=alt.Y( "y:Q", title=y_title, diff --git a/tests/test_plot_fidelity.py b/tests/test_plot_fidelity.py index b0125cb..f32b8f6 100644 --- a/tests/test_plot_fidelity.py +++ b/tests/test_plot_fidelity.py @@ -84,3 +84,87 @@ def test_plot_fidelity_smoke_kl(): assert isinstance( plot_fidelity(fidelity_data_kl, metric="kl"), altair.vegalite.v5.api.LayerChart ) + + +# %% +if __name__ == "__main__": + import polars as pl + + from lpm_plot import plot_fidelity + + fidelity_data = pl.DataFrame( + [ + { + "column-1": "age", + "column-2": "income", + "tvd": 0.02, + "model": "LPM", + "index": 0, + }, + { + "column-1": "age", + "column-2": "education", + "tvd": 0.03, + "model": "LPM", + "index": 1, + }, + { + "column-1": "income", + "column-2": "education", + "tvd": 0.05, + "model": "LPM", + "index": 2, + }, + { + "column-1": "age", + "column-2": "occupation", + "tvd": 0.08, + "model": "LPM", + "index": 3, + }, + { + "column-1": "income", + "column-2": "occupation", + "tvd": 0.12, + "model": "LPM", + "index": 4, + }, + { + "column-1": "age", + "column-2": "income", + "tvd": 0.04, + "model": "Baseline", + "index": 0, + }, + { + "column-1": "age", + "column-2": "education", + "tvd": 0.06, + "model": "Baseline", + "index": 1, + }, + { + "column-1": "income", + "column-2": "education", + "tvd": 0.09, + "model": "Baseline", + "index": 2, + }, + { + "column-1": "age", + "column-2": "occupation", + "tvd": 0.15, + "model": "Baseline", + "index": 3, + }, + { + "column-1": "income", + "column-2": "occupation", + "tvd": 0.20, + "model": "Baseline", + "index": 4, + }, + ] + ) + chart = plot_fidelity(fidelity_data, metric="tvd") + chart.show() diff --git a/tests/test_plot_heatmap.py b/tests/test_plot_heatmap.py index 0845d60..e36b4e8 100644 --- a/tests/test_plot_heatmap.py +++ b/tests/test_plot_heatmap.py @@ -22,3 +22,40 @@ def test_plot_heatmap_smoke(): plot_heatmap(df), altair.vegalite.v5.api.Chart, ) + + +# %% +if __name__ == "__main__": + import polars as pl + + from lpm_plot import plot_heatmap + + df = pl.DataFrame( + { + "Column 1": [ + "age", + "age", + "age", + "income", + "income", + "income", + "education", + "education", + "education", + ], + "Column 2": [ + "age", + "income", + "education", + "age", + "income", + "education", + "age", + "income", + "education", + ], + "Score": [None, 0.8, 0.5, 0.8, None, 0.6, 0.5, 0.6, None], + } + ) + chart = plot_heatmap(df, interactive=False) + chart.show() diff --git a/tests/test_plot_lines.py b/tests/test_plot_lines.py index 735eee9..90ac672 100644 --- a/tests/test_plot_lines.py +++ b/tests/test_plot_lines.py @@ -47,3 +47,16 @@ def test_plot_lines_y_scale_default(): spec = chart.to_dict(format="vega") # Default is linear scale assert spec["scales"][1]["type"] == "linear" + + +# %% + +if __name__ == "__main__": + data = { + "train_loss": [5.0, 3.5, 2.8, 2.2, 1.9, 1.6, 1.4], + "val_loss": [5.2, 3.8, 3.2, 2.8, 2.5, 2.3, 2.1], + "test_loss": [5.5, 4.0, 3.5, 3.0, 2.7, 2.5, 2.3], + } + chart = plot_lines(data, x_title="Epoch", y_title="Loss") + chart.show() +# %% diff --git a/tests/test_plot_marginal.py b/tests/test_plot_marginal.py index 73834df..167f6a3 100644 --- a/tests/test_plot_marginal.py +++ b/tests/test_plot_marginal.py @@ -54,3 +54,98 @@ def test_plot_marginal_2d_smoke(): plot_marginal_2d(combined_df, x, y), altair.vegalite.v5.api.HConcatChart, ) + + +# %% +if __name__ == "__main__": + import polars as pl + + from lpm_plot import ( + plot_marginal_1d, + plot_marginal_2d, + plot_marginal_numerical_categorical, + plot_marginal_numerical_numerical, + ) + + # --- plot_marginal_1d --- + observed_df = pl.DataFrame( + { + "category": ["A", "A", "B", "B", "B", "C"], + "status": ["active", "inactive", "active", "active", "inactive", "active"], + } + ) + synthetic_df = pl.DataFrame( + { + "category": ["A", "B", "B", "C", "C", "C"], + "status": ["active", "active", "inactive", "inactive", "active", "active"], + } + ) + chart_1d = plot_marginal_1d(observed_df, synthetic_df, ["category", "status"]) + chart_1d.show() + +# %% +if __name__ == "__main__": + # --- plot_marginal_2d --- + x, y = "category", "status" + df1 = pl.DataFrame( + { + x: ["A", "A", "B", "B", "C", "C"], + y: ["active", "inactive", "active", "inactive", "active", "inactive"], + "Normalized frequency": [0.25, 0.15, 0.20, 0.10, 0.18, 0.12], + } + ).with_columns(pl.lit("Observed").alias("Source")) + df2 = pl.DataFrame( + { + x: ["A", "A", "B", "B", "C", "C"], + y: ["active", "inactive", "active", "inactive", "active", "inactive"], + "Normalized frequency": [0.22, 0.18, 0.18, 0.12, 0.16, 0.14], + } + ).with_columns(pl.lit("Synthetic").alias("Source")) + combined_df = pl.concat([df1, df2]) + chart_2d = plot_marginal_2d(combined_df, x, y) + chart_2d.show() + +# %% +if __name__ == "__main__": + # --- plot_marginal_numerical_numerical --- + import random + + random.seed(42) + observed_num = pl.DataFrame( + { + "age": [random.gauss(35, 10) for _ in range(100)], + "income": [random.gauss(50000, 15000) for _ in range(100)], + } + ) + synthetic_num = pl.DataFrame( + { + "age": [random.gauss(36, 11) for _ in range(100)], + "income": [random.gauss(52000, 14000) for _ in range(100)], + } + ) + chart_num_num = plot_marginal_numerical_numerical( + observed_num, synthetic_num, "age", "income" + ) + chart_num_num.show() + +# %% +if __name__ == "__main__": + # --- plot_marginal_numerical_categorical --- + observed_cat = pl.DataFrame( + { + "department": ["Sales", "Sales", "Engineering", "Engineering", "HR", "HR"] + * 10, + "salary": [random.gauss(60000, 10000) for _ in range(60)], + } + ) + synthetic_cat = pl.DataFrame( + { + "department": ["Sales", "Sales", "Engineering", "Engineering", "HR", "HR"] + * 10, + "salary": [random.gauss(62000, 11000) for _ in range(60)], + } + ) + chart_num_cat = plot_marginal_numerical_categorical( + observed_cat, synthetic_cat, "department", "salary" + ) + chart_num_cat.show() diff --git a/uv.lock b/uv.lock index c45e4f8..f9554cf 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" [[package]] @@ -783,6 +783,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "ipykernel" }, { name = "jupyter" }, { name = "pytest" }, { name = "ruff" }, @@ -801,6 +802,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "ipykernel", specifier = ">=6.29.5" }, { name = "jupyter", specifier = ">=1.1.1" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "ruff", specifier = ">=0.13.2" },