model_values/__init__.py
from __future__ import annotations
import collections
import functools
import itertools
import math
import operator
import types
from collections.abc import Callable, Iterable, Mapping
import django
from django.db import IntegrityError, models, transaction
from django.db.models import functions
try: # pragma: no cover
import django.contrib.gis.db.models.functions # noqa: F401
import django.contrib.gis.db.models as gis
except Exception:
gis = None
__version__ = '1.6'
def update_wrapper(wrapper, name):
wrapper.__name__ = wrapper.__doc__ = name
return wrapper
def eq(lookup):
return update_wrapper(lambda self, value: self.__eq__(value, '__' + lookup), lookup)
class Lookup:
"""Mixin for field lookups."""
__ne__ = eq('ne')
__lt__ = eq('lt')
__le__ = eq('lte')
__gt__ = eq('gt')
__ge__ = eq('gte')
iexact = eq('iexact')
icontains = eq('icontains')
startswith = eq('startswith')
istartswith = eq('istartswith')
endswith = eq('endswith')
iendswith = eq('iendswith')
regex = eq('regex')
iregex = eq('iregex')
isin = eq('in')
# spatial lookups
contained = eq('contained')
coveredby = eq('coveredby')
covers = eq('covers')
crosses = eq('crosses')
disjoint = eq('disjoint')
equals = eq('equals') # __eq__ is taken
intersects = eq('intersects') # __and__ is ambiguous
touches = eq('touches')
__lshift__ = left = eq('left')
__rshift__ = right = eq('right')
above = eq('strictly_above')
below = eq('strictly_below')
def range(self, *values):
"""range"""
return self.__eq__(values, '__range')
def relate(self, *values):
"""relate"""
return self.__eq__(values, '__relate')
@property
def is_valid(self):
"""Whether field `isvalid`."""
return self.__eq__(True, '__isvalid')
def contains(self, value, properly=False, bb=False):
"""Return whether field `contains` the value. Options apply only to geom fields.
Args:
properly: `contains_properly`
bb: bounding box, `bbcontains`
"""
properly = '_properly' * bool(properly)
bb = 'bb' * bool(bb)
return self.__eq__(value, f'__{bb}contains{properly}')
def overlaps(self, geom, position='', bb=False):
"""Return whether field `overlaps` with geometry .
Args:
position: `overlaps_{left, right, above, below}`
bb: bounding box, `bboverlaps`
"""
bb = 'bb' * bool(bb)
return self.__eq__(geom, f'__{bb}overlaps_{position}'.rstrip('_'))
def within(self, geom, distance=None):
"""Return whether field is `within` geometry.
Args:
distance: `dwithin`
"""
if distance is None:
return self.__eq__(geom, '__within')
return self.__eq__((geom, distance), '__dwithin')
class method(functools.partial):
def __init__(self, func, *args):
self.__doc__ = func.__doc__ or func.__name__
def __get__(self, instance, owner):
return self if instance is None else types.MethodType(self, instance)
def transform(lookup, func, value):
field, expr = func.source_expressions
expr = expr if isinstance(expr, models.F) else expr.value
return field.__eq__((expr, value), '__' + lookup)
class MetaF(type):
def __getattr__(cls, name: str) -> F:
if name in ('name', '__slots__', '__wrapped__'):
raise AttributeError(f"'{name}' is a reserved attribute")
return cls(name)
def any(cls, exprs: Iterable[models.Q]) -> models.Q:
"""Return ``Q`` OR object."""
return functools.reduce(operator.or_, exprs)
def all(cls, exprs: Iterable[models.Q]) -> models.Q:
"""Return ``Q`` AND object."""
return functools.reduce(operator.and_, exprs)
class F(models.F, Lookup, metaclass=MetaF):
"""Create ``F``, ``Q``, and ``Func`` objects with expressions.
``F`` creation supported as attributes:
``F.user`` == ``F('user')``,
``F.user.created`` == ``F('user__created')``.
``Q`` lookups supported as methods or operators:
``F.text.iexact(...)`` == ``Q(text__iexact=...)``,
``F.user.created >= ...`` == ``Q(user__created__gte=...)``.
``Func`` objects also supported as methods:
``F.user.created.min()`` == ``Min('user__created')``.
Some ``Func`` objects can also be transformed into lookups,
if [registered](https://docs.djangoproject.com/en/stable/ref/models/database-functions/#length):
``F.text.length()`` == ``Length(F('text'))``,
``F.text.length > 0`` == ``Q(text__length__gt=0)``.
"""
lookups = dict(
length=functions.Length,
lower=functions.Lower,
upper=functions.Upper,
chr=functions.Chr,
ord=functions.Ord,
acos=functions.ACos,
asin=functions.ASin,
atan=functions.ATan,
atan2=functions.ATan2,
cos=functions.Cos,
cot=functions.Cot,
degrees=functions.Degrees,
exp=functions.Exp,
radians=functions.Radians,
sin=functions.Sin,
sqrt=functions.Sqrt,
tan=functions.Tan,
sign=functions.Sign,
md5=functions.MD5,
)
coalesce = method(functions.Coalesce)
concat = method(functions.Concat) # __add__ is taken
min = method(models.Min)
max = method(models.Max)
sum = method(models.Sum)
mean = method(models.Avg)
var = method(models.Variance)
std = method(models.StdDev)
greatest = method(functions.Greatest)
least = method(functions.Least)
now = staticmethod(functions.Now)
cast = method(functions.Cast)
extract = method(functions.Extract)
trunc = method(functions.Trunc)
cume_dist = method(functions.CumeDist)
dense_rank = method(functions.DenseRank)
first_value = method(functions.FirstValue)
lag = method(functions.Lag)
last_value = method(functions.LastValue)
lead = method(functions.Lead)
nth_value = method(functions.NthValue)
ntile = staticmethod(functions.Ntile)
percent_rank = method(functions.PercentRank)
rank = method(functions.Rank)
row_number = method(functions.RowNumber)
strip = method(functions.Trim)
lstrip = method(functions.LTrim)
rstrip = method(functions.RTrim)
repeat = method(functions.Repeat)
nullif = method(functions.NullIf)
__reversed__ = method(functions.Reverse)
__abs__ = method(functions.Abs)
__ceil__ = method(functions.Ceil)
__floor__ = method(functions.Floor)
__mod__ = method(functions.Mod)
pi = functions.Pi()
__pow__ = method(functions.Power)
__round__ = method(functions.Round)
sha1 = method(functions.SHA1)
sha224 = method(functions.SHA224)
sha256 = method(functions.SHA256)
sha384 = method(functions.SHA384)
sha512 = method(functions.SHA512)
collate = method(functions.Collate)
json = staticmethod(functions.JSONObject)
random = staticmethod(functions.Random)
if gis: # pragma: no cover
area = property(gis.functions.Area)
geojson = method(gis.functions.AsGeoJSON)
gml = method(gis.functions.AsGML)
kml = method(gis.functions.AsKML)
svg = method(gis.functions.AsSVG)
bounding_circle = method(gis.functions.BoundingCircle)
centroid = property(gis.functions.Centroid)
difference = method(gis.functions.Difference)
envelope = property(gis.functions.Envelope)
geohash = method(gis.functions.GeoHash) # __hash__ requires an int
intersection = method(gis.functions.Intersection)
make_valid = method(gis.functions.MakeValid)
mem_size = property(gis.functions.MemSize)
num_geometries = property(gis.functions.NumGeometries)
num_points = property(gis.functions.NumPoints)
perimeter = property(gis.functions.Perimeter)
point_on_surface = property(gis.functions.PointOnSurface)
reverse = method(gis.functions.Reverse)
scale = method(gis.functions.Scale)
snap_to_grid = method(gis.functions.SnapToGrid)
symmetric_difference = method(gis.functions.SymDifference)
transform = method(gis.functions.Transform)
translate = method(gis.functions.Translate)
union = method(gis.functions.Union)
azimuth = method(gis.functions.Azimuth)
line_locate_point = method(gis.functions.LineLocatePoint)
force_polygon_cw = method(gis.functions.ForcePolygonCW)
@method
class distance(gis.functions.Distance):
"""Return ``Distance`` with support for lookups: <, <=, >, >=, within."""
__lt__ = method(transform, 'distance_lt')
__le__ = method(transform, 'distance_lte')
__gt__ = method(transform, 'distance_gt')
__ge__ = method(transform, 'distance_gte')
within = method(transform, 'dwithin')
def __getattr__(self, name: str) -> F:
"""Return new [F][model_values.F] object with chained attribute."""
return type(self)(f'{self.name}__{name}')
def __eq__(self, value, lookup: str = '') -> models.Q:
"""Return ``Q`` object with lookup."""
if not lookup and type(value) is models.F:
return self.name == value.name
return models.Q(**{self.name + lookup: value})
def __ne__(self, value) -> models.Q:
"""Allow __ne=None lookup without custom queryset."""
if value is None:
return self.__eq__(False, '__isnull')
return self.__eq__(value, '__ne')
__hash__ = models.F.__hash__
def __call__(self, *args, **extra) -> models.Func:
name, _, func = self.name.rpartition('__')
return self.lookups[func](name, *args, **extra)
def __iter__(self):
raise TypeError("'F' object is not iterable")
def __getitem__(self, slc: slice) -> models.Func:
"""Return field ``Substr`` or ``Right``."""
assert (slc.stop or 0) >= 0 and slc.step is None
start = slc.start or 0
if start < 0:
assert slc.stop is None
return functions.Right(self, -start)
size = slc.stop and max(slc.stop - start, 0)
return functions.Substr(self, start + 1, size)
def __rmod__(self, value):
return functions.Mod(value, self)
def __rpow__(self, value):
return functions.Power(value, self)
@method
def count(self='*', **extra):
"""Return ``Count`` with optional field."""
return models.Count(getattr(self, 'name', self), **extra)
def find(self, sub, **extra) -> models.Expression:
"""Return ``StrIndex`` with ``str.find`` semantics."""
return functions.StrIndex(self, Value(sub), **extra) - 1
def replace(self, old, new='', **extra) -> models.Func:
"""Return ``Replace`` with wrapped values."""
return functions.Replace(self, Value(old), Value(new), **extra)
def ljust(self, width: int, fill=' ', **extra) -> models.Func:
"""Return ``LPad`` with wrapped values."""
return functions.LPad(self, width, Value(fill), **extra)
def rjust(self, width: int, fill=' ', **extra) -> models.Func:
"""Return ``RPad`` with wrapped values."""
return functions.RPad(self, width, Value(fill), **extra)
def log(self, base=math.e, **extra) -> models.Func:
"""Return ``Log``, by default ``Ln``."""
return functions.Log(self, base, **extra)
def reduce(func):
return update_wrapper(lambda self, **extra: self.reduce(func, **extra), func.__name__)
def field(func):
return update_wrapper(lambda self, value: func(models.F(*self._fields), value), func.__name__)
class QuerySet(models.QuerySet, Lookup):
min = reduce(models.Min)
max = reduce(models.Max)
sum = reduce(models.Sum)
mean = reduce(models.Avg)
var = reduce(models.Variance)
std = reduce(models.StdDev)
__add__ = field(operator.add)
__sub__ = field(operator.sub)
__mul__ = field(operator.mul)
__truediv__ = field(operator.truediv)
__mod__ = field(operator.mod)
__pow__ = field(operator.pow)
if gis: # pragma: no cover
collect = reduce(gis.Collect)
extent = reduce(gis.Extent)
extent3d = reduce(gis.Extent3D)
make_line = reduce(gis.MakeLine)
union = reduce(gis.Union)
@property
def _flat(self):
return issubclass(self._iterable_class, models.query.FlatValuesListIterable)
@property
def _named(self):
return issubclass(self._iterable_class, models.query.NamedValuesListIterable)
def __getitem__(self, key):
"""Allow column access by field names, expressions, or ``F`` objects.
``qs[field]`` returns flat ``values_list``
``qs[field, ...]`` returns tupled ``values_list``
``qs[Q_obj]`` provisionally returns filtered [QuerySet][model_values.QuerySet]
"""
if isinstance(key, tuple):
return self.values_list(*map(extract, key), named=True)
key = extract(key)
if isinstance(key, (str, models.Expression)):
return self.values_list(key, flat=True)
if isinstance(key, models.Q):
return self.filter(key)
return super().__getitem__(key)
def __setitem__(self, key, value):
"""Update a single column."""
self.update(**{key: value})
def __eq__(self, value, lookup: str = '') -> QuerySet:
"""Return [QuerySet][model_values.QuerySet] filtered by comparison to given value."""
(field,) = self._fields
return self.filter(**{field + lookup: value})
def __contains__(self, value):
"""Return whether value is present using ``exists``."""
if self._result_cache is None and self._flat:
return (self == value).exists()
return value in iter(self)
def __iter__(self):
"""Iteration extended to support [groupby][model_values.QuerySet.groupby]."""
if not hasattr(self, '_groupby'):
return super().__iter__()
size = len(self._groupby)
rows = self[self._groupby + self._fields].order_by(*self._groupby).iterator()
groups = itertools.groupby(rows, key=operator.itemgetter(*range(size)))
getter = operator.itemgetter(size if self._flat else slice(size, None))
if self._named:
Row = collections.namedtuple('Row', self._fields)
getter = lambda tup: Row(*tup[size:]) # noqa: E731
return ((key, map(getter, values)) for key, values in groups)
def items(self, *fields, **annotations) -> QuerySet:
"""Return annotated ``values_list``."""
return self.annotate(**annotations)[fields + tuple(annotations)]
def groupby(self, *fields, **annotations) -> QuerySet:
"""Return a grouped [QuerySet][model_values.QuerySet].
The queryset is iterable in the same manner as ``itertools.groupby``.
Additionally the [reduce][model_values.QuerySet.reduce] functions will return annotated querysets.
"""
qs = self.annotate(**annotations)
qs._groupby = fields + tuple(annotations)
return qs
def annotate(self, *args, **kwargs) -> QuerySet:
"""Annotate extended to also handle mapping values, as a [Case][model_values.Case] expression.
Args:
**kwargs: ``field={Q_obj: value, ...}, ...``
As a provisional feature, an optional ``default`` key may be specified.
"""
for field, value in kwargs.items():
if Case.isa(value):
kwargs[field] = Case.defaultdict(value)
return super().annotate(*args, **kwargs)
def alias(self, *args, **kwargs) -> QuerySet:
"""Alias extended to also handle mapping values, as a [Case][model_values.Case] expression.
Args:
**kwargs: ``field={Q_obj: value, ...}, ...``
"""
for field, value in kwargs.items():
if Case.isa(value):
kwargs[field] = Case.defaultdict(value)
return super().alias(*args, **kwargs)
def value_counts(self, alias: str = 'count') -> QuerySet:
"""Return annotated value counts."""
return self.items(*self._fields, **{alias: F.count()})
def sort_values(self, reverse=False) -> QuerySet:
"""Return [QuerySet][model_values.QuerySet] ordered by selected values."""
qs = self.order_by(*self._fields)
return qs.reverse() if reverse else qs
def reduce(self, *funcs, **extra):
"""Return aggregated values, or an annotated [QuerySet][model_values.QuerySet].
Args:
*funcs: aggregation function classes
"""
funcs = [func(field, **extra) for field, func in zip(self._fields, itertools.cycle(funcs))]
if hasattr(self, '_groupby'):
return self[self._groupby].annotate(*funcs)
names = [func.default_alias for func in funcs]
row = self.aggregate(*funcs)
if self._named:
return collections.namedtuple('Row', names)(**row)
return row[names[0]] if self._flat else tuple(map(row.__getitem__, names))
def update(self, **kwargs) -> int:
"""Update extended to also handle mapping values, as a [Case][model_values.Case] expression.
Args:
**kwargs: ``field={Q_obj: value, ...}, ...``
"""
for field, value in kwargs.items():
if Case.isa(value):
kwargs[field] = Case(value, default=F(field))
return super().update(**kwargs)
def change(self, defaults: Mapping = {}, **kwargs) -> int:
"""Update and return number of rows that actually changed.
For triggering on-change logic without fetching first.
``if qs.change(status=...):`` status actually changed
``qs.change({'last_modified': now}, status=...)`` last_modified only updated if status updated
Args:
defaults: optional mapping which will be updated conditionally, as with ``update_or_create``.
"""
return self.exclude(**kwargs).update(**dict(defaults, **kwargs))
def changed(self, **kwargs) -> dict:
"""Return first mapping of fields and values which differ in the db.
Also efficient enough to be used in boolean contexts, instead of ``exists``.
"""
row = self.exclude(**kwargs).values(*kwargs).first() or {}
return {field: value for field, value in row.items() if value != kwargs[field]}
def exists(self, count: int = 1) -> bool:
"""Return whether there are at least the specified number of rows."""
if count == 1:
return super().exists()
return (self[:count].count() if self._result_cache is None else len(self)) >= count
@models.Field.register_lookup
class NotEqual(models.Lookup):
"""Missing != operator."""
lookup_name = 'ne'
def as_sql(self, *args):
lhs, lhs_params = self.process_lhs(*args)
rhs, rhs_params = self.process_rhs(*args)
return f'{lhs} <> {rhs}', (lhs_params + rhs_params)
class Query(models.sql.Query):
"""Allow __ne=None lookup."""
def build_lookup(self, lookups, lhs, rhs):
if rhs is None and lookups[-1:] == ['ne']:
rhs, lookups[-1] = False, 'isnull'
return super().build_lookup(lookups, lhs, rhs)
class Manager(models.Manager):
def get_queryset(self):
return QuerySet(self.model, Query(self.model), self._db, self._hints)
def __getitem__(self, pk) -> QuerySet:
"""Return [QuerySet][model_values.QuerySet] which matches primary key.
To encourage direct db access, instead of always using get and save.
"""
return self.filter(pk=pk)
def __delitem__(self, pk):
"""Delete row with primary key."""
self[pk].delete()
def __contains__(self, pk):
"""Return whether primary key is present using ``exists``."""
return self[pk].exists()
def upsert(self, defaults: Mapping = {}, **kwargs) -> int | models.Model:
"""Update or insert returning number of rows or created object.
Faster and safer than ``update_or_create``.
Supports combined expression updates by assuming the identity element on insert: ``F(...) + 1``.
Args:
defaults: optional mapping which will be updated, as with ``update_or_create``.
"""
update = getattr(self.filter(**kwargs), 'update' if defaults else 'count')
for field, value in defaults.items():
expr = isinstance(value, models.expressions.CombinedExpression)
kwargs[field] = value.rhs.value if expr else value
try:
with transaction.atomic():
return update(**defaults) or self.create(**kwargs)
except IntegrityError:
return update(**defaults)
def bulk_changed(self, field, data: Mapping, key: str = 'pk') -> dict:
"""Return mapping of values which differ in the db.
Args:
field: value column
data: ``{pk: value, ...}``
key: unique key column
"""
rows = self.filter(F(key).isin(data))[key, field].iterator()
return {pk: value for pk, value in rows if value != data[pk]}
def bulk_change(
self, field, data: Mapping, key: str = 'pk', conditional=False, **kwargs
) -> int:
"""Update changed rows with a minimal number of queries, by inverting the data to use ``pk__in``.
Args:
field: value column
data: ``{pk: value, ...}``
key: unique key column
conditional: execute select query and single conditional update;
may be more efficient if the percentage of changed rows is relatively small
**kwargs: additional fields to be updated
"""
if conditional:
data = {pk: data[pk] for pk in self.bulk_changed(field, data, key)}
updates = collections.defaultdict(list)
for pk in data:
updates[data[pk]].append(pk)
if conditional:
kwargs[field] = {F(key).isin(tuple(updates[value])): value for value in updates}
return self.filter(F(key).isin(data)).update(**kwargs)
count = 0
for value in updates:
kwargs[field] = value
count += self.filter((F(field) != value) & F(key).isin(updates[value])).update(**kwargs)
return count
class classproperty(property):
"""A property bound to a class."""
def __get__(self, instance, owner):
return self.fget(owner)
def Value(value):
return value if isinstance(value, models.F) else models.Value(value)
def extract(field):
if isinstance(field, models.F):
return field.name
return Case.defaultdict(field) if Case.isa(field) else field
class Case(models.Case):
"""``Case`` expression from mapping of when conditionals.
Args:
conds: ``{Q_obj: value, ...}``
default: optional default value or ``F`` object
"""
types = {
str: models.CharField,
int: models.IntegerField,
float: models.FloatField,
bool: models.BooleanField,
}
def __new__(cls, conds: Mapping, default=None, **extra):
cases = (models.When(cond, Value(conds[cond])) for cond in conds)
return models.Case(*cases, default=Value(default), **extra)
@classmethod
def defaultdict(cls, conds):
conds = dict(conds)
return cls(conds, default=conds.pop('default', None))
@classmethod
def isa(cls, value):
return isinstance(value, Mapping) and any(isinstance(key, models.Q) for key in value)
def EnumField(enum, display: Callable | None = None, **options) -> models.Field:
"""Return a ``CharField`` or ``IntegerField`` with choices from given enum.
By default, enum names and values are used as db values and display labels respectively,
returning a ``CharField`` with computed ``max_length``.
Args:
display: optional callable to transform enum names to display labels,
thereby using enum values as db values and also supporting integers.
"""
choices = tuple((choice.name, choice.value) for choice in enum)
if display is not None:
choices = tuple((choice.value, display(choice.name)) for choice in enum)
try:
max_length = max(map(len, dict(choices)))
except TypeError:
return models.IntegerField(choices=choices, **options)
return models.CharField(max_length=max_length, choices=choices, **options)