"""This module implements the Operation class."""
from __future__ import annotations
from typing import Any
import numpy as np
import numpy.typing as npt
from bqskit.ir.gate import Gate
from bqskit.ir.gates.composed.frozenparam import FrozenParameterGate
from bqskit.ir.location import CircuitLocation
from bqskit.ir.location import CircuitLocationLike
from bqskit.qis.unitary.differentiable import DifferentiableUnitary
from bqskit.qis.unitary.unitary import RealVector
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
from bqskit.utils.typing import is_sequence
[docs]
class Operation(DifferentiableUnitary):
"""An Operation groups together a gate, its parameters and location."""
[docs]
def __init__(
self,
gate: Gate,
location: CircuitLocationLike,
params: RealVector = [],
) -> None:
"""
Construct an operation.
Args:
gate (Gate): The operation's gate.
location (CircuitLocationLike): The set of qudits this gate
is applied to.
params (RealVector): The parameters for the gate.
Raises:
ValueError: If `gate`'s size doesn't match `location`'s length.
ValueError: If `gate`'s num_params doesn't match `params`'s
length.
"""
if not isinstance(gate, Gate):
raise TypeError('Expected gate, got %s.' % type(gate))
if not CircuitLocation.is_location(location):
raise TypeError('Invalid location.')
if is_sequence(params) and len(params) == 0 and gate.num_params != 0:
params = [0.0] * gate.num_params
gate.check_parameters(params)
location = CircuitLocation(location)
if len(location) != gate.num_qudits:
raise ValueError('Gate and location size mismatch.')
self._num_params = gate.num_params
self._radixes = gate.radixes
self._num_qudits = gate.num_qudits
self._gate = gate
self._location = location
self._params = list(params)
@property
def gate(self) -> Gate:
"""The operation's gate."""
return self._gate
@property
def location(self) -> CircuitLocation:
"""The qudit this operation is applied to."""
return self._location
@property
def params(self) -> list[float]:
"""The operation's parameters for its gate."""
return self._params
@params.setter
def params(self, params: list[float]) -> None:
self.check_parameters(params)
self._params = params
[docs]
def get_qasm(self) -> str:
"""
Return the qasm string for this operation.
Returns:
str: The operation as a qasm line.
"""
if isinstance(self.gate, FrozenParameterGate):
full_params = self.gate.get_full_params(self.params)
else:
full_params = self.params
return self.gate.get_qasm(full_params, self.location)
[docs]
def get_unitary(self, params: RealVector = []) -> UnitaryMatrix:
"""Return the unitary for this gate, see :class:`Unitary` for more."""
if len(params) != 0:
return self.gate.get_unitary(params)
return self.gate.get_unitary(self.params)
[docs]
def get_inverse(self) -> Operation:
"""Return the operation's inverse operation."""
return Operation(
self.gate.get_inverse(),
self.location,
self.gate.get_inverse_params(self.params),
)
[docs]
def get_grad(self, params: RealVector = []) -> npt.NDArray[np.complex128]:
"""
Return the gradient for this operation.
See :class:`DifferentiableUnitary` for more info.
"""
if len(params) != 0:
return self.gate.get_grad(params) # type: ignore
return self.gate.get_grad(self.params) # type: ignore
[docs]
def get_unitary_and_grad(
self,
params: RealVector = [],
) -> tuple[UnitaryMatrix, npt.NDArray[np.complex128]]:
"""
Return the unitary and gradient for this gate.
See :class:`DifferentiableUnitary` for more info.
"""
if len(params) != 0:
return self.gate.get_unitary_and_grad(params) # type: ignore
return self.gate.get_unitary_and_grad(self.params) # type: ignore
def __eq__(self, rhs: Any) -> bool:
"""Check for equality."""
if self is rhs:
return True
if not isinstance(rhs, Operation):
return NotImplemented
return (
self.gate == rhs.gate
and self.params == rhs.params
and self.location == rhs.location
)
def __hash__(self) -> int:
return hash((self.gate, self.location))
def __str__(self) -> str:
return f'{self.gate}@{self.location}'
def __repr__(self) -> str:
if len(self.params) == 0:
return f'{self.gate}@{self.location}'
return f'{self.gate}({self.params})@{self.location}'
[docs]
def is_differentiable(self) -> bool:
"""Check if operation is differentiable."""
return isinstance(self.gate, DifferentiableUnitary)