django/django

View on GitHub
django/contrib/gis/db/models/lookups.py

Summary

Maintainability
A
3 hrs
Test Coverage
from django.contrib.gis.db.models.fields import BaseSpatialField
from django.contrib.gis.measure import Distance
from django.db import NotSupportedError
from django.db.models import Expression, Lookup, Transform
from django.db.models.sql.query import Query
from django.utils.regex_helper import _lazy_re_compile


class RasterBandTransform(Transform):
    def as_sql(self, compiler, connection):
        return compiler.compile(self.lhs)


class GISLookup(Lookup):
    sql_template = None
    transform_func = None
    distance = False
    band_rhs = None
    band_lhs = None

    def __init__(self, lhs, rhs):
        rhs, *self.rhs_params = rhs if isinstance(rhs, (list, tuple)) else [rhs]
        super().__init__(lhs, rhs)
        self.template_params = {}
        self.process_rhs_params()

    def process_rhs_params(self):
        if self.rhs_params:
            # Check if a band index was passed in the query argument.
            if len(self.rhs_params) == (2 if self.lookup_name == "relate" else 1):
                self.process_band_indices()
            elif len(self.rhs_params) > 1:
                raise ValueError("Tuple too long for lookup %s." % self.lookup_name)
        elif isinstance(self.lhs, RasterBandTransform):
            self.process_band_indices(only_lhs=True)

    def process_band_indices(self, only_lhs=False):
        """
        Extract the lhs band index from the band transform class and the rhs
        band index from the input tuple.
        """
        # PostGIS band indices are 1-based, so the band index needs to be
        # increased to be consistent with the GDALRaster band indices.
        if only_lhs:
            self.band_rhs = 1
            self.band_lhs = self.lhs.band_index + 1
            return

        if isinstance(self.lhs, RasterBandTransform):
            self.band_lhs = self.lhs.band_index + 1
        else:
            self.band_lhs = 1

        self.band_rhs, *self.rhs_params = self.rhs_params

    def get_db_prep_lookup(self, value, connection):
        # get_db_prep_lookup is called by process_rhs from super class
        return ("%s", [connection.ops.Adapter(value)])

    def process_rhs(self, compiler, connection):
        if isinstance(self.rhs, Query):
            # If rhs is some Query, don't touch it.
            return super().process_rhs(compiler, connection)
        if isinstance(self.rhs, Expression):
            self.rhs = self.rhs.resolve_expression(compiler.query)
        rhs, rhs_params = super().process_rhs(compiler, connection)
        placeholder = connection.ops.get_geom_placeholder(
            self.lhs.output_field, self.rhs, compiler
        )
        return placeholder % rhs, rhs_params

    def get_rhs_op(self, connection, rhs):
        # Unlike BuiltinLookup, the GIS get_rhs_op() implementation should return
        # an object (SpatialOperator) with an as_sql() method to allow for more
        # complex computations (where the lhs part can be mixed in).
        return connection.ops.gis_operators[self.lookup_name]

    def as_sql(self, compiler, connection):
        lhs_sql, lhs_params = self.process_lhs(compiler, connection)
        rhs_sql, rhs_params = self.process_rhs(compiler, connection)
        sql_params = (*lhs_params, *rhs_params)

        template_params = {
            "lhs": lhs_sql,
            "rhs": rhs_sql,
            "value": "%s",
            **self.template_params,
        }
        rhs_op = self.get_rhs_op(connection, rhs_sql)
        return rhs_op.as_sql(connection, self, template_params, sql_params)


# ------------------
# Geometry operators
# ------------------


@BaseSpatialField.register_lookup
class OverlapsLeftLookup(GISLookup):
    """
    The overlaps_left operator returns true if A's bounding box overlaps or is to the
    left of B's bounding box.
    """

    lookup_name = "overlaps_left"


@BaseSpatialField.register_lookup
class OverlapsRightLookup(GISLookup):
    """
    The 'overlaps_right' operator returns true if A's bounding box overlaps or is to the
    right of B's bounding box.
    """

    lookup_name = "overlaps_right"


@BaseSpatialField.register_lookup
class OverlapsBelowLookup(GISLookup):
    """
    The 'overlaps_below' operator returns true if A's bounding box overlaps or is below
    B's bounding box.
    """

    lookup_name = "overlaps_below"


@BaseSpatialField.register_lookup
class OverlapsAboveLookup(GISLookup):
    """
    The 'overlaps_above' operator returns true if A's bounding box overlaps or is above
    B's bounding box.
    """

    lookup_name = "overlaps_above"


@BaseSpatialField.register_lookup
class LeftLookup(GISLookup):
    """
    The 'left' operator returns true if A's bounding box is strictly to the left
    of B's bounding box.
    """

    lookup_name = "left"


@BaseSpatialField.register_lookup
class RightLookup(GISLookup):
    """
    The 'right' operator returns true if A's bounding box is strictly to the right
    of B's bounding box.
    """

    lookup_name = "right"


@BaseSpatialField.register_lookup
class StrictlyBelowLookup(GISLookup):
    """
    The 'strictly_below' operator returns true if A's bounding box is strictly below B's
    bounding box.
    """

    lookup_name = "strictly_below"


@BaseSpatialField.register_lookup
class StrictlyAboveLookup(GISLookup):
    """
    The 'strictly_above' operator returns true if A's bounding box is strictly above B's
    bounding box.
    """

    lookup_name = "strictly_above"


@BaseSpatialField.register_lookup
class SameAsLookup(GISLookup):
    """
    The "~=" operator is the "same as" operator. It tests actual geometric
    equality of two features. So if A and B are the same feature,
    vertex-by-vertex, the operator returns true.
    """

    lookup_name = "same_as"


BaseSpatialField.register_lookup(SameAsLookup, "exact")


@BaseSpatialField.register_lookup
class BBContainsLookup(GISLookup):
    """
    The 'bbcontains' operator returns true if A's bounding box completely contains
    by B's bounding box.
    """

    lookup_name = "bbcontains"


@BaseSpatialField.register_lookup
class BBOverlapsLookup(GISLookup):
    """
    The 'bboverlaps' operator returns true if A's bounding box overlaps B's
    bounding box.
    """

    lookup_name = "bboverlaps"


@BaseSpatialField.register_lookup
class ContainedLookup(GISLookup):
    """
    The 'contained' operator returns true if A's bounding box is completely contained
    by B's bounding box.
    """

    lookup_name = "contained"


# ------------------
# Geometry functions
# ------------------


@BaseSpatialField.register_lookup
class ContainsLookup(GISLookup):
    lookup_name = "contains"


@BaseSpatialField.register_lookup
class ContainsProperlyLookup(GISLookup):
    lookup_name = "contains_properly"


@BaseSpatialField.register_lookup
class CoveredByLookup(GISLookup):
    lookup_name = "coveredby"


@BaseSpatialField.register_lookup
class CoversLookup(GISLookup):
    lookup_name = "covers"


@BaseSpatialField.register_lookup
class CrossesLookup(GISLookup):
    lookup_name = "crosses"


@BaseSpatialField.register_lookup
class DisjointLookup(GISLookup):
    lookup_name = "disjoint"


@BaseSpatialField.register_lookup
class EqualsLookup(GISLookup):
    lookup_name = "equals"


@BaseSpatialField.register_lookup
class IntersectsLookup(GISLookup):
    lookup_name = "intersects"


@BaseSpatialField.register_lookup
class OverlapsLookup(GISLookup):
    lookup_name = "overlaps"


@BaseSpatialField.register_lookup
class RelateLookup(GISLookup):
    lookup_name = "relate"
    sql_template = "%(func)s(%(lhs)s, %(rhs)s, %%s)"
    pattern_regex = _lazy_re_compile(r"^[012TF*]{9}$")

    def process_rhs(self, compiler, connection):
        # Check the pattern argument
        pattern = self.rhs_params[0]
        backend_op = connection.ops.gis_operators[self.lookup_name]
        if hasattr(backend_op, "check_relate_argument"):
            backend_op.check_relate_argument(pattern)
        elif not isinstance(pattern, str) or not self.pattern_regex.match(pattern):
            raise ValueError('Invalid intersection matrix pattern "%s".' % pattern)
        sql, params = super().process_rhs(compiler, connection)
        return sql, params + [pattern]


@BaseSpatialField.register_lookup
class TouchesLookup(GISLookup):
    lookup_name = "touches"


@BaseSpatialField.register_lookup
class WithinLookup(GISLookup):
    lookup_name = "within"


class DistanceLookupBase(GISLookup):
    distance = True
    sql_template = "%(func)s(%(lhs)s, %(rhs)s) %(op)s %(value)s"

    def process_rhs_params(self):
        if not 1 <= len(self.rhs_params) <= 3:
            raise ValueError(
                "2, 3, or 4-element tuple required for '%s' lookup." % self.lookup_name
            )
        elif len(self.rhs_params) == 3 and self.rhs_params[2] != "spheroid":
            raise ValueError(
                "For 4-element tuples the last argument must be the 'spheroid' "
                "directive."
            )

        # Check if the second parameter is a band index.
        if len(self.rhs_params) > 1 and self.rhs_params[1] != "spheroid":
            self.process_band_indices()

    def process_distance(self, compiler, connection):
        dist_param = self.rhs_params[0]
        return (
            compiler.compile(dist_param.resolve_expression(compiler.query))
            if hasattr(dist_param, "resolve_expression")
            else (
                "%s",
                connection.ops.get_distance(
                    self.lhs.output_field, self.rhs_params, self.lookup_name
                ),
            )
        )


@BaseSpatialField.register_lookup
class DWithinLookup(DistanceLookupBase):
    lookup_name = "dwithin"
    sql_template = "%(func)s(%(lhs)s, %(rhs)s, %(value)s)"

    def process_distance(self, compiler, connection):
        dist_param = self.rhs_params[0]
        if (
            not connection.features.supports_dwithin_distance_expr
            and hasattr(dist_param, "resolve_expression")
            and not isinstance(dist_param, Distance)
        ):
            raise NotSupportedError(
                "This backend does not support expressions for specifying "
                "distance in the dwithin lookup."
            )
        return super().process_distance(compiler, connection)

    def process_rhs(self, compiler, connection):
        dist_sql, dist_params = self.process_distance(compiler, connection)
        self.template_params["value"] = dist_sql
        rhs_sql, params = super().process_rhs(compiler, connection)
        return rhs_sql, params + dist_params


class DistanceLookupFromFunction(DistanceLookupBase):
    def as_sql(self, compiler, connection):
        spheroid = (
            len(self.rhs_params) == 2 and self.rhs_params[-1] == "spheroid"
        ) or None
        distance_expr = connection.ops.distance_expr_for_lookup(
            self.lhs, self.rhs, spheroid=spheroid
        )
        sql, params = compiler.compile(distance_expr.resolve_expression(compiler.query))
        dist_sql, dist_params = self.process_distance(compiler, connection)
        return (
            "%(func)s %(op)s %(dist)s" % {"func": sql, "op": self.op, "dist": dist_sql},
            params + dist_params,
        )


@BaseSpatialField.register_lookup
class DistanceGTLookup(DistanceLookupFromFunction):
    lookup_name = "distance_gt"
    op = ">"


@BaseSpatialField.register_lookup
class DistanceGTELookup(DistanceLookupFromFunction):
    lookup_name = "distance_gte"
    op = ">="


@BaseSpatialField.register_lookup
class DistanceLTLookup(DistanceLookupFromFunction):
    lookup_name = "distance_lt"
    op = "<"


@BaseSpatialField.register_lookup
class DistanceLTELookup(DistanceLookupFromFunction):
    lookup_name = "distance_lte"
    op = "<="