Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions samplomatic/builders/box_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,30 @@ def rhs(self):
class RightBoxBuilder(BoxBuilder):
"""Box builder for right dressings."""

def parse(self, instr: CircuitInstruction):
if (name := instr.operation.name).startswith("meas"):
raise RuntimeError("Boxes with measurements cannot have dressing=right.")
def __init__(self, collection: CollectionSpec, emission: EmissionSpec):
super().__init__(collection=collection, emission=emission)

self.measured_qubits = QubitPartition(1, [])
self.clbit_idxs = []

if name == "barrier":
def parse(self, instr: CircuitInstruction):
if (name := instr.operation.name).startswith("barrier"):
spec = InstructionSpec(params=self.template_state.append_remapped_gate(instr))
return

if name.startswith("meas"):
for qubit in instr.qubits:
if (qubit,) not in self.measured_qubits:
self.measured_qubits.add((qubit,))
else:
raise SamplexBuildError(
"Cannot measure the same qubit twice in a twirling box."
)
self.template_state.append_remapped_gate(instr)
self.clbit_idxs.extend(
[self.template_state.template.find_bit(clbit)[0] for clbit in instr.clbits]
)
return

elif (num_qubits := instr.operation.num_qubits) == 1:
self.entangled_qubits.update(instr.qubits)
Expand Down Expand Up @@ -220,5 +238,12 @@ def lhs(self):
def rhs(self):
self._append_barrier("M")
param_idxs = self._append_dressed_layer()
if twirl_type := self.emission.twirl_register_type:
if len(self.measured_qubits) != 0:
if twirl_type != VirtualType.PAULI:
raise SamplexBuildError(
f"Cannot use {twirl_type.value} twirl in a box with measurements."
)
self.samplex_state.add_z2_collect(self.measured_qubits, self.clbit_idxs)
self.samplex_state.add_collect(self.collection.qubits, self.collection.synth, param_idxs)
self._append_barrier("R")
15 changes: 12 additions & 3 deletions samplomatic/pre_samplex/pre_samplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
from ..samplex.noise_model_requirement import NoiseModelRequirement
from ..synths import Synth
from ..tensor_interface import TensorSpecification
from ..virtual_registers import U2Register
from ..virtual_registers import PauliRegister, U2Register
from ..virtual_registers.pauli_register import PAULI_GATE_NAMES
from ..visualization import plot_graph
from .graph_data import (
PreBasisTransform,
Expand Down Expand Up @@ -1351,8 +1352,8 @@ def add_propagate_node(
for predecssor_idx in self.graph.predecessor_indices(pre_propagate_idx):
incoming.add(samplex.graph[pre_nodes_to_nodes[predecssor_idx]].outgoing_register_type)
if mode is InstructionMode.MULTIPLY and pre_propagate.operation.num_qubits == 1:
combined_register_type = VirtualType.U2
if pre_propagate.operation.is_parameterized():
combined_register_type = VirtualType.U2
param_idxs = [
samplex.append_parameter_expression(param)
for _, param in pre_propagate.spec.params
Expand All @@ -1366,7 +1367,15 @@ def add_propagate_node(
op_name, combined_register_name, param_idxs
)
else:
register = U2Register(np.array(pre_propagate.operation).reshape(1, 1, 2, 2))
if (
incoming == {VirtualType.PAULI}
and (name := pre_propagate.operation.name) in PAULI_GATE_NAMES
):
combined_register_type = VirtualType.PAULI
register = PauliRegister.from_name(name)
else:
combined_register_type = VirtualType.U2
register = U2Register(np.array(pre_propagate.operation).reshape(1, 1, 2, 2))
if pre_propagate.direction is Direction.LEFT:
propagate_node = RightMultiplicationNode(register, combined_register_name)
else:
Expand Down
23 changes: 21 additions & 2 deletions samplomatic/virtual_registers/pauli_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from .u2_register import U2Register
from .z2_register import Z2Register

PAULI_GATE_NAMES = {"id": 0, "x": 2, "z": 1, "y": 3}

PAULI_TO_U2 = np.array(
[np.diag([1, 1]), np.diag([1, -1]), np.diag([1, 1])[::-1], np.diag([-1j, 1j])[::-1]],
dtype=U2Register.DTYPE,
Expand Down Expand Up @@ -60,6 +62,23 @@ def __init__(self, virtual_gates):
def identity(cls, num_subsystems, num_samples):
return cls(np.zeros((num_subsystems, num_samples), dtype=np.uint8))

@classmethod
def from_name(cls, name: str) -> PauliRegister:
"""Returns a Pauli register given a its name.

Args:
name: The name of the Pauli.

Returns: The Pauli register.

Raises:
VirtualGateError: If ``name`` is not in 'PAULI_GATE_NAMES'.
"""
try:
return cls(np.array([PAULI_GATE_NAMES[name]], dtype=np.uint8).reshape(1, 1))
except KeyError:
raise VirtualGateError(f"'{name}' is not a valid Pauli.")

def convert_to(self, register_type):
if register_type is VirtualType.U2:
return U2Register(PAULI_TO_U2[self._array, :, :])
Expand All @@ -76,7 +95,7 @@ def multiply(self, other, subsystem_idxs: list[SubsystemIndex] | slice = slice(N
except (ValueError, IndexError) as exc:
raise VirtualGateError(
f"Register {self} and {other} have incompatible shapes or types, "
f"given subsystem_idxs {subsystem_idxs}"
f"given subsystem_idxs {subsystem_idxs}."
) from exc

def inplace_multiply(self, other, subsystem_idxs: list[SubsystemIndex] | slice = slice(None)):
Expand All @@ -87,7 +106,7 @@ def inplace_multiply(self, other, subsystem_idxs: list[SubsystemIndex] | slice =
except (ValueError, IndexError) as exc:
raise VirtualGateError(
f"Register {self} and {other} have incompatible shapes or types, "
f"given subsystem_idxs {subsystem_idxs}"
f"given subsystem_idxs {subsystem_idxs}."
) from exc

def invert(self):
Expand Down
18 changes: 18 additions & 0 deletions test/integration/test_measurement_twirling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ def test_measure_all(self, save_plot):
circuit.measure_all()
sample_simulate_and_compare_counts(circuit, save_plot)

def test_measure_all_right(self, save_plot):
circuit = QuantumCircuit(3)
with circuit.box([Twirl()]):
circuit.noop(*circuit.qubits)
with circuit.box([Twirl(dressing="right")]):
circuit.measure_all()
sample_simulate_and_compare_counts(circuit, save_plot)

def test_gates_and_measure_all(self, save_plot):
circuit = QuantumCircuit(3)
with circuit.box([Twirl(dressing="left")]):
Expand All @@ -75,6 +83,16 @@ def test_separate_measures(self, save_plot):

sample_simulate_and_compare_counts(circuit, save_plot)

def test_measure_with_different_dressings(self, save_plot):
circuit = QuantumCircuit(QuantumRegister(size=2), ClassicalRegister(name="meas", size=2))
with circuit.box([Twirl(dressing="left")]):
circuit.measure(0, 1)
circuit.x(1)
with circuit.box([Twirl(dressing="right")]):
circuit.measure(1, 0)

sample_simulate_and_compare_counts(circuit, save_plot)

@pytest.mark.skip(reason="QiskitAer bug #2367")
def test_separate_measure_boxes(self, save_plot):
"""Test separate measurement boxes, with non-standard cbit associations"""
Expand Down
13 changes: 13 additions & 0 deletions test/unit/test_virtual_registers/test_pauli_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
"""Test the PauliRegister"""

import numpy as np
import pytest

from samplomatic.annotations import VirtualType
from samplomatic.exceptions import VirtualGateError
from samplomatic.virtual_registers import PauliRegister, U2Register, VirtualRegister, Z2Register


Expand Down Expand Up @@ -160,3 +162,14 @@ def test_invert():
assert paulis == inverted
paulis[0, 0] = 2
assert inverted.virtual_gates[0, 0] != 2


def test_from_name():
"""Test the from_name() method."""
assert PauliRegister.from_name("x") == PauliRegister(np.array(2, dtype=np.uint8).reshape(1, 1))
assert PauliRegister.from_name("y") == PauliRegister(np.array(3, dtype=np.uint8).reshape(1, 1))
assert PauliRegister.from_name("z") == PauliRegister(np.array(1, dtype=np.uint8).reshape(1, 1))
assert PauliRegister.from_name("id") == PauliRegister(np.zeros((1, 1), dtype=np.uint8))

with pytest.raises(VirtualGateError, match="'not-pauli' is not a valid Pauli"):
PauliRegister.from_name("not-pauli")