"""This module implements the ForEachBlockPass class."""
from __future__ import annotations
import functools
import logging
from typing import Callable
from bqskit.compiler.basepass import _sub_do_work
from bqskit.compiler.basepass import BasePass
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
from bqskit.ir.gates.circuitgate import CircuitGate
from bqskit.ir.gates.constant.unitary import ConstantUnitaryGate
from bqskit.ir.gates.parameterized.pauli import PauliGate
from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate
from bqskit.ir.location import CircuitLocation
from bqskit.ir.operation import Operation
from bqskit.ir.point import CircuitPoint
from bqskit.runtime import get_runtime
_logger = logging.getLogger(__name__)
[docs]
class ForEachBlockPass(BasePass):
"""
A pass that executes other passes on each block in the circuit.
This is a control pass that executes a workflow on every block in the
circuit. This will be done in parallel.
"""
key = 'ForEachBlockPass_data'
"""The key in data, where block data will be put."""
pass_down_key_prefix = 'ForEachBlockPass_pass_down_'
"""If a key exists in the pass data with this prefix, pass it to blocks."""
pass_down_block_specific_key_prefix = (
'ForEachBlockPass_specific_pass_down_'
)
"""
Data specific to the processing of individual blocks in a partitioned
circuit can be injected into the `PassData` in `run` by using this prefix.
The expected type of the associated value is `dict[int, Any]`, where
integer (sub-)keys correspond to block numbers in a partitioned quantum
circuit.
Pseudocode example for seed circuits:
seeds = {block_id: [seed_circuit_a, seed_circuit_b, ...], ...}
key = self.pass_down_block_specific_key_prefix + 'seed_circuits'
seed_updater = UpdateDataPass(key, seeds)
workflow = Workflow([..., seed_updater, ForEachBlockPass(...), ...])
"""
[docs]
def __init__(
self,
loop_body: WorkflowLike,
calculate_error_bound: bool = False,
collection_filter: Callable[[Operation], bool] | None = None,
replace_filter: ReplaceFilterFn | str = 'always',
batch_size: int | None = None,
) -> None:
"""
Construct a ForEachBlockPass.
Args:
loop_body (WorkflowLike): The workflow to execute on every block.
calculate_error_bound (bool): If set to true, will calculate
errors on blocks after running `loop_body` on them and
use these block errors to calculate an upper bound on the
full circuit error. (Default: False)
collection_filter (Callable[[Operation], bool] | None):
A predicate that determines which operations should have
`loop_body` called on them. Called with each operation
in the circuit. If this returns true, that operation will
be formed into an individual circuit and passed through
`loop_body`. Defaults to all CircuitGates,
ConstantUnitaryGates, and VariableUnitaryGates.
#TODO: address importability
replace_filter (ReplaceFilterFn | str | None):
A predicate that determines if the resulting circuit, after
calling `loop_body` on a block, should replace the original
operation. Called with the circuit output from `loop_body`
and the original operation. If this returns true, the
operation will be replaced with the new circuit.
Defaults to always replace. If none is passed, will
generate a replace filter always replaces. If a string is
passed, will generate a replace filter corresponding to
the string. The string should either be 'always', 'less-than',
'less-than-multi', 'less-than-many', 'less-than-respecting',
'less-than-respecting-multi', or 'less-than-respecting-many'.
- 'always' will always replace
- 'less-than' will replace if the new circuit has fewer
gates than the old circuit.
- 'less-than-multi' will replace if the new circuit has
fewer multi-qudit gates than the old circuit.
- 'less-than-many' will replace if the new circuit has
fewer many-qudit gates than the old circuit.
- 'less-than-respecting' will replace if the new circuit
has fewer gates than the old circuit or the old
doesn't respect the model (ignoring single-qudit
gate sets).
- 'less-than-respecting-multi' will replace if the new
circuit has fewer multi-qudit gates than the old
circuit or the old doesn't respect the model
(ignoring single-qudit gate sets).
- 'less-than-respecting-many' will replace if the new
circuit has fewer many-qudit gates than the old
circuit or the old doesn't respect the model
(ignoring single-qudit gate sets).
- 'less-than-respecting-fully' will replace if the new
circuit has fewer gates than the old circuit or
the old doesn't respect the model.
- 'less-than-respecting-fully-multi' will replace if
the new circuit has fewer multi-qudit gates than
the old circuit or the old doesn't respect the model.
- 'less-than-respecting-fully-many' will replace if
the new circuit has fewer many-qudit gates than
the old circuit or the old doesn't respect the model.
Defaults to 'always'. #TODO: address importability
batch_size (int): (Deprecated).
"""
if batch_size is not None:
import warnings
warnings.warn(
'Batch size is no longer supported, this warning will'
' become an error in a future update.',
DeprecationWarning,
)
self.calculate_error_bound = calculate_error_bound
self.collection_filter = collection_filter or default_collection_filter
self.replace_filter = replace_filter or default_replace_filter
self.workflow = Workflow(loop_body)
if not callable(self.collection_filter):
raise TypeError(
'Expected callable method that maps Operations to booleans for'
f' collection_filter, got {type(self.collection_filter)}.',
)
if not isinstance(self.replace_filter, str):
if not callable(self.replace_filter):
raise TypeError(
'Expected either string representing a valid replacement'
' filter or callable method that maps Circuit and'
' Operations to bools for replace_filter'
f' , got {type(self.replace_filter)}.',
)
[docs]
async def run(self, circuit: Circuit, data: PassData) -> None:
"""Perform the pass's operation, see :class:`BasePass` for more."""
# Get the callable replacement filter
if isinstance(self.replace_filter, str):
method = self.replace_filter
replace_filter = gen_replace_filter(method, data.model)
else:
replace_filter = self.replace_filter
# Make room in data for block data
if self.key not in data:
data[self.key] = []
# Collect blocks
blocks: list[tuple[int, Operation]] = []
for cycle, op in circuit.operations_with_cycles():
if self.collection_filter(op):
blocks.append((cycle, op))
# No blocks, no work
if len(blocks) == 0:
data[self.key].append([])
return
# Get the machine model
model = data.model
coupling_graph = data.connectivity
# Preprocess blocks
subcircuits: list[Circuit] = []
block_datas: list[PassData] = []
for i, (cycle, op) in enumerate(blocks):
# Form Subcircuit
if isinstance(op.gate, CircuitGate):
subcircuit = op.gate._circuit.copy()
subcircuit.set_params(op.params)
else:
subcircuit = Circuit.from_operation(op)
# Form Submodel
subradixes = [circuit.radixes[q] for q in op.location]
subnumbering = {op.location[i]: i for i in range(len(op.location))}
submodel = MachineModel(
len(op.location),
coupling_graph.get_subgraph(op.location, subnumbering),
model.gate_set,
subradixes,
)
# Form Subdata
block_data: PassData = PassData(subcircuit)
block_data['subnumbering'] = subnumbering
block_data['model'] = submodel
block_data['point'] = CircuitPoint(cycle, op.location[0])
block_data['calculate_error_bound'] = self.calculate_error_bound
for key in data:
if key.startswith(self.pass_down_key_prefix):
block_data[key] = data[key]
elif key.startswith(
self.pass_down_block_specific_key_prefix,
) and i in data[key]:
block_data[key] = data[key][i]
block_data.seed = data.seed
subcircuits.append(subcircuit)
block_datas.append(block_data)
# Do the work
results = await get_runtime().map(
_sub_do_work,
[self.workflow] * len(subcircuits),
subcircuits,
block_datas,
)
# Unpack results
completed_subcircuits, completed_block_datas = zip(*results)
# Postprocess blocks
points: list[CircuitPoint] = []
ops: list[Operation] = []
error_sum = 0.0
for i, (cycle, op) in enumerate(blocks):
subcircuit = completed_subcircuits[i]
block_data = completed_block_datas[i]
# Mark Blocks to be Replaced
if replace_filter(subcircuit, op):
_logger.debug(f'Replacing block {i}.')
points.append(CircuitPoint(cycle, op.location[0]))
ops.append(
Operation(
CircuitGate(subcircuit, True),
op.location,
subcircuit.params,
),
)
block_data['replaced'] = True
# Calculate Error
error_sum += block_data.error
else:
block_data['replaced'] = False
# Replace blocks
circuit.batch_replace(points, ops)
# Record block data into pass data
data[self.key].append(completed_block_datas)
# Record error
data.update_error_mul(error_sum)
if self.calculate_error_bound:
_logger.debug(f'New circuit error is {data.error}.')
def default_collection_filter(op: Operation) -> bool:
return isinstance(
op.gate, (
CircuitGate,
ConstantUnitaryGate,
VariableUnitaryGate,
PauliGate,
),
)
def default_replace_filter(circuit: Circuit, op: Operation) -> bool:
"""Always replace."""
# legacy name and style for backwards compatibility
return True
def _less_than(new: Circuit, old: Operation) -> bool:
"""Return true if the new circuit has fewer gates."""
if isinstance(old.gate, CircuitGate):
return new.num_operations < old.gate._circuit.num_operations
return True # TODO: Re-evaluate always true when old is not a circuit
def _less_than_multi(new: Circuit, old: Operation) -> bool:
"""Return true if the new circuit has fewer multi-qudit gates."""
if isinstance(old.gate, CircuitGate):
org = old.gate._circuit
omq = sum([c for g, c in org.gate_counts.items() if g.num_qudits > 1])
osq = sum([c for g, c in org.gate_counts.items() if g.num_qudits == 1])
nmq = sum([c for g, c in new.gate_counts.items() if g.num_qudits > 1])
nsq = sum([c for g, c in new.gate_counts.items() if g.num_qudits == 1])
return (nmq, nsq) < (omq, osq)
return True
def _less_than_many(new: Circuit, old: Operation) -> bool:
"""Return true if the new circuit has fewer many-qudit gates."""
if isinstance(old.gate, CircuitGate):
org = old.gate._circuit
omq = sum([c for g, c in org.gate_counts.items() if g.num_qudits > 2])
otq = sum([c for g, c in org.gate_counts.items() if g.num_qudits == 2])
osq = sum([c for g, c in org.gate_counts.items() if g.num_qudits == 1])
nmq = sum([c for g, c in new.gate_counts.items() if g.num_qudits > 2])
ntq = sum([c for g, c in new.gate_counts.items() if g.num_qudits == 2])
nsq = sum([c for g, c in new.gate_counts.items() if g.num_qudits == 1])
return (nmq, ntq, nsq) < (omq, otq, osq)
return True
def _is_respecting(
circuit: Circuit,
location: CircuitLocation,
model: MachineModel,
fully: bool = False,
) -> bool:
"""
Return true if the `circuit` respects the `model` at `location`.
Args:
circuit (Circuit): The circuit to check.
location (CircuitLocation): The location to check.
model (MachineModel): The machine model to check against.
fully (bool): If set to true, will check if the circuit respects
the model fully. If set to false, will ignore single-qudit
gate sets. (Default: False)
Returns:
True if the circuit respects the model at the location. This implies
that the circuit can be run on the machine at the location.
"""
org_mq_gates = circuit.gate_set.multi_qudit_gates
org_sq_gates = circuit.gate_set.single_qudit_gates
if any(g not in model.gate_set for g in org_mq_gates):
return False
if fully and any(g not in model.gate_set for g in org_sq_gates):
return False
if any(
(location[e[0]], location[e[1]]) not in model.coupling_graph
for e in circuit.coupling_graph
):
return False
return True
def _less_than_fn_respecting(
new: Circuit,
old: Operation,
model: MachineModel,
fn: ReplaceFilterFn,
) -> bool:
"""Return true if the new circuit has fewer gates or the old doesn't respect
the model."""
if isinstance(old.gate, CircuitGate):
if not _is_respecting(old.gate._circuit, old.location, model):
if not _is_respecting(new, old.location, model):
_logger.debug("New block doesn't respect model.")
return True
if not _is_respecting(new, old.location, model):
_logger.debug("New block doesn't respect model.")
return False
return fn(new, old)
def _less_than_fn_respecting_fully(
new: Circuit,
old: Operation,
model: MachineModel,
fn: ReplaceFilterFn,
) -> bool:
"""Return true if the new circuit has fewer gates or the old doesn't respect
the model."""
if isinstance(old.gate, CircuitGate):
if not _is_respecting(old.gate._circuit, old.location, model, True):
if not _is_respecting(new, old.location, model, True):
_logger.debug("New block doesn't respect model.")
return True
if not _is_respecting(new, old.location, model, True):
_logger.debug("New block doesn't respect model.")
return False
return fn(new, old)
def gen_always(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that always replaces."""
# legacy name and style for backwards compatibility
return default_replace_filter
def gen_less_than(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
gates."""
return _less_than
def gen_less_than_multi(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
multi-qudit gates."""
return _less_than_multi
def gen_less_than_many(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
many-qudit gates."""
return _less_than_many
def gen_less_than_rspt(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
gates or the old doesn't respect the model."""
return functools.partial(
_less_than_fn_respecting,
model=model,
fn=_less_than,
)
def gen_less_than_rspt_multi(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
multi-qudit gates or the old doesn't respect the model."""
return functools.partial(
_less_than_fn_respecting,
model=model,
fn=_less_than_multi,
)
def gen_less_than_rspt_many(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
many-qudit gates or the old doesn't respect the model."""
return functools.partial(
_less_than_fn_respecting,
model=model,
fn=_less_than_many,
)
def gen_less_than_rspt_fully(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
gates or the old doesn't respect the model."""
return functools.partial(
_less_than_fn_respecting_fully,
model=model,
fn=_less_than,
)
def gen_less_than_rspt_fully_multi(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
multi-qudit gates or the old doesn't respect the model."""
return functools.partial(
_less_than_fn_respecting_fully,
model=model,
fn=_less_than_multi,
)
def gen_less_than_rspt_fully_many(model: MachineModel) -> ReplaceFilterFn:
"""Generate a replace filter that replaces if the new circuit has fewer
many-qudit gates or the old doesn't respect the model."""
return functools.partial(
_less_than_fn_respecting_fully,
model=model,
fn=_less_than_many,
)
def gen_replace_filter(method: str, model: MachineModel) -> ReplaceFilterFn:
"""
Generate a replace filter for use during the standard workflow.
Args:
method (str): The method to use for the replace filter. See
:class:`ForEachBlockPass` for more information.
model (MachineModel): The machine model to potentially respect.
Returns:
A replace filter function.
"""
replace_filters = {
'always': gen_always,
'less-than': gen_less_than,
'less-than-multi': gen_less_than_multi,
'less-than-many': gen_less_than_many,
'less-than-respecting': gen_less_than_rspt,
'less-than-respecting-multi': gen_less_than_rspt_multi,
'less-than-respecting-many': gen_less_than_rspt_many,
'less-than-respecting-fully': gen_less_than_rspt_fully,
'less-than-respecting-fully-multi': gen_less_than_rspt_fully_multi,
'less-than-respecting-fully-many': gen_less_than_rspt_fully_many,
}
if method not in replace_filters:
raise ValueError(f'Unknown replace filter method {method}.')
return replace_filters[method](model)
ReplaceFilterFn = Callable[[Circuit, Operation], bool]
[docs]
class ClearAllBlockData(BasePass):
"""Clear all block data and passed down data from the pass data."""
[docs]
async def run(self, circuit: Circuit, data: PassData) -> None:
"""Perform the pass's operation, see :class:`BasePass` for more."""
for key in list(data.keys()):
if key.startswith(ForEachBlockPass.key):
del data[key]
elif key.startswith(ForEachBlockPass.pass_down_key_prefix):
del data[key]