MediaMath/t1-python

View on GitHub
terminalone/service.py

Summary

Maintainability
D
2 days
Test Coverage
# -*- coding: utf-8 -*-
"""Provides service object for T1."""

from __future__ import absolute_import, division
from collections import Iterator
from types import GeneratorType
from .models import ACL
from .t1mappings import SINGULAR, CLASSES, CHILD_PATHS, MODEL_PATHS
from .connection import Connection
from .entity import Entity
from .errors import ClientError
from .reports import Report
from .utils import filters
from .vendor import six


def _detect_auth_method(username, password, session_id,
                        api_key, client_id, client_secret, token):
    if client_id is not None and client_secret is not None:
        return 'oauth2-resourceowner'
    else:
        return 'cookie'


class T1(Connection):
    """Service class for ALL other T1 entities, e.g.: t1 = T1(auth).

    Accepts authentication parameters. Supports get methods to get
    collections or an entity, find method to user inner-join-like queries.
    """

    def __init__(self,
                 username=None,
                 password=None,
                 api_key=None,
                 client_id=None,
                 client_secret=None,
                 auth_method=None,
                 session_id=None,
                 environment='production',
                 api_base=None,
                 json=False,
                 redirect_uri=None,
                 token=None,
                 token_updater=None,
                 access_token=None,
                 realm=None,
                 scope=None,
                 **kwargs):
        """Set up session for main service object.

        :param username: str T1 Username
        :param password: str T1 Password
        :param api_key: str API Key approved in Developer Portal
        :param client_secret: str Client Secret for use with OAuth2
        :param session_id: str API-provided prior session cookie.
            For instance, if you have a session ID provided by browser cookie,
            you can use that to authenticate a server-side connection.
        :param auth_method: enum('cookie', 'oauth2') method for authentication.
        :param environment: str to look up API Base to use. e.g. 'production'
            for https://api.mediamath.com/api/v2.0
        :param api_base: str API domain. should be the qualified domain name
            without trailing slash. e.g. "api.mediamath.com".
        :param json: bool use JSON header for serialization. Currently
            for internal experimentation, JSON will become the default in a
            future version.
        :param redirect_uri: str redirect URI for OAuth2 authentication.
            Must match the redirect URI set in the application settings.
        :param token: dict OAuth2 token as generated by OAuth2Session.
            If you have a web app, you can store the token in the browser
            session, and then use that to generate a new T1 session.
            See the documentation for examples.
        :param token_updater: function with one argument, token, to be used to
            update your token databse on automatic token refresh. If not
            set, a TokenUpdated warning will be raised when a token
            has been refreshed. This warning will carry the token
            in its token argument.
        """
        self.auth_params = {}
        if auth_method is None:
            auth_method = _detect_auth_method(username,
                                              password,
                                              session_id,
                                              api_key,
                                              client_id,
                                              client_secret,
                                              token)
        self.auth_params['method'] = auth_method
        self.auth_params['api_key'] = api_key

        if auth_method == 'oauth2':
            self.auth_params.update({
                'client_secret': client_secret,
                'redirect_uri': redirect_uri,
                'token': token,
                'token_updater': token_updater
            })
        elif auth_method == 'oauth2-resourceowner':
            self.auth_params.update({
                'username': username,
                'password': password,
                'client_secret': client_secret,
                'client_id': client_id
            })
        else:
            self.auth_params.update({
                'username': username,
                'password': password
            })

        super(T1, self).__init__(environment, api_base=api_base,
                                 json=json,
                                 auth_params=self.auth_params,
                                 _create_session=True, **kwargs)

        self._authenticated = False
        self._auth = (username, password, api_key, client_secret)
        self.environment = environment
        self.realm = realm
        self.scope = scope
        self.json = json
        self.api_key = api_key

        if auth_method != 'oauth2' and auth_method != 'delayed':
            self.authenticate(auth_method, session_id=session_id, access_token=access_token)

    def authenticate(self, auth_method, **kwargs):
        """Authenticate using method given."""
        session_id = kwargs.get('session_id')
        access_token = kwargs.get('access_token')
        if session_id is not None and auth_method in ['cookie',
                                                      'oauth2-resourceowner',
                                                      'oauth2-existingaccesstoken']:
            return super(T1, self)._auth_session_id(
                session_id,
                self.auth_params['api_key']
            )

        if access_token is not None and auth_method in ['oauth2-resourceowner',
                                                        'oauth2-existingaccesstoken']:
            return super(T1, self)._auth_access_token(access_token)

        if auth_method == 'cookie':
            return super(T1, self)._auth_cookie(self.auth_params['username'],
                                                self.auth_params['password'],
                                                self.auth_params['api_key'])
        elif auth_method in ['oauth2-resourceowner',
                             'oauth2-existingaccesstoken']:
            return super(T1, self).fetch_resource_owner_password_token(
                self.auth_params['username'],
                self.auth_params['password'],
                self.auth_params['client_id'],
                self.auth_params['client_secret'],
                self.environment,
                self.realm,
                self.scope)
        elif auth_method == 'basic':
            raise ClientError(
                'basic authentication is not supported')
        else:
            raise AttributeError('No authentication method for ' + auth_method)

    def new(self, collection, report=None, properties=None, version=None, *args, **kwargs):
        """Return a fresh class instance for a new entity.

        ac = t1.new('atomic_creative') OR
        ac = t1.new('atomic_creatives') OR even
        ac = t1.new(terminalone.models.AtomicCreative)
        """
        if type(collection) == type and issubclass(collection, Entity):
            ret = collection
        elif '_acl' in collection:
            ret = ACL
        else:
            try:
                ret = SINGULAR[collection]
            except KeyError:
                ret = CLASSES[collection]

        if ret == Report:
            return ret(self.session,
                       report=report,
                       environment=self.environment,
                       api_base=self.api_base,
                       version=version,
                       **kwargs)

        return ret(self.session,
                   environment=self.environment,
                   api_base=self.api_base,
                   properties=properties,
                   json=self.json,
                   *args, **kwargs)

    def _return_class(self, ent_dict,
                      child=None, child_id=None,
                      entity_id=None, collection=None):
        """Generate item for new class instantiation."""
        ent_type = ent_dict.get('_type', ent_dict.get('type'))
        relations = ent_dict.get('relations')
        if child is not None:
            # Child can be either a target dimension (with an ID) or
            # a bare child, like concepts or permissions. These should not
            # have an ID passed in.
            if child_id is not None:
                ent_dict['id'] = child_id
            ent_dict['parent_id'] = entity_id
            ent_dict['parent'] = collection
        if relations is not None:
            for rel_name, data in six.iteritems(relations):
                if isinstance(data, list):
                    ent_dict[rel_name] = []
                    for cls in data:
                        ent_dict[rel_name].append(self._return_class(cls))
                else:
                    ent_dict[rel_name] = self._return_class(data)
            ent_dict.pop('relations', None)
        return self.new(ent_type, properties=ent_dict)

    def _gen_classes(self, entities, child, child_id, entity_id, collection):
        """Iterate over entities, returning objects for each."""
        for entity in entities:
            e = self._return_class(
                entity, child, child_id, entity_id, collection)
            yield e

    @staticmethod
    def _construct_params(entity, **kwargs):
        """Construct URL params."""
        if entity is not None:
            params = {}
        else:
            params = {'page_limit': kwargs.get('page_limit'),
                      'page_offset': kwargs.get('page_offset'),
                      'sort_by': kwargs.get('sort_by'),
                      'parent': kwargs.get('parent'),
                      'q': kwargs.get('query'), }

        # include can be either a string (e.g. 'advertiser'),
        # list of *non-traversable* relations (e.g. ['vendor', 'concept']),
        # or a list of lists/strings of traversable elements, e.g.
        # [['advertiser', 'agency'], 'vendor'],
        # [['advertiser', 'agency'], ['vendor', 'vendor_domains']]
        # If we're given a string, leave it as-is
        # If we're given a list, for each element:
        # -> If the item is a string, leave it as-is
        # -> If the item is a list, comma-join it
        # Examples from above:
        # include='advertiser' -> with=advertiser
        # include=['vendor', 'concept'] -> with=vendor&with=concept
        # include=[['advertiser', 'agency'], 'vendor']
        # -> with=advertiser,agency&with=vendor
        # include=[['advertiser', 'agency'], ['vendor', 'vendor_domains']]
        # -> with=advertiser,agency&with=vendor,vendor_domains
        include = kwargs.get('include')
        if include:
            if isinstance(include, list):
                for i, item in enumerate(include):
                    if isinstance(item, list):
                        include[i] = ','.join(item)
            params['with'] = include

        full = kwargs.get('full')
        if isinstance(full, list):
            params['full'] = ','.join(full)
        elif full is True:
            params['full'] = '*'
        elif full is not None:
            params['full'] = full

        params.update(kwargs.get('other_params', {}))

        return params

    @staticmethod
    def _construct_url(collection, entity, child, limit):
        """Construct URL."""
        url = [collection, ]
        if entity is not None:
            url.append(str(entity))  # str so that we can use join

        child_id = None
        if child is not None:
            try:
                child_path = CHILD_PATHS[child.lower()]
            except AttributeError:
                raise ClientError(
                    "`child` must be a string of the entity to retrieve")
            except KeyError:
                raise ClientError("`child` must correspond to an entity in T1")
            # child_path should always be a tuple of (path, id). For children
            # that do not have IDs, like concepts and permissions, ID is 0
            if child_path[1]:
                child_id = child_path[1]
                url.append(child_path[0])
                # All values need to be strings for join
                url.append(str(child_path[1]))
            else:
                url.append(child_path[0])

        if isinstance(limit, dict):
            if len(limit) != 1:
                raise ClientError(
                    'Limit must consist of one parent collection '
                    '(or chained parent collection) and a single '
                    'value for it (e.g. {"advertiser": 1}, or '
                    '{"advertiser.agency": 2)')

            if isinstance(list(limit.values())[0], int):
                url.extend(['limit',
                            '{0!s}={1:d}'.format(*next(six.iteritems(limit)))])
            else:
                url.extend(['limit',
                            '{0!s}={1:s}'.format(*next(six.iteritems(limit)))])

        return '/'.join(url), child_id

    def get(self,
            collection,
            entity=None,
            child=None,
            limit=None,
            include=None,
            full=None,
            page_limit=100,
            page_offset=0,
            sort_by='id',
            get_all=False,
            parent=None,
            query=None,
            other_params={},
            count=False,
            _url=None,
            _params=None):
        """Main retrieval method for T1 Entities.

        :param collection: str T1 collection, e.g. "advertisers", "agencies"
        :param entity: int ID of entity being retrieved from T1
        :param child: str child, e.g. "dma", "acl"
        :param limit: dict[str]int query for relation entity,
        e.g. {"advertiser": 123456}
        :param include: str/list of relations to include, e.g. "advertiser",
            ["campaign", "advertiser"]
        :param full: str/bool when retrieving multiple entities, specifies
            which types to return the full record for.
            e.g. "campaign", True, ["campaign", "advertiser"]
        :param page_limit: int number of entities to return per query, 100 max
        :param page_offset: int offset for results returned.
        :param sort_by: str sort order. Default "id". e.g. "-id", "name"
        :param get_all: bool whether to retrieve all results for a query
            or just a single page
        :param parent: only return entities with this parent id
        :param query: str search parameter. Invoked by `find`
        :param other_params: optional dict of additional service
            specific params
        :param count: bool return the number of entities as a second parameter
        :param _url: str shortcut to bypass URL determination.
        :param _params: dict query string parameters to bypass
            query determination
        :return: If:
            Collection is requested => generator over collection of entity
                objects
            Entity ID is provided => Entity object
            `count` is True => number of entities as second return val
        :raise ClientError: if page_limit > 100
        """
        if type(collection) == type and issubclass(collection, Entity):
            collection = MODEL_PATHS[collection]

        child_id = None
        if _url is None:
            _url, child_id = self._construct_url(
                collection, entity, child, limit)

        # some child endpoints need to have full overridden to ensure correct behaviour
        if child:
            full = True

        if get_all:
            gen = self._get_all(collection,
                                entity=entity,
                                child=child,
                                include=include,
                                full=full,
                                sort_by=sort_by,
                                parent=parent,
                                query=query,
                                page_limit=page_limit,
                                count=count,
                                other_params=other_params,
                                _params=_params,
                                _url=_url)
            if count:
                ent_count = next(gen)
                return gen, ent_count
            else:
                return gen

        if _params is None:
            _params = self._construct_params(entity,
                                             include=include,
                                             full=full,
                                             page_limit=page_limit,
                                             page_offset=page_offset,
                                             sort_by=sort_by,
                                             parent=parent,
                                             query=query,
                                             other_params=other_params)

        entities, ent_count = super(T1, self)._get(
            self._get_service_path(collection), _url, params=_params)

        if not isinstance(
            entities, GeneratorType
        ) and not isinstance(
            entities, Iterator
        ):
            return self._return_class(entities,
                                      child,
                                      child_id,
                                      entity,
                                      collection)

        ent_gen = self._gen_classes(
            entities, child, child_id, entity, collection)
        if count:
            return ent_gen, ent_count
        else:
            return ent_gen

    def get_all(self, collection, **kwargs):
        """Retrieve all entities in a collection.

        Has same signature as .get.
        """
        kwargs.pop('get_all', None)
        return self.get(collection, get_all=True, **kwargs)

    def _get_all(self, collection, **kwargs):
        """Construct iterator to get all entities in a collection.

        Pages over 100 entities.
        This method should not be called directly: it's called from T1.get.
        """
        num_to_fetch = kwargs.get('page_limit', 100)
        params = {
            'page_limit': 1,
            'parent': kwargs.get('parent'),
            'q': kwargs.get('query'),
        }
        if kwargs.get('other_params'):
            params.update(kwargs.get('other_params'))

        _, num_recs = super(T1, self)\
            ._get(self._get_service_path(collection), kwargs['_url'], params=params)

        if kwargs.get('count'):
            yield num_recs
        for page_offset in six.moves.range(0, num_recs, num_to_fetch):
            # get_all=False, otherwise we could go in a loop
            gen = self.get(collection,
                           _url=kwargs['_url'],
                           entity=kwargs.get('entity'),
                           include=kwargs.get('include'),
                           full=kwargs.get('full'),
                           page_offset=page_offset,
                           sort_by=kwargs.get('sort_by'),
                           page_limit=num_to_fetch,
                           parent=kwargs.get('parent'),
                           query=kwargs.get('query'),
                           other_params=kwargs.get('other_params'),
                           get_all=False)
            if not isinstance(gen, GeneratorType):
                gen = iter([gen])
            for item in gen:
                yield item

    # def get_sub(self, collection, entity, sub, *args):
    #   pass

    @staticmethod
    def _parse_candidate(candidate):
        """Parse filter candidates so that you can use None, True, False."""
        val = candidate
        if candidate is None:
            val = "null"
        elif candidate is True:
            val = "1"
        elif candidate is False:
            val = "0"
        return val

    def find(self, collection, variable, operator, candidates, **kwargs):
        """Find objects based on query criteria.

        Helper method for T1.get, with same return values.

        :param collection: str T1 collection, e.g. "advertisers", "agencies"
        :param variable: str Field to query for, e.g. "name". If operator is
            terminalone.filters.IN, this is ignored and None can be provided
        :param operator: str Arithmetic operator, e.g. "=:". Package provides
            helper object filters to help, e.g. terminalone.filters.IN or
            terminalone.filters.CASE_INS_STRING
        :param candidates: str/int/list values to search for. list only if
            operator is IN.
        :param kwargs: additional keyword args to pass on to T1.get. See that
            method's signature for details.
        :return: generator over collection of objects matching query
        :raise TypeError: if operator is IN and candidates not provided as list
        """
        if operator == filters.IN:
            if not isinstance(candidates, list):
                raise TypeError(
                    '`candidates` must be list of entities for `IN`')
            query = '(' + ','.join(str(c) for c in candidates) + ')'
        else:
            query = operator.join(
                [variable, self._parse_candidate(candidates)])
        return self.get(collection, query=query, **kwargs)


T1Service = T1