diff --git a/samplomatic/builders/box_builder.py b/samplomatic/builders/box_builder.py index 28983089..f8a89ff0 100644 --- a/samplomatic/builders/box_builder.py +++ b/samplomatic/builders/box_builder.py @@ -172,12 +172,30 @@ def rhs(self): class RightBoxBuilder(BoxBuilder): """Box builder for right dressings.""" - def parse(self, instr: CircuitInstruction): - if (name := instr.operation.name).startswith("meas"): - raise RuntimeError("Boxes with measurements cannot have dressing=right.") + def __init__(self, collection: CollectionSpec, emission: EmissionSpec): + super().__init__(collection=collection, emission=emission) + + self.measured_qubits = QubitPartition(1, []) + self.clbit_idxs = [] - if name == "barrier": + def parse(self, instr: CircuitInstruction): + if (name := instr.operation.name).startswith("barrier"): spec = InstructionSpec(params=self.template_state.append_remapped_gate(instr)) + return + + if name.startswith("meas"): + for qubit in instr.qubits: + if (qubit,) not in self.measured_qubits: + self.measured_qubits.add((qubit,)) + else: + raise SamplexBuildError( + "Cannot measure the same qubit twice in a twirling box." + ) + self.template_state.append_remapped_gate(instr) + self.clbit_idxs.extend( + [self.template_state.template.find_bit(clbit)[0] for clbit in instr.clbits] + ) + return elif (num_qubits := instr.operation.num_qubits) == 1: self.entangled_qubits.update(instr.qubits) @@ -220,5 +238,12 @@ def lhs(self): def rhs(self): self._append_barrier("M") param_idxs = self._append_dressed_layer() + if twirl_type := self.emission.twirl_register_type: + if len(self.measured_qubits) != 0: + if twirl_type != VirtualType.PAULI: + raise SamplexBuildError( + f"Cannot use {twirl_type.value} twirl in a box with measurements." + ) + self.samplex_state.add_z2_collect(self.measured_qubits, self.clbit_idxs) self.samplex_state.add_collect(self.collection.qubits, self.collection.synth, param_idxs) self._append_barrier("R") diff --git a/samplomatic/pre_samplex/pre_samplex.py b/samplomatic/pre_samplex/pre_samplex.py index 63660cf2..99b58df1 100644 --- a/samplomatic/pre_samplex/pre_samplex.py +++ b/samplomatic/pre_samplex/pre_samplex.py @@ -79,7 +79,8 @@ from ..samplex.noise_model_requirement import NoiseModelRequirement from ..synths import Synth from ..tensor_interface import TensorSpecification -from ..virtual_registers import U2Register +from ..virtual_registers import PauliRegister, U2Register +from ..virtual_registers.pauli_register import PAULI_GATE_NAMES from ..visualization import plot_graph from .graph_data import ( PreBasisTransform, @@ -1351,8 +1352,8 @@ def add_propagate_node( for predecssor_idx in self.graph.predecessor_indices(pre_propagate_idx): incoming.add(samplex.graph[pre_nodes_to_nodes[predecssor_idx]].outgoing_register_type) if mode is InstructionMode.MULTIPLY and pre_propagate.operation.num_qubits == 1: - combined_register_type = VirtualType.U2 if pre_propagate.operation.is_parameterized(): + combined_register_type = VirtualType.U2 param_idxs = [ samplex.append_parameter_expression(param) for _, param in pre_propagate.spec.params @@ -1366,7 +1367,15 @@ def add_propagate_node( op_name, combined_register_name, param_idxs ) else: - register = U2Register(np.array(pre_propagate.operation).reshape(1, 1, 2, 2)) + if ( + incoming == {VirtualType.PAULI} + and (name := pre_propagate.operation.name) in PAULI_GATE_NAMES + ): + combined_register_type = VirtualType.PAULI + register = PauliRegister.from_name(name) + else: + combined_register_type = VirtualType.U2 + register = U2Register(np.array(pre_propagate.operation).reshape(1, 1, 2, 2)) if pre_propagate.direction is Direction.LEFT: propagate_node = RightMultiplicationNode(register, combined_register_name) else: diff --git a/samplomatic/virtual_registers/pauli_register.py b/samplomatic/virtual_registers/pauli_register.py index d1343308..340592a4 100644 --- a/samplomatic/virtual_registers/pauli_register.py +++ b/samplomatic/virtual_registers/pauli_register.py @@ -23,6 +23,8 @@ from .u2_register import U2Register from .z2_register import Z2Register +PAULI_GATE_NAMES = {"id": 0, "x": 2, "z": 1, "y": 3} + PAULI_TO_U2 = np.array( [np.diag([1, 1]), np.diag([1, -1]), np.diag([1, 1])[::-1], np.diag([-1j, 1j])[::-1]], dtype=U2Register.DTYPE, @@ -60,6 +62,23 @@ def __init__(self, virtual_gates): def identity(cls, num_subsystems, num_samples): return cls(np.zeros((num_subsystems, num_samples), dtype=np.uint8)) + @classmethod + def from_name(cls, name: str) -> PauliRegister: + """Returns a Pauli register given a its name. + + Args: + name: The name of the Pauli. + + Returns: The Pauli register. + + Raises: + VirtualGateError: If ``name`` is not in 'PAULI_GATE_NAMES'. + """ + try: + return cls(np.array([PAULI_GATE_NAMES[name]], dtype=np.uint8).reshape(1, 1)) + except KeyError: + raise VirtualGateError(f"'{name}' is not a valid Pauli.") + def convert_to(self, register_type): if register_type is VirtualType.U2: return U2Register(PAULI_TO_U2[self._array, :, :]) @@ -76,7 +95,7 @@ def multiply(self, other, subsystem_idxs: list[SubsystemIndex] | slice = slice(N except (ValueError, IndexError) as exc: raise VirtualGateError( f"Register {self} and {other} have incompatible shapes or types, " - f"given subsystem_idxs {subsystem_idxs}" + f"given subsystem_idxs {subsystem_idxs}." ) from exc def inplace_multiply(self, other, subsystem_idxs: list[SubsystemIndex] | slice = slice(None)): @@ -87,7 +106,7 @@ def inplace_multiply(self, other, subsystem_idxs: list[SubsystemIndex] | slice = except (ValueError, IndexError) as exc: raise VirtualGateError( f"Register {self} and {other} have incompatible shapes or types, " - f"given subsystem_idxs {subsystem_idxs}" + f"given subsystem_idxs {subsystem_idxs}." ) from exc def invert(self): diff --git a/test/integration/test_measurement_twirling.py b/test/integration/test_measurement_twirling.py index ef33d5b7..e60d0969 100644 --- a/test/integration/test_measurement_twirling.py +++ b/test/integration/test_measurement_twirling.py @@ -54,6 +54,14 @@ def test_measure_all(self, save_plot): circuit.measure_all() sample_simulate_and_compare_counts(circuit, save_plot) + def test_measure_all_right(self, save_plot): + circuit = QuantumCircuit(3) + with circuit.box([Twirl()]): + circuit.noop(*circuit.qubits) + with circuit.box([Twirl(dressing="right")]): + circuit.measure_all() + sample_simulate_and_compare_counts(circuit, save_plot) + def test_gates_and_measure_all(self, save_plot): circuit = QuantumCircuit(3) with circuit.box([Twirl(dressing="left")]): @@ -75,6 +83,16 @@ def test_separate_measures(self, save_plot): sample_simulate_and_compare_counts(circuit, save_plot) + def test_measure_with_different_dressings(self, save_plot): + circuit = QuantumCircuit(QuantumRegister(size=2), ClassicalRegister(name="meas", size=2)) + with circuit.box([Twirl(dressing="left")]): + circuit.measure(0, 1) + circuit.x(1) + with circuit.box([Twirl(dressing="right")]): + circuit.measure(1, 0) + + sample_simulate_and_compare_counts(circuit, save_plot) + @pytest.mark.skip(reason="QiskitAer bug #2367") def test_separate_measure_boxes(self, save_plot): """Test separate measurement boxes, with non-standard cbit associations""" diff --git a/test/unit/test_virtual_registers/test_pauli_register.py b/test/unit/test_virtual_registers/test_pauli_register.py index e6c5fdc0..f9a2e849 100644 --- a/test/unit/test_virtual_registers/test_pauli_register.py +++ b/test/unit/test_virtual_registers/test_pauli_register.py @@ -13,8 +13,10 @@ """Test the PauliRegister""" import numpy as np +import pytest from samplomatic.annotations import VirtualType +from samplomatic.exceptions import VirtualGateError from samplomatic.virtual_registers import PauliRegister, U2Register, VirtualRegister, Z2Register @@ -160,3 +162,14 @@ def test_invert(): assert paulis == inverted paulis[0, 0] = 2 assert inverted.virtual_gates[0, 0] != 2 + + +def test_from_name(): + """Test the from_name() method.""" + assert PauliRegister.from_name("x") == PauliRegister(np.array(2, dtype=np.uint8).reshape(1, 1)) + assert PauliRegister.from_name("y") == PauliRegister(np.array(3, dtype=np.uint8).reshape(1, 1)) + assert PauliRegister.from_name("z") == PauliRegister(np.array(1, dtype=np.uint8).reshape(1, 1)) + assert PauliRegister.from_name("id") == PauliRegister(np.zeros((1, 1), dtype=np.uint8)) + + with pytest.raises(VirtualGateError, match="'not-pauli' is not a valid Pauli"): + PauliRegister.from_name("not-pauli")