cgtobi/PyRMVtransport

View on GitHub
RMVtransport/rmvtransport.py

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
"""A module to query bus and train departure times."""
import asyncio
import json
import logging
import urllib.parse
import urllib.request
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

import async_timeout
import httpx
from lxml import etree, objectify  # type: ignore

from .const import (
    ALL_PRODUCTS,
    BASE_URI,
    GETSTOP_PATH,
    KNOWN_XML_ISSUES,
    MAX_RETRIES,
    PRODUCTS,
    STBOARD_PATH,
)
from .errors import (
    RMVtransportApiConnectionError,
    RMVtransportDataError,
    RMVtransportError,
)
from .rmvjourney import RMVjourney
from .rmvtravel import RMVtravel

_LOGGER = logging.getLogger(__name__)


class RMVtransport:
    """Connection data and travel information."""

    def __init__(self, timeout: float = 10) -> None:
        """Initialize connection data."""
        self._timeout: float = timeout

        self.now: datetime

        self.station_id: str
        self.direction_id: Optional[str]
        self.products_filter: str

        self.max_journeys: int

        self.obj: objectify.ObjectifiedElement
        self.journeys: List[RMVjourney] = []

    async def get_departures(
        self,
        station_id: str,
        direction_id: Optional[str] = None,
        max_journeys: int = 20,
        products: Optional[List[str]] = None,
    ) -> RMVtravel:
        """Fetch data from rmv.de."""
        url = self.build_journey_query(station_id, direction_id, max_journeys, products)
        xml = await self._query_rmv_api(url)
        self.obj = extract_data_from_xml(xml)

        try:
            self.now = self.current_time()
        except RMVtransportDataError:
            _LOGGER.debug(
                "XML contains unexpected data %s", objectify.dump(self.obj)[:100]
            )
            raise

        self.journeys.clear()
        try:
            for journey in self.obj.SBRes.JourneyList.Journey:
                self.journeys.append(RMVjourney(journey, self.now))
        except AttributeError as err:
            _LOGGER.debug("Extract journeys: %s", objectify.dump(self.obj.SBRes))
            raise RMVtransportError(err) from err

        return self.travel_data()

    def build_journey_query(
        self,
        station_id: str,
        direction_id: Optional[str] = None,
        max_journeys: int = 20,
        products: Optional[List[str]] = None,
    ) -> str:
        """Build query to request journey data."""
        self.station_id = station_id
        self.direction_id = direction_id
        self.max_journeys = max_journeys
        self.products_filter = product_filter(products or ALL_PRODUCTS)

        params: Dict[str, Union[str, int]] = {
            "selectDate": "today",
            "time": "now",
            "input": self.station_id,
            "maxJourneys": self.max_journeys,
            "boardType": "dep",
            "productsFilter": self.products_filter,
            "disableEquivs": "discard_nearby",
            "output": "xml",
            "start": "yes",
        }
        if self.direction_id:
            params["dirInput"] = self.direction_id

        return base_url() + urllib.parse.urlencode(params)

    async def search_station(self, name: str, max_results: int = 25) -> Dict[str, Dict]:
        """Search station/stop my name."""
        suggestions: List = await self._fetch_sugestions(name, max_results)

        return {
            item["extId"]: {
                "id": item["extId"],
                "name": item["value"],
                "lat": convert_coordinates(item["ycoord"]),
                "long": convert_coordinates(item["xcoord"]),
            }
            for item in suggestions
        }

    async def _fetch_sugestions(
        self, name: str, max_results: int
    ) -> List[Optional[Dict]]:
        """Fetch suggestsions for the given station name from the backend."""
        params: Dict[str, Union[str, int]] = {
            "getstop": 1,
            "REQ0JourneyStopsS0A": max_results,
            "REQ0JourneyStopsS0G": name,
        }

        url = base_url(GETSTOP_PATH) + urllib.parse.urlencode(params)
        _LOGGER.debug("URL: %s", url)

        response = await self._query_rmv_api(url)
        data = extract_json_data(response)

        try:
            json_data = json.loads(data)
        except (TypeError, json.JSONDecodeError) as err:
            _LOGGER.debug("Error in JSON: %s...", data[:100])
            raise RMVtransportError(err) from err

        return list(json_data["suggestions"][:max_results])

    async def _query_rmv_api(self, url: str) -> Any:
        """Query RMV API."""
        async with async_timeout.timeout(self._timeout):
            async with httpx.AsyncClient() as client:
                try:
                    response = await client.get(url)
                except (
                    asyncio.TimeoutError,
                    httpx.ReadTimeout,
                    httpx.ConnectTimeout,
                    httpx.ConnectError,
                ) as err:
                    _LOGGER.error("Can not load data from RMV API")
                    raise RMVtransportApiConnectionError(err) from err

        _LOGGER.debug("Response from RMV API: %s", response.status_code)
        return response.read()

    def travel_data(self) -> RMVtravel:
        """Return travel data."""
        return RMVtravel(
            self.station,
            self.station_id,
            self.products_filter,
            self.journeys,
            self.max_journeys,
        )

    @property
    def station(self) -> str:
        """Extract station name."""
        return str(self.obj.SBRes.SBReq.Start.Station.HafasName.Text.pyval)

    def current_time(self) -> datetime:
        """Extract current time."""
        try:
            _date = datetime.strptime(self.obj.SBRes.SBReq.StartT.get("date"), "%Y%m%d")
            _time = datetime.strptime(self.obj.SBRes.SBReq.StartT.get("time"), "%H:%M")
        except (ValueError, AttributeError) as err:
            raise RMVtransportDataError(err) from err

        return datetime.combine(_date.date(), _time.time())

    def print(self) -> None:
        """Pretty print travel times."""
        print(f"{self.station} - {self.now}")
        print("-------------")
        print(self.travel_data())


def product_filter(products) -> str:
    """Calculate the product filter."""
    _filter = sum({PRODUCTS[p] for p in products})
    return format(_filter, "b")[::-1]


def base_url(path: str = STBOARD_PATH) -> str:
    """Build base url."""
    _lang: str = "d"
    _type: str = "n"
    _with_suggestions: str = "?"
    return BASE_URI + path + _lang + _type + _with_suggestions


def convert_coordinates(value: str) -> float:
    """Convert coordinates to lat/long."""
    if len(value) < 8:
        return float(value[0] + "." + value[1:])
    return float(value[0:2] + "." + value[2:])


def extract_data_from_xml(xml: bytes) -> Any:
    """Extract data from xml."""
    retry = 0
    while retry < MAX_RETRIES:
        try:
            return objectify.fromstring(xml)

        except etree.XMLSyntaxError as err:
            xml = fix_xml(xml, err)
            retry -= 1


def fix_xml(data: bytes, err: etree.XMLSyntaxError) -> Any:
    """Try to fix known issues in XML data."""
    xml_issue = data.decode().split("\n")[err.lineno - 1]

    if xml_issue not in KNOWN_XML_ISSUES.keys():
        _LOGGER.debug("Unknown xml issue in: %s", xml_issue)
        raise RMVtransportError()

    return data.decode().replace(xml_issue, KNOWN_XML_ISSUES[xml_issue]).encode()


def extract_json_data(response) -> str:
    """Extract json from response."""
    data = response.decode("utf-8")
    return str(data[data.find("{") : data.rfind("}") + 1])  # noqa: E203