src/pyff/api.py

Summary

Maintainability
D
2 days
Test Coverage
import importlib
import threading
from datetime import datetime, timedelta
from json import dumps
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple

import pkg_resources
import pyramid.httpexceptions as exc
import pytz
import requests
from accept_types import AcceptableType
from cachetools import TTLCache
from lxml import etree
from pyramid.config import Configurator
from pyramid.events import NewRequest
from pyramid.request import Request
from pyramid.response import Response
from six import b
from six.moves.urllib_parse import quote_plus

from pyff.constants import config
from pyff.exceptions import ResourceException
from pyff.logs import get_log
from pyff.pipes import plumbing
from pyff.repo import MDRepository
from pyff.resource import Resource
from pyff.samlmd import entity_display_name
from pyff.utils import b2u, dumptree, hash_id, json_serializer, utc_now

log = get_log(__name__)


class NoCache(object):
    """ Dummy implementation for when caching isn't enabled """

    def __init__(self) -> None:
        pass

    def __getitem__(self, item: Any) -> None:
        return None

    def __setitem__(self, instance: Any, value: Any) -> Any:
        return value


def robots_handler(request: Request) -> Response:
    """
    Implements robots.txt

    :param request: the HTTP request
    :return: robots.txt
    """
    return Response(
        """
User-agent: *
Disallow: /
"""
    )


def status_handler(request: Request) -> Response:
    """
    Implements the /api/status endpoint

    :param request: the HTTP request
    :return: JSON status
    """
    d = {}
    for r in request.registry.md.rm:
        if 'Validation Errors' in r.info and r.info['Validation Errors']:
            d[r.url] = r.info['Validation Errors']
    _status = dict(
        version=pkg_resources.require("pyFF")[0].version,
        invalids=d,
        icon_store=dict(size=request.registry.md.icon_store.size()),
        jobs=[dict(id=j.id, next_run_time=j.next_run_time) for j in request.registry.scheduler.get_jobs()],
        threads=[t.name for t in threading.enumerate()],
        store=dict(size=request.registry.md.store.size()),
    )
    response = Response(dumps(_status, default=json_serializer))
    response.headers['Content-Type'] = 'application/json'
    return response


class MediaAccept(object):
    def __init__(self, accept: str):
        self._type = AcceptableType(accept)

    def has_key(self, key: Any) -> bool:  # Literal[True]:
        return True

    def get(self, item: Any) -> Any:
        return self._type.matches(item)

    def __contains__(self, item: Any) -> Any:
        return self._type.matches(item)

    def __str__(self) -> str:
        return str(self._type)


xml_types = ('text/xml', 'application/xml', 'application/samlmetadata+xml')


def _is_xml_type(accepter: MediaAccept) -> bool:
    return any([x in accepter for x in xml_types])


def _is_xml(data: Any) -> bool:
    return isinstance(data, (etree._Element, etree._ElementTree))


def _fmt(data: Any, accepter: MediaAccept) -> Tuple[str, str]:
    """
    Format data according to the accepted content type of the requester.
    Return data as string (either XML or json) and a content-type.
    """
    if data is None or len(data) == 0:
        return "", 'text/plain'
    if _is_xml(data) and _is_xml_type(accepter):
        return dumptree(data), 'application/samlmetadata+xml'
    if isinstance(data, (dict, list)) and accepter.get('application/json'):
        return dumps(data, default=json_serializer), 'application/json'

    raise exc.exception_response(406)


def call(entry: str) -> None:
    url = f'{config.base_url}/api/call/{entry}'
    log.debug(f'Calling API endpoint at {url}')
    resp = requests.post(url)
    if resp.status_code >= 300:
        log.error(f'POST request to API endpoint at {url} failed: {resp.status_code} {resp.reason}')
    return None


def request_handler(request: Request) -> Response:
    """
    The main GET request handler for pyFF. Implements caching and forwards the request to process_handler

    :param request: the HTTP request object
    :return: the data to send to the client
    """
    key = request.path_qs
    r = None
    try:
        r = request.registry.cache[key]
    except KeyError:
        pass
    if r is None:
        r = process_handler(request)
        request.registry.cache[key] = r
    return r


def process_handler(request: Request) -> Response:
    """
    The main request handler for pyFF. Implements API call hooks and content negotiation.

    :param request: the HTTP request object
    :return: the data to send to the client
    """
    _ctypes = {'xml': 'application/samlmetadata+xml;application/xml;text/xml', 'json': 'application/json'}

    def _d(x: Optional[str], do_split: bool = True) -> Tuple[Optional[str], Optional[str]]:
        """ Split a path into a base component and an extension. """
        if x is not None:
            x = x.strip()

        if x is None or len(x) == 0:
            return None, None

        if '.' in x:
            (pth, dot, extn) = x.rpartition('.')
            assert dot == '.'
            if extn in _ctypes:
                return pth, extn

        return x, None

    log.debug(f'Processing request: {request}')

    if request.matchdict is None:
        raise exc.exception_response(400)

    if request.body:
        try:
            request.matchdict.update(request.json_body)
        except ValueError as ex:
            pass

    entry = request.matchdict.get('entry', 'request')
    path_elem = list(request.matchdict.get('path', []))
    match = request.params.get('q', request.params.get('query', None))

    # Enable matching on scope.
    match = match.split('@').pop() if match and not match.endswith('@') else match
    log.debug("match={}".format(match))

    if not path_elem:
        path_elem = ['entities']

    alias = path_elem.pop(0)
    path = '/'.join(path_elem)

    # Ugly workaround bc WSGI drops double-slashes.
    path = path.replace(':/', '://')

    msg = "handling entry={}, alias={}, path={}"
    log.debug(msg.format(entry, alias, path))

    pfx = None
    if 'entities' not in alias:
        pfx = request.registry.aliases.get(alias, None)
        if pfx is None:
            log.debug("alias {} not found - passing to storage lookup".format(alias))
            path=alias #treat as path

    # content_negotiation_policy is one of three values:
    # 1. extension - current default, inspect the path and if it ends in
    #    an extension, e.g. .xml or .json, always strip off the extension to
    #    get the entityID and if no accept header or a wildcard header, then
    #    use the extension to determine the return Content-Type.
    #
    # 2. adaptive - only if no accept header or if a wildcard, then inspect
    #    the path and if it ends in an extension strip off the extension to
    #    get the entityID and use the extension to determine the return
    #    Content-Type.
    #
    # 3. header - future default, do not inspect the path for an extension and
    #    use only the Accept header to determine the return Content-Type.
    policy = config.content_negotiation_policy

    # TODO - sometimes the client sends > 1 accept header value with ','.
    accept = str(request.accept).split(',')[0]
    valid_accept = accept and not ('application/*' in accept or 'text/*' in accept or '*/*' in accept)

    new_path: Optional[str] = path
    path_no_extension, extension = _d(new_path, True)
    accept_from_extension = accept
    if extension:
        accept_from_extension = _ctypes.get(extension, accept)

    if policy == 'extension':
        new_path = path_no_extension
        if not valid_accept:
            accept = accept_from_extension
    elif policy == 'adaptive':
        if not valid_accept:
            new_path = path_no_extension
            accept = accept_from_extension

    if not accept:
        log.warning('Could not determine accepted response type')
        raise exc.exception_response(400)

    q: Optional[str]
    if pfx and new_path:
        q = f'{{{pfx}}}{new_path}'
        new_path = f'/{alias}/{new_path}'
    else:
        q = new_path

    try:
        accepter = MediaAccept(accept)
        for p in request.registry.plumbings:
            state = {
                entry: True,
                'headers': {'Content-Type': None},
                'accept': accepter,
                'url': request.current_route_url(),
                'select': q,
                'match': match.lower() if match else match,
                'path': new_path,
                'stats': {},
            }

            r = p.process(request.registry.md, state=state, raise_exceptions=True, scheduler=request.registry.scheduler)
            log.debug(f'Plumbing process result: {r}')
            if r is None:
                r = []

            response = Response()
            _headers = state.get('headers', {})
            response.headers.update(_headers)
            ctype = _headers.get('Content-Type', None)
            if not ctype:
                r, t = _fmt(r, accepter)
                ctype = t

            response.text = b2u(r)
            response.size = len(r)
            response.content_type = ctype
            cache_ttl = int(state.get('cache', 0))
            response.expires = datetime.now() + timedelta(seconds=cache_ttl)
            return response
    except ResourceException as ex:
        import traceback

        log.debug(traceback.format_exc())
        log.warning(f'Exception from processing pipeline: {ex}')
        raise exc.exception_response(409)
    except BaseException as ex:
        import traceback

        log.debug(traceback.format_exc())
        log.error(f'Exception from processing pipeline: {ex}')
        raise exc.exception_response(500)

    if request.method == 'GET':
        raise exc.exception_response(404)


def webfinger_handler(request: Request) -> Response:
    """An implementation the webfinger protocol
    (http://tools.ietf.org/html/draft-ietf-appsawg-webfinger-12)
    in order to provide information about up and downstream metadata available at
    this pyFF instance.

    Example:

    .. code-block:: bash

    # curl http://my.org/.well-known/webfinger?resource=http://my.org

    This should result in a JSON structure that looks something like this:

    .. code-block:: json

    {
     "expires": "2013-04-13T17:40:42.188549",
     "links": [
     {
      "href": "http://reep.refeds.org:8080/role/sp.xml",
      "rel": "urn:oasis:names:tc:SAML:2.0:metadata"
      },
     {
      "href": "http://reep.refeds.org:8080/role/sp.json",
      "rel": "disco-json"
      }
     ],
     "subject": "http://reep.refeds.org:8080"
    }

    Depending on which version of pyFF you're running and the configuration you
    may also see downstream metadata listed using the 'role' attribute to the link
    elements.
    """

    resource = request.params.get('resource', None)
    rel = request.params.get('rel', None)

    if resource is None:
        resource = request.host_url

    jrd: Dict[str, Any] = dict()
    dt = datetime.now() + timedelta(hours=1)
    jrd['expires'] = dt.isoformat()
    jrd['subject'] = request.host_url
    links: List[Dict[str, Any]] = list()
    jrd['links'] = links

    _dflt_rels = {
        'urn:oasis:names:tc:SAML:2.0:metadata': ['.xml', 'application/xml'],
        'disco-json': ['.json', 'application/json'],
    }

    if rel is None or len(rel) == 0:
        rel = _dflt_rels.keys()
    else:
        rel = [rel]

    def _links(url: str, title: Any = None) -> None:
        if url.startswith('/'):
            url = url.lstrip('/')
        for r in rel:
            suffix = ""
            if not url.endswith('/'):
                suffix = _dflt_rels[r][0]
            links.append(dict(rel=r, type=_dflt_rels[r][1], href='%s/%s%s' % (request.host_url, url, suffix)))

    _links('/entities/')
    for a in request.registry.md.store.collections():
        if a is not None and '://' not in a:
            _links(a)

    for entity in request.registry.md.store.lookup('entities'):
        entity_display = entity_display_name(entity)
        _links("/entities/%s" % hash_id(entity.get('entityID')), title=entity_display)

    aliases = request.registry.aliases
    for a in aliases.keys():
        for v in request.registry.md.store.attribute(aliases[a]):
            _links('%s/%s' % (a, quote_plus(v)))

    response = Response(dumps(jrd, default=json_serializer))
    response.headers['Content-Type'] = 'application/json'

    return response


def resources_handler(request: Request) -> Response:
    """
    Implements the /api/resources endpoint

    :param request: the HTTP request
    :return: a JSON representation of the set of resources currently loaded by the server
    """

    def _infos(resources: Iterable[Resource]) -> List[Mapping[str, Any]]:
        return [_info(r) for r in resources if r.info.state is not None]

    def _info(r: Resource) -> Mapping[str, Any]:
        nfo = r.info.to_dict()
        nfo['Valid'] = r.is_valid()
        nfo['Parser'] = r.last_parser
        if r.last_seen is not None:
            nfo['Last Seen'] = r.last_seen
        if len(r.children) > 0:
            nfo['Children'] = _infos(r.children)

        return nfo

    response = Response(dumps(_infos(request.registry.md.rm.children), default=json_serializer))
    response.headers['Content-Type'] = 'application/json'

    return response


def pipeline_handler(request: Request) -> Response:
    """
    Implements the /api/pipeline endpoint

    :param request: the HTTP request
    :return: a JSON representation of the active pipeline
    """
    response = Response(dumps(request.registry.plumbings, default=json_serializer))
    response.headers['Content-Type'] = 'application/json'

    return response


def search_handler(request: Request) -> Response:
    """
    Implements the /api/search endpoint

    :param request: the HTTP request with the 'query' request parameter
    :return: a JSON search result
    """
    match = request.params.get('q', request.params.get('query', ""))

    # Enable matching on scope.
    match = match.split('@').pop() if match and not match.endswith('@') else match

    entity_filter = request.params.get('entity_filter', '{http://pyff.io/role}idp')
    log.debug("match={}".format(match))
    store = request.registry.md.store

    def _response() -> Generator[bytes, bytes, None]:
        yield b('[')
        in_loop = False
        entities = store.search(query=match.lower(), entity_filter=entity_filter)
        for e in entities:
            if in_loop:
                yield b(',')
            yield b(dumps(e))
            in_loop = True
        yield b(']')

    response = Response(content_type='application/json')
    response.app_iter = _response()
    return response


def add_cors_headers_response_callback(event: NewRequest) -> None:
    def cors_headers(request: Request, response: Response) -> None:
        response.headers.update(
            {
                'Access-Control-Allow-Origin': '*',
                'Access-Control-Allow-Methods': 'POST,GET,DELETE,PUT,OPTIONS',
                'Access-Control-Allow-Headers': ('Origin, Content-Type, Accept, ' 'Authorization'),
                'Access-Control-Allow-Credentials': 'true',
                'Access-Control-Max-Age': '1728000',
            }
        )

    event.request.add_response_callback(cors_headers)


def launch_memory_usage_server(port: int = 9002) -> None:
    import cherrypy
    import dowser

    cherrypy.tree.mount(dowser.Root())
    cherrypy.config.update({'environment': 'embedded', 'server.socket_port': port})

    cherrypy.engine.start()


def mkapp(*args: Any, **kwargs: Any) -> Any:
    md = kwargs.pop('md', None)
    if md is None:
        md = MDRepository()

    if config.devel_memory_profile:
        launch_memory_usage_server()

    with Configurator(debug_logger=log) as ctx:
        ctx.add_subscriber(add_cors_headers_response_callback, NewRequest)

        if config.aliases is None:
            config.aliases = dict()

        if config.modules is None:
            config.modules = []

        ctx.registry.config = config
        config.modules.append('pyff.builtins')
        for mn in config.modules:
            importlib.import_module(mn)

        pipeline = None
        if args:
            pipeline = list(args)
        if pipeline is None and config.pipeline:
            pipeline = [config.pipeline]

        ctx.registry.scheduler = md.scheduler
        if pipeline is not None:
            ctx.registry.pipeline = pipeline
            ctx.registry.plumbings = [plumbing(v) for v in pipeline]
        ctx.registry.aliases = config.aliases
        ctx.registry.md = md
        if config.caching_enabled:
            ctx.registry.cache = TTLCache(config.cache_size, config.cache_ttl)
        else:
            ctx.registry.cache = NoCache()

        ctx.add_route('robots', '/robots.txt')
        ctx.add_view(robots_handler, route_name='robots')

        ctx.add_route('webfinger', '/.well-known/webfinger', request_method='GET')
        ctx.add_view(webfinger_handler, route_name='webfinger')

        ctx.add_route('search', '/api/search', request_method='GET')
        ctx.add_view(search_handler, route_name='search')

        ctx.add_route('status', '/api/status', request_method='GET')
        ctx.add_view(status_handler, route_name='status')

        ctx.add_route('resources', '/api/resources', request_method='GET')
        ctx.add_view(resources_handler, route_name='resources')

        ctx.add_route('pipeline', '/api/pipeline', request_method='GET')
        ctx.add_view(pipeline_handler, route_name='pipeline')

        ctx.add_route('call', '/api/call/{entry}', request_method=['POST', 'PUT'])
        ctx.add_view(process_handler, route_name='call')

        ctx.add_route('request', '/*path', request_method='GET')
        ctx.add_view(request_handler, route_name='request')

        start = utc_now() + timedelta(seconds=1)
        if config.update_frequency > 0:
            ctx.registry.scheduler.add_job(
                call,
                'interval',
                id="call/update",
                args=['update'],
                start_date=start,
                misfire_grace_time=10,
                seconds=config.update_frequency,
                replace_existing=True,
                max_instances=1,
                timezone=pytz.utc,
            )

        return ctx.make_wsgi_app()