knipknap/django-find

View on GitHub
django_find/serializers/sql.py

Summary

Maintainability
A
1 hr
Test Coverage
from builtins import str
from collections import defaultdict, OrderedDict
from MySQLdb import escape_string
from ..refs import get_join_for
from .serializer import Serializer
from .util import parse_date, parse_datetime

int_op_map = {
    'equals': 'equals',
    'contains': 'equals',
    'startswith': 'gte',
    'endswith': 'lte'
}

str_op_map = {
    'gt': 'startswith',
    'gte': 'startswith',
    'lt': 'endswith',
    'lte': 'endswith'
}

date_op_map = {
    'contains': 'equals',
    'startswith': 'gte',
    'endswith': 'lte'
}

operator_map = {
    'equals': "='{}'",
    'iequals': " LIKE '{}'",
    'lt': "<'{}'",
    'lte': "<='{}'",
    'gt': ">'{}'",
    'gte': ">='{}'",
    'startswith': " LIKE '{}%%'",
    'endswith': " LIKE '%%{}'",
    'contains': " LIKE '%%{}%%'",
    'regex': " REGEXP '%%{}%%'"
}

def _mkcol(tbl, name, alias):
    return tbl+'.'+name+' '+tbl+'_'+alias

def _mk_condition(db_column, operator, data):
    op = operator_map.get(operator)
    if not op:
        raise Exception('unsupported operator:' + str(operator))

    # I would prefer to use a prepared statement, but collecting arguments
    # and passing them back along the string everywhere would be awful design.
    # (Also, I didn't find any API from Django to generate a prepared statement
    # without already executing it, e.g. django.db.connection.execute())
    if isinstance(data, int):
        return db_column+op.format(data)
    return db_column+op.format(escape_string(data).decode('utf-8'))

class SQLSerializer(Serializer):
    def __init__(self, model, mode='SELECT', fullnames=None, extra_model=None):
        modes = 'SELECT', 'WHERE'
        if mode not in modes:
            raise AttributeError('invalid mode: {}. Must be one of {}'.format(mode, modes))
        Serializer.__init__(self)
        self.model = model
        self.mode = mode
        self.fullnames = fullnames
        self.extra_model = extra_model

    def _create_db_column_list(self, dom):
        fullnames = self.fullnames if self.fullnames else dom.get_term_names()
        result = []
        for fullname in fullnames:
            model, alias = self.model.get_class_from_fullname(fullname)
            selector = model.get_selector_from_alias(alias)
            target_model, field = model.get_field_from_selector(selector)
            result.append((target_model, target_model._meta.db_table, field.column))
        return result

    def _create_select(self, fields):
        # Create the "SELECT DISTINCT table1.col1, table2.col2, ..."
        # part of the SQL.
        col_numbers = defaultdict(int)
        fullfields = []
        for field in fields:
            table, column = field[1:3]
            key = "%s.%s" % (table, column)
            col_number = col_numbers[key]
            col_numbers[key] += 1
            if len(field) == 3:
                field = (field[0], table, column, column if col_number == 0 else "%s__%d" % (column, col_number))
            fullfields.append(field)

        select = 'SELECT DISTINCT '+_mkcol(fullfields[0][1], fullfields[0][2], fullfields[0][3])
        for target_model, table, column, alias in fullfields[1:]:
            select += ', '+_mkcol(table, column, alias)

        # Find the best way to join the tables.
        target_models = [r[0] for r in fullfields]
        if self.extra_model:
            target_models.append(self.extra_model)
        vector = self.model.get_object_vector_for(target_models)
        join_path = get_join_for(vector)

        # Create the "table1 LEFT JOIN table2 ON table1.col1=table2.col1"
        # part of the SQL.
        select += ' FROM '+join_path[0][0]
        for table, left, right in join_path[1:]:
            select += ' LEFT JOIN {} ON {}={}'.format(table,
                                                      table+'.'+left,
                                                      right)
        return select

    def logical_root_group(self, root_group, terms):
        fields = self._create_db_column_list(root_group)

        # Create the SELECT part of the query.
        if self.mode == 'SELECT':
            select = self._create_select(fields)+' WHERE '
        else:
            select = ''

        where = (' AND '.join(terms) if terms else '1')
        if where.startswith('(') and where.endswith(')'):
            select += where
        else:
            select += '('+where+')'
        return select, []

    def logical_group(self, terms):
        terms = [t for t in terms if t]
        if not terms:
            return ''
        return ' AND '.join(terms)

    def logical_and(self, terms):
        terms = [t for t in terms if t]
        if not terms:
            return '()'
        return '(' + self.logical_group(terms) + ')'

    def logical_or(self, terms):
        terms = [t for t in terms if t]
        if not terms:
            return ''
        return '(' + ' OR '.join(terms) + ')'

    def logical_not(self, terms):
        if not terms:
            return ''
        if len(terms) == 1:
            return 'NOT(' + terms[0] + ')'
        return 'NOT ' + self.logical_and(terms)

    def boolean_term(self, db_column, operator, data):
        value = 'TRUE' if data.lower() == 'true' else 'FALSE'
        return _mk_condition(db_column, operator, value)

    def int_term(self, db_column, operator, data):
        try:
            value = int(data)
        except ValueError:
            return '1'
        operator = int_op_map.get(operator, operator)
        return _mk_condition(db_column, operator, value)

    def str_term(self, db_column, operator, data):
        operator = str_op_map.get(operator, operator)
        return _mk_condition(db_column, operator, data)

    def lcstr_term(self, db_column, operator, data):
        operator = str_op_map.get(operator, operator)
        if operator == 'equals':
            operator = 'iequals'
        return _mk_condition(db_column, operator, data.lower())

    def date_datetime_common(self, db_column, operator, thedatetime):
        if not thedatetime:
            return ''
        operator = date_op_map.get(operator, operator)
        return _mk_condition(db_column, operator, thedatetime.isoformat())

    def date_term(self, db_column, operator, data):
        thedate = parse_date(data)
        return self.date_datetime_common(db_column, operator, thedate)

    def datetime_term(self, db_column, operator, data):
        thedatetime = parse_datetime(data)
        return self.date_datetime_common(db_column, operator, thedatetime)

    def term(self, term_name, operator, data):
        if operator == 'any':
            return '1'

        model, alias = self.model.get_class_from_fullname(term_name)
        selector = model.get_selector_from_alias(alias)
        target_model, field = model.get_field_from_selector(selector)
        db_column = target_model._meta.db_table + '.' + field.column
        handler = model.get_field_handler_from_alias(alias)

        type_map = {'BOOL': self.boolean_term,
                    'INT': self.int_term,
                    'STR': self.str_term,
                    'LCSTR': self.lcstr_term,
                    'DATE': self.date_term,
                    'DATETIME': self.datetime_term}

        func = type_map.get(handler.db_type)
        if not func:
            raise TypeError('unsupported field type: '+repr(field_type))
        return func(db_column, operator, handler.prepare(data))