From 3fe1780a19339320ff9bb0e1886f45d3bf2ada00 Mon Sep 17 00:00:00 2001 From: Jan Kudlicka Date: Thu, 26 Mar 2026 13:29:11 +0100 Subject: [PATCH] Add compile and run argument classes - Introduce classes for compile and run arguments. - Make compile arguments mirror tpplc arguments. - Support saving and loading run arguments. - Enable saving inference results to a file. - Remove obsolete syntax highlighting in Jupyter notebooks. - Refactor code for better PEP 8 compliance. --- examples/coin.py | 29 +++-- examples/crbd.py | 14 ++- examples/generative_crbd.py | 11 +- examples/treeppl_in_jupyter.ipynb | 32 +++--- setup.py | 14 +-- tests/test_all.py | 21 ++-- treeppl/__init__.py | 2 +- treeppl/base.py | 171 +++++++++++++++++++----------- treeppl/ipython/__init__.py | 73 ++----------- treeppl/serialization.py | 7 +- treeppl/stdlib.py | 8 +- 11 files changed, 183 insertions(+), 199 deletions(-) diff --git a/examples/coin.py b/examples/coin.py index 7bc33b1..0c29c26 100755 --- a/examples/coin.py +++ b/examples/coin.py @@ -4,15 +4,30 @@ import matplotlib.pyplot as plt import seaborn as sns - -with treeppl.Model(filename="coin.tppl", samples=100000) as coin: +with treeppl.Model(filename="coin.tppl", m="smc-bpf", particles=100000) as coin: res = coin( outcomes=[ - True, True, True, False, True, False, False, True, True, False, - False, False, True, False, True, False, False, False, False, False, + True, + True, + True, + False, + True, + False, + False, + True, + True, + False, + False, + False, + True, + False, + True, + False, + False, + False, + False, + False, ] ) - sns.histplot( - x=res.samples, weights=res.nweights, bins=100, stat="density", kde=True - ) + sns.histplot(x=res.samples, weights=res.nweights, bins=100, stat="density", kde=True) plt.show() diff --git a/examples/crbd.py b/examples/crbd.py index dd2b3e0..1a2d834 100755 --- a/examples/crbd.py +++ b/examples/crbd.py @@ -6,18 +6,16 @@ import matplotlib.pyplot as plt import seaborn as sns - alcedinidae = treeppl.Tree.load("trees/Alcedinidae.phyjson", format="phyjson") samples = None -with treeppl.Model(filename="crbd.tppl", samples=10000, subsamples=10) as crbd: +with treeppl.Model( + filename="crbd.tppl", m="smc-bpf", particles=10000, resample="align", subsample=True, subsample_size=10 +) as crbd: for i in range(1000): res = crbd(tree=alcedinidae) - samples = pd.concat([ - samples, - pd.DataFrame({ - "lambda": res.items(0), "mu": res.items(1), "lweight": res.norm_const - }) - ]) + samples = pd.concat( + [samples, pd.DataFrame({"lambda": res.items(0), "mu": res.items(1), "lweight": res.norm_const})] + ) weights = np.exp(samples.lweight - samples.lweight.max()) plt.clf() sns.kdeplot(data=samples, x="lambda", weights=weights) diff --git a/examples/generative_crbd.py b/examples/generative_crbd.py index 9d0e117..6065e29 100755 --- a/examples/generative_crbd.py +++ b/examples/generative_crbd.py @@ -3,13 +3,10 @@ import treeppl from Bio import Phylo +args = {"time": 5.0, "lambda": 1.0, "mu": 0.1} -params = {"time": 5.0, "lambda": 1.0, "mu": 0.1} - -with treeppl.Model(filename="generative_crbd.tppl", samples=1) as generative_crbd: - result = generative_crbd(**params) +with treeppl.Model(filename="generative_crbd.tppl", particles=1) as generative_crbd: + result = generative_crbd(args) tree = result.samples[0] - tree = Phylo.BaseTree.Clade( - branch_length=params["time"] - tree.age, clades=[tree.to_biopython()] - ) + tree = Phylo.BaseTree.Clade(branch_length=args["time"] - tree.age, clades=[tree.to_biopython()]) Phylo.draw(tree) diff --git a/examples/treeppl_in_jupyter.ipynb b/examples/treeppl_in_jupyter.ipynb index e85ff17..ff7e90b 100644 --- a/examples/treeppl_in_jupyter.ipynb +++ b/examples/treeppl_in_jupyter.ipynb @@ -57,9 +57,7 @@ "id": "a67a0418", "metadata": {}, "source": [ - "Once the extension is loaded, users can use the `%%treeppl` cell magic to write and compile a TreePPL program. Executing the cell creates a `treeppl.Model` object, which allows interaction with the compiled program directly in Python. This object is stored in a Python variable specified as an argument to the `%%treeppl` magic (i.e., immediately following `%%treeppl` on the first line). Optionally, the variable name can be followed by parameters and values. Most of these are passed to the TreePPL compiler, except for the `samples` parameter, which specifies the number of samples to draw when the program is executed. Examples are provided below.\n", - "\n", - "The extension also supports basic syntax highlighting for TreePPL programs.\n", + "Once the extension is loaded, users can use the `%%treeppl` cell magic to write and compile a TreePPL program. Executing the cell creates a `treeppl.Model` object, which allows interaction with the compiled program directly in Python. This object is stored in a Python variable specified as an argument to the `%%treeppl` magic (i.e., immediately following `%%treeppl` on the first line). The variable name can be followed by arguments and values, all of which are passed directly to the TreePPL compiler. Examples are provided below.\n", "\n", "For example, the following cell demonstrates a simple TreePPL program for simulating the flip of a fair coin:" ] @@ -71,7 +69,7 @@ "metadata": {}, "outputs": [], "source": [ - "%%treeppl flip samples=10\n", + "%%treeppl flip --particles 10\n", "\n", "model function flip() => Bool {\n", " assume p ~ Bernoulli(0.5);\n", @@ -84,9 +82,9 @@ "id": "d71e9d44", "metadata": {}, "source": [ - "In this example, a `treeppl.Model` instance is created, and the program is compiled. The variable name `flip` (specified after `%%treeppl`) provides an interface for interacting with the model. The argument `samples=10` specifies the number of samples to generate when the program is executed.\n", + "In this example, a `treeppl.Model` instance is created, and the program is compiled. The variable name `flip` (specified after `%%treeppl`) provides an interface for interacting with the model. The argument `--particles 10` sets the number of particles used when executing the program. If no inference method is specified via the `--m` parameter, the default is importance sampling.\n", "\n", - "To run the TreePPL program, simply call the variable as a function (e.g., `flip()`). This executes the program and returns a `treeppl.InferenceResult` object, which includes a `samples` attribute containing the generated samples. While samples may have different weights in more complex programs, they are equally weighted in this simple example. We will cover programs with weighted samples later.\n", + "To run the TreePPL program, simply call the variable as a function (e.g., `flip()`). This executes the program and returns a `treeppl.InferenceResult` object, which includes a `samples` attribute containing the generated samples. While samples may have different weights in more complex programs, they are equally weighted in this simple example. Programs with weighted samples will be demonstrated later.\n", "\n", "Here’s an example of how to use the compiled program:" ] @@ -125,7 +123,7 @@ "metadata": {}, "outputs": [], "source": [ - "%%treeppl coin samples=100000\n", + "%%treeppl coin -m smc-bpf --particles 100000\n", "\n", "model function coin(outcomes: Bool[]) => Real {\n", " assume p ~ Uniform(0.0, 1.0);\n", @@ -169,7 +167,7 @@ "metadata": {}, "outputs": [], "source": [ - "%%treeppl generative_crbd samples=1\n", + "%%treeppl generative_crbd --particles 1\n", "\n", "model function generativeCrbd(time: Real, lambda: Real, mu: Real) => Tree {\n", " assume waitingTime ~ Exponential(lambda + mu);\n", @@ -198,16 +196,16 @@ "metadata": {}, "outputs": [], "source": [ - "params = {\n", + "args = {\n", " \"time\": 5.0,\n", " \"lambda\": 1.0,\n", " \"mu\": 0.1\n", "}\n", "\n", - "result = generative_crbd(**params)\n", + "result = generative_crbd(args)\n", "tree = result.samples[0]\n", "tree = Phylo.BaseTree.Clade(\n", - " branch_length=params[\"time\"] - tree.age,\n", + " branch_length=args[\"time\"] - tree.age,\n", " clades=[tree.to_biopython()]\n", ")\n", "Phylo.draw(tree)" @@ -228,7 +226,7 @@ "metadata": {}, "outputs": [], "source": [ - "%%treeppl crbd samples=10000 subsamples=10\n", + "%%treeppl crbd -m smc-bpf --particles 10000 --resample align --subsample --subsample-size 10\n", "\n", "function simulateExtinctSubtree(time: Real, lambda: Real, mu: Real) {\n", " assume waitingTime ~ Exponential(lambda + mu);\n", @@ -275,9 +273,7 @@ "cell_type": "code", "execution_count": null, "id": "00bfd918", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "alcedinidae = treeppl.Tree.load(\"trees/Alcedinidae.phyjson\", format=\"phyjson\")\n", @@ -306,9 +302,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python [conda env:base] *", "language": "python", - "name": "python3" + "name": "conda-base-py" }, "language_info": { "codemirror_mode": { @@ -320,7 +316,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.13.9" } }, "nbformat": 4, diff --git a/setup.py b/setup.py index 8d9e8e3..f0f8cc3 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,13 @@ from setuptools import setup setup( - name='treeppl', - version='0.1', - description='Python Interface to TreePPL', - author='Jan Kudlicka', - author_email='github@kudlicka.eu', - packages=['treeppl'], + name="treeppl", + version="0.1", + description="Python Interface to TreePPL", + author="Jan Kudlicka", + author_email="github@kudlicka.eu", + packages=["treeppl"], install_requires=[ - 'numpy', + "numpy", ], ) diff --git a/tests/test_all.py b/tests/test_all.py index f89a189..270442c 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -5,23 +5,20 @@ def test_compilation_error(): with pytest.raises(treeppl.CompileError): - with treeppl.Model( - source="""\ + with treeppl.Model(source="""\ model incorrect() { return result; } -""", - samples=1, - ) as model: +""") as model: model() def test_coin(): - nsamples = 100 - with treeppl.Model(filename="examples/coin.tppl", samples=nsamples) as coin: + particles = 100 + with treeppl.Model(filename="examples/coin.tppl", m="smc-bpf", particles=particles) as coin: res = coin(outcomes=(np.random.random(10) < 0.5).tolist()) assert isinstance(res, treeppl.InferenceResult) - assert len(res.samples) == nsamples + assert len(res.samples) == particles def test_matrix_mul(): @@ -35,7 +32,7 @@ def test_matrix_mul(): return a *@ b; } """, - samples=1, + particles=1, ) as matrix_mul: result = matrix_mul(a=a, b=b).samples[0] assert (result == expected).all() @@ -43,9 +40,7 @@ def test_matrix_mul(): def test_tree_input(): tree = treeppl.Tree.Node( - left=treeppl.Tree.Node( - left=treeppl.Tree.Leaf(age=0.0), right=treeppl.Tree.Leaf(age=0.0), age=0.5 - ), + left=treeppl.Tree.Node(left=treeppl.Tree.Leaf(age=0.0), right=treeppl.Tree.Leaf(age=0.0), age=0.5), right=treeppl.Tree.Leaf(age=0.0), age=1.0, ) @@ -62,7 +57,7 @@ def test_tree_input(): return count_leaves(tree); } """, - samples=1, + particles=1, ) as count_tree_leaves: res = count_tree_leaves(tree=tree) assert res.samples[0] == 3 diff --git a/treeppl/__init__.py b/treeppl/__init__.py index ffb3a2d..5779b43 100644 --- a/treeppl/__init__.py +++ b/treeppl/__init__.py @@ -1,4 +1,4 @@ -from .base import Model, InferenceResult +from .base import Model, InferenceResult, CompileArguments, RunArguments from .exceptions import CompileError, InferenceError from .serialization import Object, constructor from .stdlib import Tree diff --git a/treeppl/base.py b/treeppl/base.py index 5a3422e..a99877a 100644 --- a/treeppl/base.py +++ b/treeppl/base.py @@ -1,20 +1,20 @@ import json from operator import itemgetter from tempfile import TemporaryDirectory -from subprocess import Popen, PIPE, STDOUT +import subprocess import numpy as np - import tarfile -import shutil import os import importlib +import shlex from .exceptions import CompileError, InferenceError from .serialization import from_json, to_json + def get_tpplc_binary(): - if os.environ.get('TPPLC'): - return os.environ['TPPLC'] + if os.environ.get("TPPLC"): + return os.environ["TPPLC"] # NOTE(vipa, 2025-06-04): The selfcontained compiler must be # deployed to a directory somewhere. There are three important # limitatations: @@ -33,64 +33,105 @@ def get_tpplc_binary(): # might share (immutable) deployed directories between virtual # envs, which seems fine, and that differing versions will get # differing deployed directories. - tmp_dir = '/tmp' - deployed_basename = "@DEPLOYED_BASENAME@" # This will be substituted by nix when building the wheel - tarball_name = "@TARBALL_NAME@" # This will be substituted by nix when building the wheel + tmp_dir = "/tmp" + deployed_basename = "@DEPLOYED_BASENAME@" # This will be substituted by nix when building the wheel + tarball_name = "@TARBALL_NAME@" # This will be substituted by nix when building the wheel tppl_dir_path = os.path.join(tmp_dir, deployed_basename) if not os.path.isdir(tppl_dir_path): - with importlib.resources.path('treeppl', tarball_name) as tarball: + with importlib.resources.path("treeppl", tarball_name) as tarball: with tarfile.open(tarball) as tar: - if hasattr(tarfile, 'tar_filter'): - tar.extractall(tmp_dir, filter='tar') + if hasattr(tarfile, "tar_filter"): + tar.extractall(tmp_dir, filter="tar") else: # NOTE(vipa, 2025-11-20): This allows us to support Python < 3.12 tar.extractall(tmp_dir) return os.path.join(tppl_dir_path, "tpplc") + +class Arguments(dict): + def __getattr__(self, name): + return self[name] + + def __setattr__(self, name, value): + self[name] = value + + +class CompileArguments(Arguments): + def as_list(self): + res = [] + for k, v in self.items(): + k = k.replace("_", "-") + res.append(f"-{k}" if len(k) == 1 else f"--{k}") + if v is not True: + res.append(str(v)) + return res + + def parse_opts(self, s): + t = shlex.split(s) + i = 0 + while i < len(t): + k = t[i].lstrip("-").replace("-", "_") + if i + 1 < len(t) and not t[i + 1].startswith("-"): + v = t[i + 1] + i += 1 + if v.isdigit(): + v = int(v) + else: + try: + v = float(v) + except: + pass + else: + v = True + self[k] = v + i += 1 + + +class RunArguments(Arguments): + def load_json(self, filename): + with open(filename) as f: + self.update(from_json(f)) + + @classmethod + def from_json(cls, filename): + with open(filename) as f: + data = from_json(f) + return cls(**data) + + def save_json(self, filename): + with open(filename, "w") as f: + to_json(self, f) + + class Model: - def __init__( - self, - source=None, - filename=None, - method="smc-bpf", - samples=1000, - subsamples=None, - **kwargs, - ): - self.temp_dir = TemporaryDirectory(prefix="treeppl_") + def __init__(self, source=None, filename=None, **kwargs): + self.compile_arguments = None + self.run_arguments = None + self.source = source if filename: - source = open(filename).read() - if not source: - raise CompileError("No source code to compile.") + with open(filename) as f: + self.source = f.read() + self.temp_dir = TemporaryDirectory(prefix="treeppl_") + self.compile(**kwargs) + + def compile(self, **kwargs): + self.compile_arguments = CompileArguments(kwargs) + if not self.source: + raise CompileError("no source code to compile") with open(self.temp_dir.name + "/__main__.tppl", "w") as f: - f.write(source) - args = [ - get_tpplc_binary(), - "__main__.tppl", - "-m", - method, - "-p", - str(samples), - ] - if subsamples: - args.extend(["--subsample", "-n", str(subsamples)]) - for k, v in kwargs.items(): - args.append(f"--{k.replace('_', '-')}") - if v is not True: - args.append(str(v)) - with Popen( - args=args, cwd=self.temp_dir.name, stdout=PIPE, stderr=STDOUT - ) as proc: - try: - proc.wait() - except KeyboardInterrupt: - output = proc.stdout.read().decode("utf-8") - output = output.replace("__main__.tppl", "source code") - raise CompileError(f"Could not compile the TreePPL model:\n{output}") - if proc.returncode != 0: - output = proc.stdout.read().decode("utf-8") - output = output.replace("__main__.tppl", "source code") - raise CompileError(f"Could not compile the TreePPL model:\n{output}") + f.write(self.source) + args = [get_tpplc_binary(), "__main__.tppl"] + args.extend(self.compile_arguments.as_list()) + result = None + try: + result = subprocess.run(args=args, cwd=self.temp_dir.name, capture_output=True, text=True) + except KeyboardInterrupt: + pass + if result is None: + raise CompileError("could not compile the TreePPL model") + elif result.returncode != 0: + output = result.stdout.replace("__main__.tppl", "source code") + raise CompileError(f"could not compile the TreePPL model:\n{output}") def __enter__(self): return self @@ -98,27 +139,35 @@ def __enter__(self): def __exit__(self, *args): self.temp_dir.cleanup() - def __call__(self, **kwargs): - with open(self.temp_dir.name + "/input.json", "w") as f: - to_json(kwargs or {}, f) + def run(self, arguments=None, **kwargs): + self.run_arguments = RunArguments(arguments or {}, **kwargs) + self.run_arguments.save_json(self.temp_dir.name + "/input.json") args = [ f"{self.temp_dir.name}/out", f"{self.temp_dir.name}/input.json", ] - with Popen(args=args, stdout=PIPE) as proc: + with subprocess.Popen(args=args, stdout=subprocess.PIPE) as proc: return InferenceResult(proc.stdout) + def __call__(self, *args, **kwargs): + return self.run(*args, **kwargs) + class InferenceResult: def __init__(self, stdout): + self.result = None try: - result = from_json(stdout) + self.result = from_json(stdout) except json.decoder.JSONDecodeError: - raise InferenceError("Could not parse the output from TreePPL.") - self.samples = result.get("samples", []) - self.weights = np.array(result.get("weights", [])) + raise InferenceError("could not parse the output from TreePPL") + self.samples = self.result.get("samples", []) + self.weights = np.array(self.result.get("weights", [])) self.nweights = np.exp(self.weights) - self.norm_const = result.get("normConst", np.nan) + self.norm_const = self.result.get("normConst", np.nan) def items(self, *args): return list(map(itemgetter(*args), self.samples)) + + def save_json(self, filename): + with open(filename, "w") as f: + to_json(self.result, f) diff --git a/treeppl/ipython/__init__.py b/treeppl/ipython/__init__.py index 5d6a8a0..825ab1c 100644 --- a/treeppl/ipython/__init__.py +++ b/treeppl/ipython/__init__.py @@ -1,26 +1,24 @@ from IPython.core.magic import register_cell_magic -from IPython.display import Javascript, display import treeppl as treeppl_ @register_cell_magic def treeppl(line, cell): - args = line.split() + args = line.split(maxsplit=1) if not args: - raise ValueError("You must provide a variable name after %%treeppl") + raise ValueError("you must provide a variable name after %%treeppl") ip = get_ipython() - kwargs = dict( - arg.strip().split("=", 1) if "=" in arg else (arg.strip(), True) - for arg in args[1:] - ) - ip.user_ns[args[0]] = treeppl_.Model(source=cell, **kwargs) + compile_arguments = treeppl_.CompileArguments() + if len(args) == 2: + compile_arguments.parse_opts(args[1]) + ip.user_ns[args[0]] = treeppl_.Model(source=cell, **compile_arguments) @register_cell_magic def treeppl_source(line, cell): variable_name = line.strip() if not variable_name: - raise ValueError("You must provide a variable name after %%treeppl_source") + raise ValueError("you must provide a variable name after %%treeppl_source") ip = get_ipython() ip.user_ns[variable_name] = cell @@ -28,60 +26,3 @@ def treeppl_source(line, cell): def load_ipython_extension(ipython): ipython.register_magic_function(treeppl, "cell") ipython.register_magic_function(treeppl_source, "cell") - display( - Javascript( - r""" - require(['notebook/js/codecell'], function(codecell) { - CodeMirror.defineMode("treeppl", function(config, parserConfig) { - return { - startState: function() { - return { def: false }; - }, - token: function(stream, state) { - if (stream.eatSpace()) { - return null; - } - let def = state.def; - state.def = false; - if (stream.match('//')) { - stream.skipToEnd(); - return "comment"; - } - if (stream.match('/*')) { - stream.skipTo('*\\/'); - return "comment"; - } - if (stream.match(/\b(type|function)\b/)) { - state.def = true; - return "keyword" - } - if (stream.match(/\b(model|let|is|in|to|assume|observe|weight|logWeight|resample|if|else|for|return)\b/)) { - return "keyword"; - } - if (stream.match(/\b(true|false)\b/)) { - return "atom"; - } - if (stream.match(/~|=|==|!=|<|>|<=|>=|\+@|\+|-\*@|\*\$|\*|\$|\^|\/|\^|\|\|/)) { - return "operator"; - } - if (stream.match(/\b([A-Za-z][_A-Za-z0-9]*)\b/)) { - return def? "def": "variable"; - } - if (stream.match(/\b[+-]?\d+(\.\d)*([Ee][+-]?\d+)?\b/)) { - return "number"; - } - if (stream.match(/"([^"\\]|\\.)*"/)) { - return "string"; - } - stream.next(); - return null; - } - }; - }); - codecell.CodeCell.options_default.highlight_modes['treeppl'] = { - reg: ['^%%treeppl'] - }; - }); - """ - ) - ) diff --git a/treeppl/serialization.py b/treeppl/serialization.py index 8e925d1..c5c5669 100644 --- a/treeppl/serialization.py +++ b/treeppl/serialization.py @@ -3,7 +3,6 @@ from .exceptions import SerializationError - constructor_to_class = {} class_to_constructor = {} @@ -54,13 +53,11 @@ def default(self, obj): except TypeError: try: return { - "__constructor__": class_to_constructor.get( - obj.__class__, obj.__class__.__name__ - ), + "__constructor__": class_to_constructor.get(obj.__class__, obj.__class__.__name__), "__data__": vars(obj), } except: - raise SerializationError("Could not serialize the data.") + raise SerializationError("could not serialize the data") def to_json(value, fp): diff --git a/treeppl/stdlib.py b/treeppl/stdlib.py index 525000c..3eb848e 100644 --- a/treeppl/stdlib.py +++ b/treeppl/stdlib.py @@ -18,9 +18,7 @@ def to_biopython(self, parent=None): parent = self return Phylo.BaseTree.Clade( branch_length=parent.age - self.age, - clades=[ - child.to_biopython(parent=self) for child in [self.left, self.right] - ], + clades=[child.to_biopython(parent=self) for child in [self.left, self.right]], ) @constructor("Leaf") @@ -64,9 +62,7 @@ def load_phyjson(filename): def age(node): children = node.get("children") if children: - return max(0.0, node.get("branch_length", 0)) + max( - age(children[0]), age(children[1]) - ) + return max(0.0, node.get("branch_length", 0)) + max(age(children[0]), age(children[1])) return node["branch_length"] def convert(node, age):