SEIAROTg/autobean

View on GitHub
autobean/truelayer/importer.py

Summary

Maintainability
B
6 hrs
Test Coverage
F
0%
import datetime
from decimal import Decimal
import http.server
import logging
import os
import re
import time
import secrets
import sys
from typing import Any, Optional
import urllib.parse
import webbrowser

from autobean.utils import deduplicate
from beancount.core.amount import Amount
from beancount.core.data import Transaction, Posting, Balance, Directive, new_metadata
from beancount.core import inventory
from beancount.ingest import importer
from beancount.ingest import cache
import dateutil.parser
import requests
import yaml


CONFIG_SUFFIX = '.truelayer.yaml'
ACCOUNT_TYPES = ('accounts', 'cards')


def escape_account_component(s: str) -> str:
    s = re.sub(r'\W', '', s)
    s = s[:1].upper() + s[1:]
    return s


def format_iso_datetime(timestamp_s: float) -> str:
    return datetime.datetime.utcfromtimestamp(int(timestamp_s)).isoformat()


def currency_to_decimal(currency: float) -> Decimal:
    return Decimal(f'{currency:.2f}')


class Importer(importer.ImporterProtocol):

    def __init__(self, client_id: str, client_secret: str):
        self._client_id = client_id
        self._client_secret = client_secret

    def name(self) -> str:
        return 'autobean.truelayer'

    def identify(self, file: cache._FileMemo) -> bool:
        return file.name.endswith(CONFIG_SUFFIX)

    def extract(self, file: cache._FileMemo, existing_entries: Optional[list[Directive]] = None) -> list[Directive]:
        config = _Config(self._client_id, self._client_secret, file)
        extractor = _Extractor(config)
        return extractor.extract(existing_entries)


class _Config:
    def __init__(self, client_id: str, client_secret: str, file: cache._FileMemo):
        self.client_id = client_id
        self.client_secret = client_secret
        self.data = yaml.safe_load(file.contents()) or {}
        self._filename = file.name

    @property
    def name(self) -> str:
        return os.path.basename(self._filename).rsplit(CONFIG_SUFFIX, 1)[0]

    def dump(self) -> None:
        with open(self._filename, 'w') as f:
            yaml.safe_dump(self.data, f)


class _Extractor:
    def __init__(self, config: _Config):
        self._config = config
        self._oauth_manager = _OAuthManager(config)

    def extract(self, existing_entries: Optional[list[Directive]] = None) -> list[Directive]:
        for type_ in ACCOUNT_TYPES:
            self._update_accounts(type_)
        entries = self._fetch_all_transactions()
        if existing_entries:
            entries = deduplicate.deduplicate(entries, existing_entries)
        return entries

    @property
    def _auth_headers(self) -> dict[str, str]:
        return {
            'Authorization': f'Bearer {self._oauth_manager.access_token}'
        }

    def _update_accounts(self, type_: str) -> None:
        url = {
            'accounts': 'https://api.truelayer.com/data/v1/accounts',
            'cards': 'https://api.truelayer.com/data/v1/cards',
        }
        r = requests.get(url[type_], headers=self._auth_headers)
        if not r.ok:
            logging.warning('Could not fetch %s: %s', type_, r.text)
            return
        accounts = r.json().get('results', [])
        config_accounts = self._config.data.setdefault(type_, {})
        for account in accounts:
            config_account = config_accounts.setdefault(
                account['account_id'], {})
            config_account.setdefault('name', account['display_name'])
            config_account.setdefault(
                'liability',
                type_ == 'cards' and account['card_type'] == 'CREDIT')
            config_account.setdefault('enabled', True)
            config_account.setdefault('beancount_account', ':'.join([
                'Liabilities' if config_account['liability'] else 'Assets',
                escape_account_component(self._config.name),
                escape_account_component(config_account['name'])
            ]))
            config_account.setdefault('from', int(time.time()) - 86400 * 90)

        self._config.dump()

    def _fetch_transactions(
            self,
            account_id: str,
            account: dict[str, Any],
            type_: str,
            is_pending: bool) -> list[dict[str, Any]]:
        url = {
            ('accounts', False): (
                f'https://api.truelayer.com/data/v1/accounts/{account_id}/transactions'),
            ('accounts', True): (
                f'https://api.truelayer.com/data/v1/accounts/{account_id}/transactions/pending'),
            ('cards', False): (
                f'https://api.truelayer.com/data/v1/cards/{account_id}/transactions'),
            ('cards', True): (
                f'https://api.truelayer.com/data/v1/cards/{account_id}/transactions/pending'),
        }
        log_transaction = 'pending transactions' if is_pending else 'transactions'
        logging.info(
            f'Fetching {log_transaction} for account {account["name"]} '
            f'({account_id}).')
        r = requests.get(
            url[(type_, is_pending)],
            headers=self._auth_headers,
            params={
                'from': format_iso_datetime(account['from']),
                'to': format_iso_datetime(time.time()),
            }
        )
        if not r.ok:
            logging.error('Error fetching transactions: %s', r.text)
            r.raise_for_status()
        txns = r.json().get('results', [])
        logging.info(
            f'Fetched {len(txns)} {log_transaction} for account '
            f'{account["name"]} ({account_id}).')
        return txns

    def _fetch_balances(
            self,
            account_id: str,
            account: dict[str, Any],
            type_: str) -> list[dict[str, Any]]:
        url = {
            'accounts': f'https://api.truelayer.com/data/v1/accounts/{account_id}/balance',
            'cards': f'https://api.truelayer.com/data/v1/cards/{account_id}/balance',
        }
        logging.info(
            f'Fetching balance for account {account["name"]} ({account_id}).')
        r = requests.get(url[type_], headers=self._auth_headers)
        if not r.ok:
            logging.error('Error fetching balance: %s', r.text)
            r.raise_for_status()
        balances = r.json().get('results', [])
        logging.info(
            f'Fetched {len(balances)} balance entries for account '
            f'{account["name"]} ({account_id}).')
        return balances


    def _fetch_all_transactions(self) -> list[Directive]:
        entries: list[Directive] = []
        for type_ in ACCOUNT_TYPES:
            for account_id, account in self._config.data[type_].items():
                if not account['enabled']:
                    continue
                truelayer_txns = self._fetch_transactions(
                    account_id, account, type_, False)
                time_txns = [
                    (
                        dateutil.parser.parse(truelayer_txn['timestamp']),
                        self._transform_transaction(
                            truelayer_txn, account['beancount_account']))
                    for truelayer_txn in truelayer_txns
                ]
                pending_truelayer_txns = self._fetch_transactions(
                    account_id, account, type_, True)
                pending_time_txns = [
                    (
                        dateutil.parser.parse(truelayer_txn['timestamp']),
                        self._transform_transaction(
                            truelayer_txn, account['beancount_account'], True))
                    for truelayer_txn in pending_truelayer_txns
                ]
                entries.extend(txn for _, txn in time_txns)
                entries.extend(txn for _, txn in pending_time_txns)

                balances = self._fetch_balances(account_id, account, type_)
                for balance in balances:
                    entries.append(self._transform_balance(
                        balance, account, time_txns, pending_time_txns))
        return entries

    def _transform_balance(
            self,
            truelayer_balance: dict[str, Any],
            account: dict[str, Any],
            time_txns: list[tuple[datetime.datetime, Transaction]],
            pending_time_txns: list[tuple[datetime.datetime, Transaction]],
    ) -> Balance:
        """Transforms TrueLayer Balance to beancount Balance.
        
        Balance from TrueLayer can be effective at the middle of a day with
        pending transactions ignored, while beancount balance assertions
        must be applied at the beginning of a day and pending transactions
        are taken into account.

        It is not always possible to get pending transactions. If that is not
        available balance assertions may have to be corrected retrospectively.
        """

        balance_time = dateutil.parser.parse(
            truelayer_balance['update_timestamp']).astimezone()
        assertion_time = datetime.datetime.combine(
            balance_time, datetime.time.min, balance_time.tzinfo)

        txns_to_remove = [
            txn
            for t, txn in time_txns
            if assertion_time <= t < balance_time
        ]
        inventory_to_remove = inventory.Inventory()
        for txn in txns_to_remove:
            for posting in txn.postings:
                inventory_to_remove.add_position(posting)
        amount_to_remove = inventory_to_remove.get_currency_units(
            truelayer_balance['currency'])

        txns_to_add = [
            txn
            for t, txn in pending_time_txns
            if t < assertion_time
        ]
        inventory_to_add = inventory.Inventory()
        for txn in txns_to_add:
            for posting in txn.postings:
                inventory_to_add.add_position(posting)
        amount_to_add = inventory_to_add.get_currency_units(
            truelayer_balance['currency'])
        
        number = currency_to_decimal(truelayer_balance['current'])
        if account['liability']:
            number = -number
        number += amount_to_add.number
        number -= amount_to_remove.number
        return Balance(
            meta=new_metadata('', 0),
            date=assertion_time.date(),
            account=account['beancount_account'],
            amount=Amount(number, truelayer_balance['currency']),
            tolerance=None,
            diff_amount=None,
        )

    def _transform_transaction(
            self,
            truelayer_txn: dict[str, Any],
            beancount_account: str,
            is_pending: bool = False) -> Transaction:
        """Transforms TrueLayer Transaction to beancount Transaction."""

        number = abs(currency_to_decimal(truelayer_txn['amount']))
        if truelayer_txn['transaction_type'] == 'DEBIT':
            number = -number
        elif truelayer_txn['transaction_type'] == 'CREDIT':
            pass
        else:
            assert False

        posting = Posting(
            account=beancount_account,
            units=Amount(number, truelayer_txn['currency']),
            cost=None,
            price=None,
            flag=None,
            meta=None,
        )
        payee = (
            truelayer_txn.get('merchant_name', None) or
            truelayer_txn['meta'].get('provider_merchant_name', None))
        return Transaction(
            meta=new_metadata('', 0),
            date=dateutil.parser.parse(truelayer_txn['timestamp']).astimezone().date(),
            flag='!' if is_pending else '*',
            payee=payee,
            narration=truelayer_txn['description'],
            tags=set(),
            links=set(),
            postings=[posting],
        )


class _OAuthManager:
    ADDRESS = '127.0.0.1'
    PORT = 3000
    REDIRECT_URI = 'http://localhost:3000/callback'

    def __init__(self, config: _Config):
        self._config = config
    
    @property
    def access_token(self) -> str:
        access_token = self._get_valid_access_token()
        if access_token:
            return access_token
        self._refresh_access_token()
        access_token = self._get_valid_access_token()
        if access_token:
            return access_token
        self._request_access_token()
        access_token = self._get_valid_access_token()
        if access_token:
            return access_token
        raise RuntimeError('Unable to get a valid access token.')

    def _get_valid_access_token(self) -> Optional[str]:
        """Get access token from config file."""

        access_token = self._config.data.get('access_token')
        expiry_time = self._config.data.get('access_token_expiry_time')
        now = int(time.time())
        if access_token and expiry_time and expiry_time > now:
            return access_token
        return None

    def _refresh_access_token(self) -> None:
        """Refresh access token with refresh token."""

        logging.info('Attempt to refresh access token.')
        refresh_token = self._config.data.get('refresh_token', None)
        if not refresh_token:
            logging.info(
                'Failed to refresh access token: refresh token not available.')
            return
        self._grant_access_token(refresh_token=refresh_token)

    def _request_access_token(self) -> None:
        """Get access token with regular OAuth flow."""

        logging.info('Attempt to request access token with regular OAuth flow.')
        code = self._request_code()
        self._grant_access_token(code=code)
        logging.info('Successfully requested access token.')

    def _grant_access_token(self, code: Optional[str] = None, refresh_token: Optional[str] = None) -> None:
        """Grant access token with code or refresh_token."""

        logging.info('Attempt to grant access token.')
        req = {
            'client_id': self._config.client_id,
            'client_secret': self._config.client_secret,
        }
        if code:
            req['grant_type'] = 'authorization_code'
            req['redirect_uri'] = self.REDIRECT_URI
            req['code'] = code
        elif refresh_token:
            req['grant_type'] = 'refresh_token'
            req['refresh_token'] = refresh_token
        else:
            assert False
        r = requests.post('https://auth.truelayer.com/connect/token', req)
        if r.status_code != 200:
            logging.warning(
                f'Failed to grant access token: server returns '
                f'{r.status_code}')
            return
        data = r.json()
        self._config.data['access_token'] = data['access_token']
        self._config.data['access_token_expiry_time'] = (
            int(time.time()) + data['expires_in'])
        self._config.data['refresh_token'] = data['refresh_token']
        self._config.dump()
        logging.info('Successfully granted access token.')

    def _request_code(self) -> str:
        """Get the code to redeem access token with regular OAuth flow."""

        state = secrets.token_urlsafe(16)
        code = None
        auth_link = None

        class HttpHandler(http.server.BaseHTTPRequestHandler):
            def do_POST(self) -> None:
                logging.info('OAuth response received.')
                length = int(self.headers['Content-Length'])
                body = self.rfile.read(length).decode('utf-8')
                data = dict(urllib.parse.parse_qsl(body))
                received_state = data.get('state')
                received_code = data.get('code')

                if received_code and received_state == state:
                    nonlocal code
                    code = received_code
                    self.send_response(200)
                    response = b'You can now close this tab.\n'
                    self.send_header('Content-Type', 'text/plain')
                    self.send_header('Content-Length', str(len(response)))
                    self.end_headers()
                    self.wfile.write(response)
                else:
                    if received_state != state:
                        logging.warning('OAuth response state mismatches.')
                    elif not received_code:
                        logging.warning('OAuth response misses code.')
                    self.send_response(302)
                    assert auth_link
                    self.send_header('Location', auth_link)

        httpd = http.server.HTTPServer((self.ADDRESS, self.PORT), HttpHandler)
        socketname = httpd.socket.getsockname()
        logging.info(f'OAuth server listening at {socketname}')
        auth_link = self._build_auth_link(state)
        webbrowser.open_new(auth_link)

        print(
            f'Please navigate to the following URL to complete the '
            f'authorization process:\n\n'
            f'{auth_link}\n\n'
            f'If you are unable to visit the link on the same '
            f'host that this script is running on, you might need to '
            f'forward TCP port {socketname[1]} to the host where the '
            f'browser will be running on during the process.',
            file=sys.stderr)

        while not code:
            httpd.handle_request()
        httpd.server_close()
        return code

    def _build_auth_link(self, state: str) -> str:
        qs = urllib.parse.urlencode({
            'response_type': 'code',
            'response_mode': 'form_post',
            'client_id': self._config.client_id,
            'redirect_uri': self.REDIRECT_URI,
            'scope': ' '.join([
                'accounts',
                'cards',
                'transactions',
                'balance',
                'offline_access',
            ]),
            'state': state,
        })
        return f'https://auth.truelayer.com/?{qs}'

logging.basicConfig(level=os.getenv('LOG_LEVEL', 'WARNING'))