diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b179b39..bb0c54c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ Types of changes: ### Removed ### Fixed - +- Fixed multiple axes error in circuit visualization of decomposable gates in `draw` method. ([#209](https://github.com/qBraid/pyqasm/pull/210)) ### Dependencies ### Other diff --git a/src/pyqasm/printer.py b/src/pyqasm/printer.py index 5da4d67b..4d7ea4f0 100644 --- a/src/pyqasm/printer.py +++ b/src/pyqasm/printer.py @@ -297,6 +297,7 @@ def _mpl_setup_figure( sections: list[list[list[QuantumStatement]]], width: float, n_lines: int ) -> tuple[plt.Figure, list[plt.Axes]]: import matplotlib.pyplot as plt + import numpy as np fig_ax_tuple: tuple[plt.Figure, list[plt.Axes] | plt.Axes] = plt.subplots( len(sections), @@ -306,7 +307,11 @@ def _mpl_setup_figure( ) fig, axs = fig_ax_tuple - axs = axs if isinstance(axs, list) else [axs] + axs = ( + axs.flatten().tolist() + if isinstance(axs, np.ndarray) + else [axs] if isinstance(axs, plt.Axes) else axs + ) for ax in axs: ax.set_ylim( diff --git a/tests/visualization/test_mpl_draw.py b/tests/visualization/test_mpl_draw.py index c5b2b386..f80b78df 100644 --- a/tests/visualization/test_mpl_draw.py +++ b/tests/visualization/test_mpl_draw.py @@ -77,6 +77,22 @@ def test_draw_qasm3_custom_gate(): _check_fig(circ, fig) +def test_draw_qasm3_decomposable_gate(): + qasm = """ + OPENQASM 3.0; + qubit[2] q1; + qreg q[3]; + creg c[3]; + ccx q[0], q[1], q1[0]; + crx (0.1) q[0], q[1]; + rccx q[0], q[1], q1[0]; + cz q[0], q[1]; + """ + circ = loads(qasm) + fig = mpl_draw(circ) + _check_fig(circ, fig) + + def test_draw_qasm2_simple(): """Test drawing a simple QASM 2.0 circuit.""" qasm = """