yafeunteun/sncf-cli

View on GitHub
locomotive/stores.py

Summary

Maintainability
A
0 mins
Test Coverage
import difflib
import logging
import re
import sqlite3
from csv import DictReader
from pathlib import Path
from typing import List, Optional, Tuple

import requests
from text_unidecode import unidecode

from .exceptions import StationNotFoundException
from .models import Station

# TODO: Update normalized names in stations db.
# when `normalize()` is changed.


def normalize(s: str) -> str:
    s = (
        unidecode(s.lower())
        .replace("-", " ")
        .replace(".", " ")
        .replace("gare de ", "")
        .replace("charles de gaulle", "cdg")
        .replace("provence", "pce")
        .replace("perrache", "pche")
        .replace("saint ", "st ")
        .replace("zuerich", "zurich")
        .strip()
    )
    # Clermont Fd. => Clermont F
    # Paris MP => Paris M
    s = re.sub(r"\s(\w)\w\.?$", r" \1", s)
    # Remove multiple spaces
    s = re.sub(r"\s+", " ", s)
    return s


class Stations:
    """
    A train stations database.
    See https://github.com/trainline-eu/stations.
    """

    schema = """
    DROP TABLE IF EXISTS stations;
    CREATE TABLE stations (
        name        TEXT NOT NULL,
        name_norm   TEXT NOT NULL,
        country     TEXT NOT NULL,
        sncf_id     TEXT,
        sncf_tvs_id TEXT,
        latitude    REAL,
        longitude   REAL
    );
    """

    def __init__(self, path: Optional[Path] = None, download: bool = True):
        self.path = path
        if not self.path:
            self.path = self.default_path()
        if not self.path.exists() and download:
            print(f"{self.path} not found, downloading...")
            self.download(self.path)

    def __conn__(self) -> sqlite3.Connection:
        # We can set `check_same_thread` to False
        # since we open the database in read-only mode.
        conn = sqlite3.connect(
            f"file:{self.path}?mode=ro", check_same_thread=False, uri=True
        )
        conn.set_trace_callback(logging.debug)  # type: ignore
        return conn

    def fetchone(self, sql: str, parameters: Tuple = tuple()) -> Tuple:
        with self.__conn__() as conn:
            row: Tuple = conn.execute(sql, parameters).fetchone()
            return row

    def fetchall(self, sql: str, parameters: Tuple = tuple()) -> List[Tuple]:
        with self.__conn__() as conn:
            rows: List[Tuple] = conn.execute(sql, parameters).fetchall()
            return rows

    def count(self) -> int:
        return int(self.fetchone("SELECT COUNT(*) FROM stations")[0])

    def find_by_id(self, query: str) -> Optional[Station]:
        sql = "SELECT * FROM stations WHERE lower(sncf_id) LIKE ?"
        row = self.fetchone(sql, (query.lower(),))
        if row:
            return Station.from_row(row)
        return None

    def find_by_name(self, query: str) -> Optional[Station]:
        query_norm = normalize(query)
        sql = "SELECT * FROM stations WHERE name_norm LIKE ?"
        rows = self.fetchall(sql, (f"%{query_norm}%",))
        matches = difflib.get_close_matches(
            query_norm, [x[1] for x in rows], cutoff=0.1, n=1
        )
        if matches:
            row = next(x for x in rows if x[1] == matches[0])
            return Station.from_row(row)
        return None

    def find(self, query: str) -> Optional[Station]:
        station = self.find_by_id(query)
        if not station:
            station = self.find_by_name(query)
        return station

    def find_or_create(self, query: str) -> Station:
        station = self.find(query)
        if not station:
            station = Station(query, normalize(query), None, None, None, None)
        return station

    def find_or_raise(self, query: str) -> Station:
        station = self.find(query)
        if not station:
            raise StationNotFoundException(query)
        return station

    @classmethod
    def default_path(cls) -> Path:
        return Path.home().joinpath(".locomotive", "stations.sqlite3")

    @classmethod
    def download(cls, path: Path) -> None:
        path.parent.mkdir(exist_ok=True, parents=True)

        print("Downloading trainline-eu/stations/stations.csv...")
        url = "https://github.com/trainline-eu/stations/raw/master/stations.csv"
        csv = requests.get(url).content.decode("utf-8").split("\n")

        print("Extracting stations...")
        rdr = DictReader(csv, delimiter=";")
        stations = [
            (
                row["name"],
                normalize(row["name"]),
                row["country"],
                row["sncf_id"],
                row["sncf_tvs_id"],
                row["latitude"],
                row["longitude"],
            )
            for row in rdr
        ]

        with sqlite3.connect(str(path)) as conn:
            print("Inserting stations...")
            conn.executescript(cls.schema)
            conn.executemany("INSERT INTO stations VALUES (?,?,?,?,?,?,?)", stations)