django/django

View on GitHub
django/db/backends/mysql/operations.py

Summary

Maintainability
B
4 hrs
Test Coverage
import uuid

from django.conf import settings
from django.db.backends.base.operations import BaseDatabaseOperations
from django.utils import timezone
from django.utils.duration import duration_microseconds
from django.utils.encoding import force_str


class DatabaseOperations(BaseDatabaseOperations):
    compiler_module = "django.db.backends.mysql.compiler"

    # MySQL stores positive fields as UNSIGNED ints.
    integer_field_ranges = {
        **BaseDatabaseOperations.integer_field_ranges,
        'PositiveSmallIntegerField': (0, 65535),
        'PositiveIntegerField': (0, 4294967295),
    }
    cast_data_types = {
        'AutoField': 'signed integer',
        'BigAutoField': 'signed integer',
        'CharField': 'char(%(max_length)s)',
        'DecimalField': 'decimal(%(max_digits)s, %(decimal_places)s)',
        'TextField': 'char',
        'IntegerField': 'signed integer',
        'BigIntegerField': 'signed integer',
        'SmallIntegerField': 'signed integer',
        'PositiveIntegerField': 'unsigned integer',
        'PositiveSmallIntegerField': 'unsigned integer',
    }
    cast_char_field_without_max_length = 'char'
    explain_prefix = 'EXPLAIN'

    def date_extract_sql(self, lookup_type, field_name):
        # https://dev.mysql.com/doc/mysql/en/date-and-time-functions.html
        if lookup_type == 'week_day':
            # DAYOFWEEK() returns an integer, 1-7, Sunday=1.
            # Note: WEEKDAY() returns 0-6, Monday=0.
            return "DAYOFWEEK(%s)" % field_name
        elif lookup_type == 'week':
            # Override the value of default_week_format for consistency with
            # other database backends.
            # Mode 3: Monday, 1-53, with 4 or more days this year.
            return "WEEK(%s, 3)" % field_name
        elif lookup_type == 'iso_year':
            # Get the year part from the YEARWEEK function, which returns a
            # number as year * 100 + week.
            return "TRUNCATE(YEARWEEK(%s, 3), -2) / 100" % field_name
        else:
            # EXTRACT returns 1-53 based on ISO-8601 for the week number.
            return "EXTRACT(%s FROM %s)" % (lookup_type.upper(), field_name)

    def date_trunc_sql(self, lookup_type, field_name):
        fields = {
            'year': '%%Y-01-01',
            'month': '%%Y-%%m-01',
        }  # Use double percents to escape.
        if lookup_type in fields:
            format_str = fields[lookup_type]
            return "CAST(DATE_FORMAT(%s, '%s') AS DATE)" % (field_name, format_str)
        elif lookup_type == 'quarter':
            return "MAKEDATE(YEAR(%s), 1) + INTERVAL QUARTER(%s) QUARTER - INTERVAL 1 QUARTER" % (
                field_name, field_name
            )
        elif lookup_type == 'week':
            return "DATE_SUB(%s, INTERVAL WEEKDAY(%s) DAY)" % (
                field_name, field_name
            )
        else:
            return "DATE(%s)" % (field_name)

    def _convert_field_to_tz(self, field_name, tzname):
        if settings.USE_TZ and self.connection.timezone_name != tzname:
            field_name = "CONVERT_TZ(%s, '%s', '%s')" % (field_name, self.connection.timezone_name, tzname)
        return field_name

    def datetime_cast_date_sql(self, field_name, tzname):
        field_name = self._convert_field_to_tz(field_name, tzname)
        return "DATE(%s)" % field_name

    def datetime_cast_time_sql(self, field_name, tzname):
        field_name = self._convert_field_to_tz(field_name, tzname)
        return "TIME(%s)" % field_name

    def datetime_extract_sql(self, lookup_type, field_name, tzname):
        field_name = self._convert_field_to_tz(field_name, tzname)
        return self.date_extract_sql(lookup_type, field_name)

    def datetime_trunc_sql(self, lookup_type, field_name, tzname):
        field_name = self._convert_field_to_tz(field_name, tzname)
        fields = ['year', 'month', 'day', 'hour', 'minute', 'second']
        format = ('%%Y-', '%%m', '-%%d', ' %%H:', '%%i', ':%%s')  # Use double percents to escape.
        format_def = ('0000-', '01', '-01', ' 00:', '00', ':00')
        if lookup_type == 'quarter':
            return (
                "CAST(DATE_FORMAT(MAKEDATE(YEAR({field_name}), 1) + "
                "INTERVAL QUARTER({field_name}) QUARTER - " +
                "INTERVAL 1 QUARTER, '%%Y-%%m-01 00:00:00') AS DATETIME)"
            ).format(field_name=field_name)
        if lookup_type == 'week':
            return (
                "CAST(DATE_FORMAT(DATE_SUB({field_name}, "
                "INTERVAL WEEKDAY({field_name}) DAY), "
                "'%%Y-%%m-%%d 00:00:00') AS DATETIME)"
            ).format(field_name=field_name)
        try:
            i = fields.index(lookup_type) + 1
        except ValueError:
            sql = field_name
        else:
            format_str = ''.join(format[:i] + format_def[i:])
            sql = "CAST(DATE_FORMAT(%s, '%s') AS DATETIME)" % (field_name, format_str)
        return sql

    def time_trunc_sql(self, lookup_type, field_name):
        fields = {
            'hour': '%%H:00:00',
            'minute': '%%H:%%i:00',
            'second': '%%H:%%i:%%s',
        }  # Use double percents to escape.
        if lookup_type in fields:
            format_str = fields[lookup_type]
            return "CAST(DATE_FORMAT(%s, '%s') AS TIME)" % (field_name, format_str)
        else:
            return "TIME(%s)" % (field_name)

    def date_interval_sql(self, timedelta):
        return 'INTERVAL %s MICROSECOND' % duration_microseconds(timedelta)

    def format_for_duration_arithmetic(self, sql):
        return 'INTERVAL %s MICROSECOND' % sql

    def force_no_ordering(self):
        """
        "ORDER BY NULL" prevents MySQL from implicitly ordering by grouped
        columns. If no ordering would otherwise be applied, we don't want any
        implicit sorting going on.
        """
        return [(None, ("NULL", [], False))]

    def last_executed_query(self, cursor, sql, params):
        # With MySQLdb, cursor objects have an (undocumented) "_executed"
        # attribute where the exact query sent to the database is saved.
        # See MySQLdb/cursors.py in the source distribution.
        # MySQLdb returns string, PyMySQL bytes.
        return force_str(getattr(cursor, '_executed', None), errors='replace')

    def no_limit_value(self):
        # 2**64 - 1, as recommended by the MySQL documentation
        return 18446744073709551615

    def quote_name(self, name):
        if name.startswith("`") and name.endswith("`"):
            return name  # Quoting once is enough.
        return "`%s`" % name

    def random_function_sql(self):
        return 'RAND()'

    def sql_flush(self, style, tables, sequences, allow_cascade=False):
        # NB: The generated SQL below is specific to MySQL
        # 'TRUNCATE x;', 'TRUNCATE y;', 'TRUNCATE z;'... style SQL statements
        # to clear all tables of all data
        if tables:
            sql = ['SET FOREIGN_KEY_CHECKS = 0;']
            for table in tables:
                sql.append('%s %s;' % (
                    style.SQL_KEYWORD('TRUNCATE'),
                    style.SQL_FIELD(self.quote_name(table)),
                ))
            sql.append('SET FOREIGN_KEY_CHECKS = 1;')
            sql.extend(self.sequence_reset_by_name_sql(style, sequences))
            return sql
        else:
            return []

    def validate_autopk_value(self, value):
        # MySQLism: zero in AUTO_INCREMENT field does not work. Refs #17653.
        if value == 0:
            raise ValueError('The database backend does not accept 0 as a '
                             'value for AutoField.')
        return value

    def adapt_datetimefield_value(self, value):
        if value is None:
            return None

        # Expression values are adapted by the database.
        if hasattr(value, 'resolve_expression'):
            return value

        # MySQL doesn't support tz-aware datetimes
        if timezone.is_aware(value):
            if settings.USE_TZ:
                value = timezone.make_naive(value, self.connection.timezone)
            else:
                raise ValueError("MySQL backend does not support timezone-aware datetimes when USE_TZ is False.")
        return str(value)

    def adapt_timefield_value(self, value):
        if value is None:
            return None

        # Expression values are adapted by the database.
        if hasattr(value, 'resolve_expression'):
            return value

        # MySQL doesn't support tz-aware times
        if timezone.is_aware(value):
            raise ValueError("MySQL backend does not support timezone-aware times.")

        return str(value)

    def max_name_length(self):
        return 64

    def bulk_insert_sql(self, fields, placeholder_rows):
        placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
        values_sql = ", ".join("(%s)" % sql for sql in placeholder_rows_sql)
        return "VALUES " + values_sql

    def combine_expression(self, connector, sub_expressions):
        if connector == '^':
            return 'POW(%s)' % ','.join(sub_expressions)
        # Convert the result to a signed integer since MySQL's binary operators
        # return an unsigned integer.
        elif connector in ('&', '|', '<<'):
            return 'CONVERT(%s, SIGNED)' % connector.join(sub_expressions)
        elif connector == '>>':
            lhs, rhs = sub_expressions
            return 'FLOOR(%(lhs)s / POW(2, %(rhs)s))' % {'lhs': lhs, 'rhs': rhs}
        return super().combine_expression(connector, sub_expressions)

    def get_db_converters(self, expression):
        converters = super().get_db_converters(expression)
        internal_type = expression.output_field.get_internal_type()
        if internal_type in ['BooleanField', 'NullBooleanField']:
            converters.append(self.convert_booleanfield_value)
        elif internal_type == 'DateTimeField':
            if settings.USE_TZ:
                converters.append(self.convert_datetimefield_value)
        elif internal_type == 'UUIDField':
            converters.append(self.convert_uuidfield_value)
        return converters

    def convert_booleanfield_value(self, value, expression, connection):
        if value in (0, 1):
            value = bool(value)
        return value

    def convert_datetimefield_value(self, value, expression, connection):
        if value is not None:
            value = timezone.make_aware(value, self.connection.timezone)
        return value

    def convert_uuidfield_value(self, value, expression, connection):
        if value is not None:
            value = uuid.UUID(value)
        return value

    def binary_placeholder_sql(self, value):
        return '_binary %s' if value is not None and not hasattr(value, 'as_sql') else '%s'

    def subtract_temporals(self, internal_type, lhs, rhs):
        lhs_sql, lhs_params = lhs
        rhs_sql, rhs_params = rhs
        if internal_type == 'TimeField':
            if self.connection.mysql_is_mariadb:
                # MariaDB includes the microsecond component in TIME_TO_SEC as
                # a decimal. MySQL returns an integer without microseconds.
                return 'CAST((TIME_TO_SEC(%(lhs)s) - TIME_TO_SEC(%(rhs)s)) * 1000000 AS SIGNED)' % {
                    'lhs': lhs_sql, 'rhs': rhs_sql
                }, lhs_params + rhs_params
            return (
                "((TIME_TO_SEC(%(lhs)s) * 1000000 + MICROSECOND(%(lhs)s)) -"
                " (TIME_TO_SEC(%(rhs)s) * 1000000 + MICROSECOND(%(rhs)s)))"
            ) % {'lhs': lhs_sql, 'rhs': rhs_sql}, lhs_params * 2 + rhs_params * 2
        else:
            return "TIMESTAMPDIFF(MICROSECOND, %s, %s)" % (rhs_sql, lhs_sql), rhs_params + lhs_params

    def explain_query_prefix(self, format=None, **options):
        # Alias MySQL's TRADITIONAL to TEXT for consistency with other backends.
        if format and format.upper() == 'TEXT':
            format = 'TRADITIONAL'
        prefix = super().explain_query_prefix(format, **options)
        if format:
            prefix += ' FORMAT=%s' % format
        if self.connection.features.needs_explain_extended and format is None:
            # EXTENDED and FORMAT are mutually exclusive options.
            prefix += ' EXTENDED'
        return prefix

    def regex_lookup(self, lookup_type):
        # REGEXP BINARY doesn't work correctly in MySQL 8+ and REGEXP_LIKE
        # doesn't exist in MySQL 5.6 or in MariaDB.
        if self.connection.mysql_version < (8, 0, 0) or self.connection.mysql_is_mariadb:
            if lookup_type == 'regex':
                return '%s REGEXP BINARY %s'
            return '%s REGEXP %s'

        match_option = 'c' if lookup_type == 'regex' else 'i'
        return "REGEXP_LIKE(%%s, %%s, '%s')" % match_option

    def insert_statement(self, ignore_conflicts=False):
        return 'INSERT IGNORE INTO' if ignore_conflicts else super().insert_statement(ignore_conflicts)