bachya/pyairvisual

View on GitHub
pyairvisual/node.py

Summary

Maintainability
B
6 hrs
Test Coverage
"""Define objects to interact with an AirVisual Node/Pro."""

from __future__ import annotations

import asyncio
import csv
import json
import tempfile
from collections import OrderedDict
from collections.abc import Awaitable, Callable
from functools import partial
from types import TracebackType
from typing import IO, Any, TypeVar, cast, overload

import numpy as np
import smb
from smb.SMBConnection import SMBConnection

from .const import LOGGER
from .errors import AirVisualError

API_URL_BASE = "https://www.airvisual.com/api/v2/node"

DEFAULT_CONNECT_TIMEOUT = 10

SAMBA_HISTORY_PATTERN = "*_AirVisual_values.txt"
SMB_SERVICE = "airvisual"
SMB_USERNAME = "airvisual"

METRIC_AQI_CN = "aqi_cn"
METRIC_AQI_US = "aqi_us"
METRIC_CO2 = "co2"
METRIC_HUMIDITY = "humidity"
METRIC_PM01 = "pm0_1"
METRIC_PM10 = "pm1_0"
METRIC_PM25 = "pm2_5"
METRIC_VOC = "voc"
METRICS_TO_TREND = [
    METRIC_AQI_CN,
    METRIC_AQI_US,
    METRIC_CO2,
    METRIC_HUMIDITY,
    METRIC_PM01,
    METRIC_PM10,
    METRIC_PM25,
    METRIC_VOC,
]

METRIC_MAPPING = {
    "AQI(CN)": METRIC_AQI_CN,
    "AQI(US)": METRIC_AQI_US,
    "CO2(ppm)": METRIC_CO2,
    "Humidity(%RH)": METRIC_HUMIDITY,
    "PM01(ug/m3)": METRIC_PM01,
    "PM10(ug/m3)": METRIC_PM10,
    "PM2_5(ug/m3)": METRIC_PM25,
    "VOC(ppb)": METRIC_VOC,
    "co2_ppm": METRIC_CO2,
    "humidity_RH": METRIC_HUMIDITY,
    "pm01_ugm3": METRIC_PM01,
    "pm10_ugm3": METRIC_PM10,
    "pm25": METRIC_PM25,
    "pm25_AQICN": METRIC_AQI_CN,
    "pm25_AQIUS": METRIC_AQI_US,
    "pm25_ugm3": METRIC_PM25,
    "voc_ppb": METRIC_VOC,
}

TREND_FLAT = "flat"
TREND_INCREASING = "increasing"
TREND_DECREASING = "decreasing"


class NodeProError(AirVisualError):
    """Define an error related to Node/Pro errors."""

    pass


class NodeConnectionError(NodeProError):
    """Define an error connection issues."""

    pass


class InvalidAuthenticationError(NodeProError):
    """Define an error for invalid authentication."""

    pass


def _calculate_trends(
    history: list[OrderedDict], measurements_to_use: int
) -> dict[str, Any]:
    """Calculate the trends of all data points in history data.

    Args:
        history: A list of dict-based measurements.
        measurements_to_use: The number of measurements to include (-1 for all)

    Returns:
        An API response payload.
    """
    if measurements_to_use == -1:
        index_range = np.arange(0, len(history))
    else:
        index_range = np.arange(0, measurements_to_use)

    measured_attributes = set().union(*(d.keys() for d in history))
    metrics_to_trend = measured_attributes.intersection(list(METRICS_TO_TREND))

    trends = {}
    for attribute in metrics_to_trend:
        values = [
            float(value)
            for measurement in history
            for attr, value in measurement.items()
            if attr == attribute
        ]

        if measurements_to_use != -1:
            values = values[-measurements_to_use:]

        index_array = np.array(values)
        linear_fit = np.polyfit(
            index_range,
            index_array,
            1,
        )
        slope = round(linear_fit[0], 2)

        metric = _get_normalized_metric_name(attribute)

        if slope > 0:
            trends[metric] = TREND_INCREASING
        elif slope < 0:
            trends[metric] = TREND_DECREASING
        else:
            trends[metric] = TREND_FLAT

    return trends


def _get_normalized_metric_name(key: str) -> str:
    """Return a normalized string (if it exists) for a metric.

    Args:
        key: A metric name to examine.

    Returns:
        A normalized metric name or the original.
    """
    return METRIC_MAPPING.get(key, key)


class NodeCloudAPI:  # pylint: disable=too-few-public-methods
    """Define an object to work with getting Node info via the Cloud API."""

    def __init__(self, request: Callable[..., Awaitable]) -> None:
        """Initialize.

        Args:
            request: The request method from the CloudAPI object.
        """
        self._request = request

    async def get_by_node_id(self, node_id: str) -> dict[str, Any]:
        """Return cloud API data from a node its ID.

        Args:
            node_id: A Node ID.

        Returns:
            An API response payload.
        """
        data = await self._request("get", node_id, base_url=API_URL_BASE)
        return cast(dict[str, Any], data)


_SambaOperationReturnType = TypeVar(  # pylint: disable=invalid-name
    "_SambaOperationReturnType",
    int,
    list[smb.base.SharedFile],
    list[dict[str, Any]],
    None,
)


class NodeSamba:
    """Define an object to work with getting Node info over Samba."""

    def __init__(self, ip_or_hostname: str, password: str) -> None:
        """Initialize.

        Args:
            ip_or_hostname: An IP address or hostname to a Node.
            password: A Samba password for a Node.
        """
        self._conn = SMBConnection(SMB_USERNAME, password, "pyairvisual", SMB_SERVICE)
        self._connected = False
        self._ip_or_hostname = ip_or_hostname
        self._latest_history = None
        self._loop = asyncio.get_event_loop()

    async def __aenter__(self) -> NodeSamba:
        """Handle the start of a context manager.

        Returns:
            A connected NodeSamba object.
        """
        await self.async_connect()
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,  # noqa: F841
        exc_val: BaseException | None,  # noqa: F841
        exc_tb: TracebackType | None,  # noqa: F841
    ) -> None:
        """Handle the end of a context manager.

        Args:
            exc_type: An optional exception if one caused the context manager to close.
            exc_val: The value of the optional exception
            exc_tb: The traceback of the optional exception
        """
        await self.async_disconnect()

    @overload
    async def _execute_samba_operation(
        self, pysmb_func: Callable[..., list[dict[str, Any]]]
    ) -> list[dict[str, Any]]: ...

    @overload
    async def _execute_samba_operation(
        self,
        pysmb_func: Callable[[None], bool],
        ip_or_hostname: str,
        *,
        timeout: int = DEFAULT_CONNECT_TIMEOUT,
    ) -> bool: ...

    @overload
    async def _execute_samba_operation(
        self,
        pysmb_func: Callable[[str, str, IO[bytes]], None],
        service: str,  # noqa: F841
        filepath: str,
        file_obj: IO[bytes],  # noqa: F841
    ) -> None: ...

    @overload
    async def _execute_samba_operation(  # pylint: disable=too-many-arguments
        self,
        pysmb_func: Callable[[str, str], list[smb.base.SharedFile]],
        service: str,  # noqa: F841
        filepath: str,
        *,
        pattern: str | None = None,  # noqa: F841
        search: str | None = None,  # noqa: F841
    ) -> list[smb.base.SharedFile]: ...

    async def _execute_samba_operation(
        self,
        pysmb_func: Callable[..., _SambaOperationReturnType],
        *args: Any,
        **kwargs: Any,
    ) -> _SambaOperationReturnType:
        """Guard a Samba command with appropriate error handling.

        Args:
            pysmb_func: A pysmb function to run.
            *args: Any args to pass to the pysmb function.
            **kwargs: Any kwargs to pass to the pysmb function.

        Returns:
            Any type supported by pysmb operations.

        Raises:
            InvalidAuthenticationError: Raised on invalid Samba auth.
            NodeConnectionError: Raised on any Samba connection-related error.
            NodeProError: Raised on any unknown error.
        """
        func_with_kwargs = partial(pysmb_func, **kwargs)
        try:
            res = await self._loop.run_in_executor(  # type: ignore[func-returns-value]
                None, func_with_kwargs, *args
            )
        except smb.base.NotConnectedError as err:
            raise NodeConnectionError(f"The Pro unit is not connected: {err}") from err
        except smb.base.NotReadyError as err:
            raise InvalidAuthenticationError(
                f"The Pro unit returned an authentication error: {err}"
            ) from err
        except smb.base.SMBTimeout as err:
            raise NodeConnectionError(
                "Timed out while talking to the Pro unit"
            ) from err
        except ConnectionRefusedError as err:
            raise NodeConnectionError(
                "Couldn't find a Pro unit at the provided IP address"
            ) from err
        except Exception as err:  # pylint: disable=broad-except
            raise NodeProError(err) from err

        return res

    async def _async_get_history_files(self) -> list[smb.base.SharedFile]:
        """Return all the history files on a Samba device.

        Returns:
            A list of Samba file references.
        """
        return await self._execute_samba_operation(
            self._conn.listPath,
            SMB_SERVICE,
            "/",
            pattern=SAMBA_HISTORY_PATTERN,
            search=smb.smb_constants.SMB_FILE_ATTRIBUTE_NORMAL,
        )

    async def _async_retrieve_data_from_tempfile(
        self, tmp_file: IO[bytes]
    ) -> list[dict[str, Any]]:
        """Retrieve data from a NamedTemporaryFile.

        Args:
            tmp_file: A reference to a NamedTemporaryFile.

        Returns:
            An API response payload.
        """

        def get_data() -> list[dict[str, Any]]:
            """Get the data.

            Returns:
                An API response payload.
            """
            data = []
            with open(tmp_file.name, encoding="utf-8") as file:
                reader = csv.DictReader(file, delimiter=";")
                for row in reader:
                    data.append(
                        {
                            _get_normalized_metric_name(header): value
                            for header, value in row.items()
                        }
                    )

            LOGGER.debug("Loaded data from file: %s", data)
            return data

        return await self._execute_samba_operation(get_data)

    async def _async_store_filepath_in_tempfile(
        self, filepath: str, tmp_file: IO[bytes]
    ) -> None:
        """Save a file to a NamedTemporaryFile object.

        Args:
            filepath: A filepath on the Node Samba share.
            tmp_file: A reference to a NamedTemporaryFile.
        """
        await self._execute_samba_operation(
            self._conn.retrieveFile, SMB_SERVICE, filepath, tmp_file
        )

    async def async_connect(self, *, timeout: int = DEFAULT_CONNECT_TIMEOUT) -> None:
        """Connect to the Node.

        Args:
            timeout: The number of seconds before timing out the connection attempt.

        Raises:
            InvalidAuthenticationError: Raised when the provided Samba password
                is incorrect.
        """
        if self._connected:
            LOGGER.debug("Already connected!")
            return

        result = await self._execute_samba_operation(
            self._conn.connect, self._ip_or_hostname, timeout=timeout
        )

        if result:
            self._connected = True
        else:
            raise InvalidAuthenticationError("Invalid Samba authentication")

    async def async_disconnect(self) -> None:
        """Disconnect from the Node."""
        if not self._connected:
            LOGGER.debug("Already disconnected!")
            return

        await self._execute_samba_operation(self._conn.close)
        self._connected = False

    async def async_get_history(
        self, *, include_trends: bool = True, measurements_to_use: int = -1
    ) -> dict[str, Any]:
        """Get history data from the device.

        Args:
            include_trends: Whether trend data should be included.
            measurements_to_use: The number of measurements to include (-1 for all)

        Returns:
            An API response payload.

        Raises:
            NodeProError: Raised when no history files are found.
        """
        history_files = await self._async_get_history_files()
        history_files.sort(key=lambda file: file.filename, reverse=True)

        if not history_files:
            raise NodeProError(
                f"No history files found that match {SAMBA_HISTORY_PATTERN}"
            )

        data: dict[str, Any] = {}

        for history_file in history_files:
            tmp_file = (
                tempfile.NamedTemporaryFile()  # pylint: disable=consider-using-with
            )
            await self._async_store_filepath_in_tempfile(
                f"/{history_file.filename}", tmp_file
            )
            tmp_file.seek(0)
            data["measurements"] = await self._async_retrieve_data_from_tempfile(
                tmp_file
            )
            tmp_file.close()

            if include_trends:
                data["trends"] = _calculate_trends(
                    data["measurements"], measurements_to_use
                )

            if data["measurements"]:
                break

        return data

    async def async_get_latest_measurements(self) -> dict[str, Any]:
        """Get the latest measurements from the device.

        Returns:
            An API response payload.
        """
        data = {}

        tmp_file = tempfile.NamedTemporaryFile()
        await self._async_store_filepath_in_tempfile(
            "/latest_config_measurements.json", tmp_file
        )
        tmp_file.seek(0)
        raw = tmp_file.read()
        tmp_file.close()
        data = json.loads(raw.decode())

        LOGGER.debug("Node measurements loaded: %s", data)

        try:
            # Handle a single measurement returned in a list:
            measurements = data["measurements"][0].items()
        except KeyError:
            # Handle a single measurement returned as a standalone dict:
            measurements = data["measurements"].items()

        data["last_measurement_timestamp"] = int(data["date_and_time"]["timestamp"])
        data["measurements"] = {
            _get_normalized_metric_name(pollutant): value
            for pollutant, value in measurements
        }
        data["status"]["sensor_life"] = {
            _get_normalized_metric_name(pollutant): value
            for pollutant, value in data["status"].get("sensor_life", {}).items()
        }

        return data