Source code for bqskit.qis.state.state

"""This module implements the StateVector class."""
from __future__ import annotations

import logging
from typing import Any
from typing import cast
from typing import Iterator
from typing import Sequence
from typing import TYPE_CHECKING
from typing import Union

import numpy as np
import numpy.typing as npt
from numpy.lib.mixins import NDArrayOperatorsMixin
from scipy.stats import unitary_group

from bqskit.utils.typing import is_integer
from bqskit.utils.typing import is_valid_radixes
from bqskit.utils.typing import is_vector

if TYPE_CHECKING:
    from typing import TypeGuard
    from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
    from bqskit.ir.location import CircuitLocationLike

_logger = logging.getLogger(__name__)


[docs] class StateVector(NDArrayOperatorsMixin): """A vector representing a pure quantum state."""
[docs] def __init__( self, input: StateLike, radixes: Sequence[int] = [], check_arguments: bool = True, ) -> None: """ Constructs a `StateVector` from the supplied vector. Args: input (StateLike): The state vector input. radixes (Sequence[int]): A sequence with its length equal to the number of qudits this `StateVector` represents. Each element specifies the base, number of orthogonal states, for the corresponding qudit. By default, the constructor will attempt to calculate `radixes` from `input`. check_arguments (bool): If true, check arguments for type and value errors. Raises: ValueError: If `input` is not a pure quantum state. ValueError: If the dimension of `input` does not match the expected dimension from `radixes`. RuntimeError: If `radixes` is not specified and the constructor cannot infer it. """ # Copy Constructor if isinstance(input, StateVector): self._vec = input.numpy self._radixes = input.radixes self._dim = input.dim return if check_arguments and not is_vector(input): raise TypeError(f'Expected vector, got {type(input)}.') if check_arguments and not StateVector.is_pure_state(input): raise ValueError('Input failed state vector condition.') dim = len(input) if radixes: self._radixes = tuple(radixes) # Check if unitary dimension is a power of two elif dim & (dim - 1) == 0: self._radixes = tuple([2] * int(np.round(np.log2(dim)))) # Check if unitary dimension is a power of three elif 3 ** int(np.round(np.log(dim) / np.log(3))) == dim: # noqa radixes = [3] * int(np.round(np.log(dim) / np.log(3))) self._radixes = tuple(radixes) else: raise RuntimeError( 'Unable to determine radixes' ' for StateVector with dim %d.' % dim, ) if check_arguments and not is_valid_radixes(self.radixes): raise TypeError('Invalid qudit radixes.') if check_arguments and np.prod(self.radixes) != dim: raise ValueError('Qudit radixes mismatch with dimension.') self._vec = np.array(input, dtype=np.complex128) self._dim = dim
@property def numpy(self) -> npt.NDArray[np.complex128]: """The NumPy array holding the vector.""" return self._vec @property def shape(self) -> tuple[int, ...]: """The one-dimensional shape of the vector.""" return self._vec.shape @property def dtype(self) -> np.typing.DTypeLike: """The NumPy data type of the vector.""" return self._vec.dtype @property def num_qudits(self) -> int: """The number of qudits in the state.""" return len(self.radixes) @property def dim(self) -> int: """The vector dimension for this state.""" return self._dim @property def radixes(self) -> tuple[int, ...]: """The number of orthogonal states for each qudit.""" return self._radixes def __len__(self) -> int: """The dimension of the state vector.""" return self.shape[0] def __iter__(self) -> Iterator[np.complex128]: """An iterator that iterates through the elements of the vector.""" return self._vec.__iter__() def __getitem__( self, index: Any, ) -> np.complex128 | npt.NDArray[np.complex128]: """Implements NumPy API for the StateVector class.""" return self._vec[index]
[docs] def get_probs(self) -> tuple[float, ...]: """Return the probabilities for each classical outcome.""" return tuple(np.abs(elem)**2 for elem in self)
[docs] def is_qubit_only(self) -> bool: """Return true if this unitary can only act on qubits.""" return all([radix == 2 for radix in self.radixes])
[docs] def is_qutrit_only(self) -> bool: """Return true if this unitary can only act on qutrits.""" return all([radix == 3 for radix in self.radixes])
[docs] def is_qudit_only(self, radix: int) -> bool: """ Return true if this unitary can only act on `radix`-qudits. Args: radix (int): Check all qudits have this many orthogonal states. """ return all([r == radix for r in self.radixes])
[docs] @staticmethod def is_pure_state(V: Any, tol: float = 1e-8) -> TypeGuard[StateLike]: """ Check if V is a pure state vector. Args: V (Any): The vector to check. tol (float): The numerical precision of the check. Returns: bool: True if V is a pure quantum state vector. """ if isinstance(V, StateVector): return True from bqskit.qis.state import StateSystem if isinstance(V, StateSystem): return False if not np.allclose(np.sum(np.square(np.abs(V))), 1, rtol=0, atol=tol): _logger.debug('Failed pure state criteria.') return False return True
[docs] @staticmethod def zero(num_qudits: int, radixes: Sequence[int] = []) -> StateVector: """Prepares the zero state.""" if len(radixes) == 0: radixes = [2] * num_qudits state = np.zeros(np.prod(radixes), dtype=np.complex128) state[0] = 1.0 return StateVector(state)
[docs] @staticmethod def random(num_qudits: int, radixes: Sequence[int] = []) -> StateVector: """ Sample a random pure state. Args: num_qudits (np.ndarray): The number of qudits in the state. This is not the dimension. radixes (Sequence[int]): The radixes for the StateVector. Returns: StateVector: A random pue quantum state. Raises: ValueError: If `num_qudits` is nonpositive. ValueError: If the length of `radixes` is not equal to `num_qudits`. """ if not is_integer(num_qudits): raise TypeError( f'Expected int for num_qudits, got {type(num_qudits)}.', ) if num_qudits <= 0: raise ValueError('Expected positive number for num_qudits.') radixes = tuple(radixes if len(radixes) > 0 else [2] * num_qudits) if not is_valid_radixes(radixes): raise TypeError('Invalid qudit radixes.') if len(radixes) != num_qudits: raise ValueError( 'Expected length of radixes to be equal to num_qudits:' ' %d != %d' % (len(radixes), num_qudits), ) U = unitary_group.rvs(int(np.prod(radixes))) return StateVector(U[:, 0], radixes)
def __hash__(self) -> int: """Hash the state vector.""" return hash(tuple(self.numpy)) def __eq__(self, other: object) -> bool: """Check if `self` is approximately equal to `other`.""" if isinstance(other, StateVector): return np.allclose(self.numpy, other.numpy) if isinstance(other, np.ndarray): return np.allclose(self.numpy, other) return NotImplemented
[docs] def apply( self, utry: UnitaryMatrix, location: CircuitLocationLike, inverse: bool = False, check_arguments: bool = True, ) -> None: """ Apply the specified unitary on the right of this StateVector. .. .---. .------. | |---| |- 0 | |---| utry |- 1 . . '------' . . . . | |------------ n-1 '---' Args: utry (UnitaryMatrix): The unitary to apply. location (CircuitLocationLike): The qudits to apply the unitary on. inverse (bool): If true, apply the inverse of the unitary. check_arguments (bool): If true, check the inputs for type and value errors. Raises: ValueError: If `utry`'s size does not match the given location. ValueError: if `utry`'s radixes does not match the given location. Notes: - Applying the unitary on the left is equivalent to multiplying the unitary on the right of the tensor. The notation comes from the quantum circuit perspective. - This operation is performed using tensor contraction. """ from bqskit.ir.location import CircuitLocation from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix if check_arguments: if not isinstance(utry, UnitaryMatrix): raise TypeError('Expected UnitaryMatrix, got %s', type(utry)) if not CircuitLocation.is_location(location, self.num_qudits): raise TypeError('Invalid location.') location = CircuitLocation(location) if len(location) != utry.num_qudits: raise ValueError('Unitary and location size mismatch.') for utry_radix, bldr_radix_idx in zip(utry.radixes, location): if utry_radix != self.radixes[bldr_radix_idx]: raise ValueError('Unitary and location radix mismatch.') location = cast(CircuitLocation, location) qudits = list(range(self.num_qudits)) identity_action_perm = [ x for x in qudits if x not in location ] unitary_action_perm = list(location) left_dim = int( np.prod([ self.radixes[x] for x in unitary_action_perm ]), ) utry = utry.dagger if inverse else utry perm = unitary_action_perm + identity_action_perm self._vec = self._vec.reshape(self.radixes) self._vec = self._vec.transpose(perm) self._vec = self._vec.reshape((left_dim, -1)) self._vec = utry @ self._vec shape = list(self.radixes) * 2 shape = [shape[p] for p in perm] self._vec = self._vec.reshape(shape) inv_perm = list(np.argsort(perm)) self._vec = self._vec.transpose(inv_perm) self._vec = self._vec.reshape(-1)
[docs] def get_distance_from(self, other: StateLike) -> float: """ Return the distance between `self` and `other`. The distance is given as the infidelity between the two states. Args: other (StateLike): The state to measure distance from. Returns: float: A value between 1 and 0, where 0 means the two states are equal up to global phase and 1 means the two states are very unsimilar or far apart. """ other = StateVector(other) dist = 1 - np.abs(np.conj(self) @ other) ** 2 return dist if dist > 0.0 else 0.0
def __array__( self, dtype: np.typing.DTypeLike = np.complex128, ) -> npt.NDArray[np.complex128]: """Implements NumPy API for the StateVector class.""" if dtype != np.complex128: raise ValueError('StateVector only supports Complex128 dtype.') return self._vec def __array_ufunc__( self, ufunc: np.ufunc, method: str, *inputs: npt.NDArray[Any], **kwargs: Any, ) -> StateVector | npt.NDArray[np.complex128]: """Implements NumPy API for the StateVector class.""" if method != '__call__': return NotImplemented non_state_involved = False args: list[npt.NDArray[Any]] = [] for input in inputs: if isinstance(input, StateVector): args.append(input.numpy) else: args.append(input) non_state_involved = True out = ufunc(*args, **kwargs) # The results are state vectors # if only states are involved # and state vectors are closed under the specific operation. convert_back = ( not non_state_involved and ufunc.__name__ == 'conjugate' or ( ufunc.__name__ == 'multiply' and all( np.isscalar(input) or isinstance(input, StateVector) for input in inputs ) and all( np.abs(np.abs(input) - 1) <= 1e-14 for input in inputs if np.isscalar(input) ) ) ) if convert_back: return StateVector(out, self.radixes) return out def __str__(self) -> str: """Return the string representation of the vector.""" return str(self._vec) def __repr__(self) -> str: """Return the repr representation of the vector.""" return repr(self._vec)
StateLike = Union[StateVector, np.ndarray, Sequence[Union[int, float, complex]]]