Davidelanz/quantum-robot

View on GitHub
src/qrobot/qunits/qunit.py

Summary

Maintainability
A
1 hr
Test Coverage
import json
from typing import Dict, List, Optional

from ..bursts import Burst
from ..models import Model
from . import redis_utils
from .base import BaseUnit


class QUnit(BaseUnit):  # pylint: disable=too-many-instance-attributes
    """[QUnit description]

    Parameters
    ------------
    name : str
        The qUnit name
    model : qrobot.models.Model
        The model the qUnit implements
    burst : qrobot.bursts.Burst
        The burst the qUnit implements
    Ts : float
        The sampling time with wich the qUnit reads the input
    query : list, optional
        The target state for the model queries. Defaults to ``None``
    in_units : dict[int, str], optional
        Dictionary containing {``dim`` : ``qunit_id``} inputs
        couplings, i.e. ``qunit_id`` output is the input for dimension
        ``dim``. Defaults to ``None``.
    default_input: List[float]
        Default input vector of scalar values to use as default value
        when qunit does not have an available one.
        Defaults to ``model.n*[0.0]``

    Attributes
    ----------
    id : str
        The unique instance identifier of the qUnit
    name : str
        The unique instance identifier of the qUnit
    model : qrobot.models.Model
        The model which the qUnit implements
    burst : qrobot.bursts.Burst
        The burst the qUnit implements
    Ts : float
        The sampling period for which the qUnit samples an event
    default_input: List[float]
        Default input vector of scalar values to use as default value
        when qunit does not have an available one
    """

    def __init__(  # pylint: disable=too-many-arguments
        self,
        name: str,
        model: Model,
        burst: Burst,
        Ts: float,  # pylint: disable=invalid-name
        query: list = None,
        in_qunits: Dict[int, str] = None,
        default_input: float = None,
    ) -> None:
        # Call the BaseUnit constructor
        super().__init__(name, Ts)

        # Store the qUnits name and properties
        self.model = model
        self.burst = burst
        self.default_input = default_input or model.n * [0.0]

        # Default query to all 0s if not specified
        if query is None:
            query = [0.0] * (self.model.n)

        # Initialize multiprocessing variables
        # - Query array variable
        self._query = self._multiproc_manager.list(query)
        # - Output unit dictionary
        self._in_qunits = self._multiproc_manager.dict(in_qunits or {})
        # - Time window index
        self._t_idx = self._multiproc_manager.Value("i", 0)

        # Log properties
        self._logger.debug(f"Properties: {self}")

    def __iter__(self):
        yield "name", self.name
        yield "id", self.id
        yield "model", str(self.model)
        yield "burst", str(self.burst.__class__)
        yield "query", self.query
        yield "Ts", self.Ts

    @property
    def query(self) -> List[float]:
        """Current target state for the model queries

        Returns
        -------
        list
            The query target state array in the computational basis
        """
        return list(self._query)

    @query.setter
    def query(self, query: list) -> None:
        """Set a new query state for the qunit

        Parameters
        -----------
        query : list
            The query target state array in the computational basis
        """
        # Check arguments
        query = self.model._target_vector_check(query)
        # Update accumulator
        self._logger.debug(f"Changing query from {self._query} to {query}")
        for idx, value in enumerate(query):
            self._query[idx] = value
        self._logger.debug(f"_query={self._query}")

    @property
    def in_qunits(self) -> Dict[int, str]:
        """Current output ``{dim : qunit_id}`` couplings.

        Returns
        -------
        dict
            The current output ``{dim : qunit_id}`` couplings dictionary
        """
        in_qunits = {}
        for dim in range(self.model.n):
            try:
                in_qunits[dim] = self._in_qunits[dim]
            except KeyError:
                in_qunits[dim] = None
        return in_qunits

    @property
    def input_vector(self) -> List[float]:
        """The current input vector of the unit

        Returns
        -------
        list
            The current input vector
        """
        input_vector = self.default_input
        for dim, qunit_id in self._in_qunits.items():
            _r = redis_utils.get_redis()
            val = _r.get(qunit_id + " output")
            if val is not None:
                input_vector[dim] = float(val)
            else:
                self._logger.info(f"Unable to read {qunit_id} input")
        return input_vector

    def set_input(self, dim: int, qunit_id: str) -> None:
        """Set a new input qunit for the desired dimension

        Parameters
        -----------
        dim : int
            The input dimension index
        qunit_id : str
            The input qunit id
        """
        # Check arguments
        dim = self.model._dim_index_check(dim)  # pylint: disable=protected-access
        # Update accumulator
        self._logger.debug(
            f"Changing dim {dim} input from " + f"{self.in_qunits[dim]} to {qunit_id}"
        )
        self._in_qunits[dim] = qunit_id
        self._logger.debug(f"_in_qunits={self._in_qunits}")

    def get_burst_output(self) -> Optional[float]:
        """Get the latest burst output from the qUnit

        Returns
        -------
        float
            The latest burst output written by the unit on the Redis database
        """
        global_status = redis_utils.redis_status()
        out = global_status.get(f"{self.id} output", None)
        return float(out) if out else None

    def _clean_redis(self) -> None:
        """Clean all the redis entries created by the unit when the loop stops."""
        _r = redis_utils.get_redis()
        _r.delete(self.id + " output")
        _r.delete(self.id + " state")
        _r.delete(self.id + " query")
        _r.delete(self.id + " in_qunits")

    def _unit_task(self) -> None:
        """Single iteration of the processing loop."""
        # "_t_idx" is the event index of the temporal window
        self._logger.debug(
            f"Temporal window event {self._t_idx.value+1}/{self.model.tau}"
        )
        # Get input
        input_vector = self.input_vector
        self._logger.debug(f"input_vector={input_vector}")
        # Loop through the dimensions to encode data
        for dim in range(self.model.n):
            self.model.encode(input_vector[dim], dim)
        # Wait for the next input in the time window
        self._t_idx.value += 1
        # If at the end of the time window
        if self._t_idx.value == self.model.tau:
            # Apply the query
            self._logger.debug(f"Querying for state {self._query}")
            self.model.query(self.query)
            # Decode
            out_state = self.model.decode()
            self._logger.debug(f"Output state = {out_state}")
            # Write output on Redis database
            self._logger.debug("Opening a connection to redis...")
            _r = redis_utils.get_redis()
            self._logger.debug(f"Redis connected: {_r}")
            if not (
                _r.mset({self.id + " output": self.burst(out_state)})
                and _r.mset({self.id + " state": str(out_state)})
                and _r.mset({self.id + " query": json.dumps(self.query)})
                and _r.mset({self.id + " in_qunits": json.dumps(self.in_qunits)})
            ):
                raise Exception(
                    f"Problem in writing qunit {self.id} output on Redis database!"
                )
            # Initialize new temporal window
            self._logger.debug("Initializing a new temporal window")
            self.model.clear()
            self._t_idx.value = 0