django/django

View on GitHub
django/db/backends/ddl_references.py

Summary

Maintainability
A
3 hrs
Test Coverage
"""
Helpers to manipulate deferred DDL statements that might need to be adjusted or
discarded within when executing a migration.
"""

from copy import deepcopy


class Reference:
    """Base class that defines the reference interface."""

    def references_table(self, table):
        """
        Return whether or not this instance references the specified table.
        """
        return False

    def references_column(self, table, column):
        """
        Return whether or not this instance references the specified column.
        """
        return False

    def rename_table_references(self, old_table, new_table):
        """
        Rename all references to the old_name to the new_table.
        """
        pass

    def rename_column_references(self, table, old_column, new_column):
        """
        Rename all references to the old_column to the new_column.
        """
        pass

    def __repr__(self):
        return "<%s %r>" % (self.__class__.__name__, str(self))

    def __str__(self):
        raise NotImplementedError(
            "Subclasses must define how they should be converted to string."
        )


class Table(Reference):
    """Hold a reference to a table."""

    def __init__(self, table, quote_name):
        self.table = table
        self.quote_name = quote_name

    def references_table(self, table):
        return self.table == table

    def rename_table_references(self, old_table, new_table):
        if self.table == old_table:
            self.table = new_table

    def __str__(self):
        return self.quote_name(self.table)


class TableColumns(Table):
    """Base class for references to multiple columns of a table."""

    def __init__(self, table, columns):
        self.table = table
        self.columns = columns

    def references_column(self, table, column):
        return self.table == table and column in self.columns

    def rename_column_references(self, table, old_column, new_column):
        if self.table == table:
            for index, column in enumerate(self.columns):
                if column == old_column:
                    self.columns[index] = new_column


class Columns(TableColumns):
    """Hold a reference to one or many columns."""

    def __init__(self, table, columns, quote_name, col_suffixes=()):
        self.quote_name = quote_name
        self.col_suffixes = col_suffixes
        super().__init__(table, columns)

    def __str__(self):
        def col_str(column, idx):
            col = self.quote_name(column)
            try:
                suffix = self.col_suffixes[idx]
                if suffix:
                    col = "{} {}".format(col, suffix)
            except IndexError:
                pass
            return col

        return ", ".join(
            col_str(column, idx) for idx, column in enumerate(self.columns)
        )


class IndexName(TableColumns):
    """Hold a reference to an index name."""

    def __init__(self, table, columns, suffix, create_index_name):
        self.suffix = suffix
        self.create_index_name = create_index_name
        super().__init__(table, columns)

    def __str__(self):
        return self.create_index_name(self.table, self.columns, self.suffix)


class IndexColumns(Columns):
    def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
        self.opclasses = opclasses
        super().__init__(table, columns, quote_name, col_suffixes)

    def __str__(self):
        def col_str(column, idx):
            # Index.__init__() guarantees that self.opclasses is the same
            # length as self.columns.
            col = "{} {}".format(self.quote_name(column), self.opclasses[idx])
            try:
                suffix = self.col_suffixes[idx]
                if suffix:
                    col = "{} {}".format(col, suffix)
            except IndexError:
                pass
            return col

        return ", ".join(
            col_str(column, idx) for idx, column in enumerate(self.columns)
        )


class ForeignKeyName(TableColumns):
    """Hold a reference to a foreign key name."""

    def __init__(
        self,
        from_table,
        from_columns,
        to_table,
        to_columns,
        suffix_template,
        create_fk_name,
    ):
        self.to_reference = TableColumns(to_table, to_columns)
        self.suffix_template = suffix_template
        self.create_fk_name = create_fk_name
        super().__init__(
            from_table,
            from_columns,
        )

    def references_table(self, table):
        return super().references_table(table) or self.to_reference.references_table(
            table
        )

    def references_column(self, table, column):
        return super().references_column(
            table, column
        ) or self.to_reference.references_column(table, column)

    def rename_table_references(self, old_table, new_table):
        super().rename_table_references(old_table, new_table)
        self.to_reference.rename_table_references(old_table, new_table)

    def rename_column_references(self, table, old_column, new_column):
        super().rename_column_references(table, old_column, new_column)
        self.to_reference.rename_column_references(table, old_column, new_column)

    def __str__(self):
        suffix = self.suffix_template % {
            "to_table": self.to_reference.table,
            "to_column": self.to_reference.columns[0],
        }
        return self.create_fk_name(self.table, self.columns, suffix)


class Statement(Reference):
    """
    Statement template and formatting parameters container.

    Allows keeping a reference to a statement without interpolating identifiers
    that might have to be adjusted if they're referencing a table or column
    that is removed
    """

    def __init__(self, template, **parts):
        self.template = template
        self.parts = parts

    def references_table(self, table):
        return any(
            hasattr(part, "references_table") and part.references_table(table)
            for part in self.parts.values()
        )

    def references_column(self, table, column):
        return any(
            hasattr(part, "references_column") and part.references_column(table, column)
            for part in self.parts.values()
        )

    def rename_table_references(self, old_table, new_table):
        for part in self.parts.values():
            if hasattr(part, "rename_table_references"):
                part.rename_table_references(old_table, new_table)

    def rename_column_references(self, table, old_column, new_column):
        for part in self.parts.values():
            if hasattr(part, "rename_column_references"):
                part.rename_column_references(table, old_column, new_column)

    def __str__(self):
        return self.template % self.parts


class Expressions(TableColumns):
    def __init__(self, table, expressions, compiler, quote_value):
        self.compiler = compiler
        self.expressions = expressions
        self.quote_value = quote_value
        columns = [
            col.target.column
            for col in self.compiler.query._gen_cols([self.expressions])
        ]
        super().__init__(table, columns)

    def rename_table_references(self, old_table, new_table):
        if self.table != old_table:
            return
        self.expressions = self.expressions.relabeled_clone({old_table: new_table})
        super().rename_table_references(old_table, new_table)

    def rename_column_references(self, table, old_column, new_column):
        if self.table != table:
            return
        expressions = deepcopy(self.expressions)
        self.columns = []
        for col in self.compiler.query._gen_cols([expressions]):
            if col.target.column == old_column:
                col.target.column = new_column
            self.columns.append(col.target.column)
        self.expressions = expressions

    def __str__(self):
        sql, params = self.compiler.compile(self.expressions)
        params = map(self.quote_value, params)
        return sql % tuple(params)