reimandlab/Visualistion-Framework-for-Genome-Mutations

View on GitHub
website/helpers/filters/sqlalchemy_filter.py

Summary

Maintainability
B
4 hrs
Test Coverage
from types import FunctionType, MethodType

from sqlalchemy import and_, or_
from sqlalchemy.ext.associationproxy import AssociationProxy, AssociationProxyInstance
from sqlalchemy.sql.annotation import AnnotatedSelect
from sqlalchemy.sql.sqltypes import Text
from sqlalchemy.orm import RelationshipProperty

from database.types import ScalarSet
from helpers.utilities import is_iterable_but_not_str

from .basic_filter import BasicFilter


class SQLAlchemyAwareFilter(BasicFilter):
    """Extends python-side filtering by SQL filters generation

    Args:
        as_sqlalchemy:
            True if the filter should be executed on the SQL server side.

            A custom callback can be provided instead.
            The callback should accept a value of the filter as
            an argument and return an SQLAlchemy filter.
            The callback function will be called only if the
            filter is active (i.e. it has a non-default value).
        as_sqlalchemy_joins:
            if a custom as_sqlalchemy callback was provided and it
            requires any joins, the joins can be specified here.
    """

    sa_comparators = {
        'ge': '__ge__',
        'le': '__le__',
        'gt': '__gt__',
        'lt': '__lt__',
        'eq': '__eq__',
        'in': 'in_',
        'ni': 'notin_'
    }

    sa_join_operators = {
        'all': and_,
        'any': or_
    }

    def __init__(self, *args, as_sqlalchemy=None, as_sqlalchemy_joins=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.has_sqlalchemy = bool(as_sqlalchemy)

        if type(as_sqlalchemy) in [FunctionType, MethodType]:
            self.as_sqlalchemy_callback = as_sqlalchemy
        else:
            self.as_sqlalchemy_callback = None

        self.as_sqlalchemy_joins = as_sqlalchemy_joins or []

    def as_sqlalchemy(self, target):

        value = self.mapped_value

        if value is None:
            return None, []

        if self.as_sqlalchemy_callback:
            return self.as_sqlalchemy_callback(value), self.as_sqlalchemy_joins

        path = self.attribute.split('.')

        assert len(path) < 3     # we are unable to query deeper easily

        field = getattr(target, path[0])

        # Possible upgrade:
        #   from sqlalchemy.orm.attributes import QueryableAttribute
        #   if isinstance(field, QueryableAttribute):
        if type(field) is AnnotatedSelect:
            if self.comparator == 'eq':
                return field, []

        # refactored in 1.3, field should be either ColumnAssociationProxyInstance or ObjectAssociationProxyInstance
        assert type(field) is not AssociationProxy

        if isinstance(field, AssociationProxyInstance):
            # additional joins may be needed when using proxies

            joins = []

            while isinstance(field, AssociationProxyInstance):
                joins.append(field.target_class)
                field = field.remote_attr

            if self.comparator == 'in':

                if self.multiple == 'any':
                    # this wont give expected result for 'all'
                    func = getattr(field, self.sa_comparators[self.comparator])
                    return func(value), joins
                else:
                    # this works for 'any' too (but it's uglier)
                    func = getattr(field, '__eq__')

                    comp_func = self.sa_join_operators[self.multiple](
                        *[
                            func(sub_value)
                            for sub_value in value
                        ]
                    )
                    return comp_func, joins

        if len(path) == 2:
            if self.comparator == 'in':
                sub_attr = path[1]
                func = getattr(field, 'any')

                values = value if is_iterable_but_not_str(value) else [value]
                comp_func = self.sa_join_operators[self.multiple](
                    *[
                        func(**{sub_attr: sub_value})
                        for sub_value in values
                    ]
                )
                return comp_func, []

        comparator = self.sa_comparators[self.comparator]

        if self.comparator == 'in':
            if isinstance(field.property, RelationshipProperty):
                comparator = 'contains'
            elif type(field.property.columns[0].type) in [Text, ScalarSet]:
                comparator = 'like'
                value = '%' + value + '%'

        return getattr(field, comparator)(value), []