anubis/query.py
# Copyright (C) 2014, Ugo Pozo 2014, Câmara Municipal de São Paulo
# query.py - extensões para queries do Anubis.
# Este arquivo é parte do software Anubis.
# Anubis é um software livre: você pode redistribuí-lo e/ou
# modificá-lo sob os termos da Licença Pública Geral GNU (GNU General Public
# License), tal como é publicada pela Free Software Foundation, na
# versão 3 da licença, ou (sua decisão) qualquer versão posterior.
# Anubis é distribuído na esperança de que seja útil, mas SEM NENHUMA
# GARANTIA; nem mesmo a garantia implícita de VALOR COMERCIAL ou ADEQUAÇÃO PARA
# UM PROPÓSITO EM PARTICULAR. Veja a Licença Pública Geral GNU para mais
# detalhes.
# Você deve ter recebido uma cópia da Licença Pública Geral GNU junto com este
# programa. Se não, consulte <http://www.gnu.org/licenses/>.
from collections import OrderedDict
from operator import itemgetter
from anubis.sql_aggregators import ProcedureOrderingAnnotation
from django.db import connection as base_connection, connections
from django.db import models
from django.db.models.query import QuerySet
from django.db.models.expressions import Ref, F
class ProcedureQuerySet(QuerySet):
"""Class to provide a procedure aware QuerySet (and, with the
:meth:`as_manager` method, a Manager) to Djang models.
Example:
Use this class as follows to enable procedure queries on your models::
from django.db.models import Model, CharField
from anubis.query import ProcedureQuerySet
class MyModel(Model):
myText = CharField()
class QuerySet(ProcedureQuerySet):
pass
objects = QuerySet.as_manager()
"""
order_by_procedure_column = "anubis_index"
def _get_connection(self):
if self.db is not None:
conn = connections[self.db]
else:
conn = base_connection
return conn
@staticmethod
def _test_order_direction(source, dest):
if source.startswith("-"):
source = source[1:]
dest = "-{}".format(dest)
elif source.startswith("+"):
source = source[1:]
dest = "+{}".format(dest)
return source, dest
def order_by_procedure(self, procname, *args, field='id', extra_fields=None,
**kwargs):
aggregate = self.procedure_aggregate(procname, *args, **kwargs)
return self.order_by_aggregates(aggregate, field=field,
extra_fields=extra_fields)
def order_by_aggregates(self, *aggregates, field="id", extra_fields=None):
annotation = OrderedDict([
("{}_{}".format(agg.default_alias, i), agg)
for i, agg in enumerate(aggregates)
])
fields = list(annotation.keys())
fields = [Ref(f, annotation[f]) for f in fields]
if extra_fields is not None:
fields += [F(f) for f in extra_fields]
else:
fields.append(F(field))
return self.annotate(**annotation).order_by(*fields)
def procedure_aggregate(self, procname, *args, **kwargs):
"""Helper method for generating ordering annotations.
Returns:
anubis.sql_aggregators.ProcedureOrderingAnnotation: An ordering
annotation with the correct procedure name (including table
name).
"""
procname = "{}_{}".format(self.model._meta.db_table, procname)
return ProcedureOrderingAnnotation(procname, *args, **kwargs)
def procedure(self, procname, *args):
"""This method calls a procedure from the database.
This method calls a procedure from the database. It expects the
procedure to be name in the format "[table_name]_[procname]", and that
it returns an array of primary keys. This is for composability with
other queryset methods. For procedures that return arbitrary values,
use :meth:`unchainable_procedure`.
Note:
Don't **ever** pass untreated user input values (or any values that
cannot be implicitly trusted) to `procname`. Its value is not
treated in anyway and **can be used for SQL-injection attacks**. If
you must get the `procname` dinamically, set up a list of allowed
names and check against that before calling this function.
Args:
procname (str): The name of the procedure, minus the starting
"[table_name]_" prefix. **This value is not treated and can be
used for SQL injection attacks**. Never pass untreated user
input or any values that cannot be implicitly trusted to it.
*args: Arguments to be passed to the procedure. These values will
be passed as binding parameters in the SQL query, so they should
be safe (but don't quote me on this).
Returns:
ProcedureQuerySet: A queryset containing the records whose primary
keys match the ones in the array returned by the database stored
procedure.
"""
connection = self._get_connection()
cursor = connection.cursor()
procname = "{}_{}".format(self.model._meta.db_table, procname)
procname = connection.ops.quote_name(procname)
args = list(args)
arg_marks = ["%s"] * len(args)
arg_marks = ", ".join(arg_marks)
query = "select * from {}({});".format(procname, arg_marks)
cursor.execute(query, args)
ids = [i[0] for i in cursor.fetchall()]
return self.filter(id__in=ids)
def unchainable_procedure(self, procname, *args):
"""Returns a result that is not a record and can't be chained"""
connection = self._get_connection()
cursor = connection.cursor()
procname = "{}_{}".format(self.model._meta.db_table, procname)
procname = connection.ops.quote_name(procname)
args = list(args)
arg_marks = ["%s"] * len(args)
arg_marks = ", ".join(arg_marks)
query = "select * from {}({});".format(procname, arg_marks)
cursor.execute(query, args)
result = [{k: v for k, v in zip(map(itemgetter(0), cursor.description),
row)} for row in cursor.fetchall()]
return result
def single_valued_procedure(self, procname, *args):
"""Returns a single valued result."""
connection = self._get_connection()
cursor = connection.cursor()
procname = "{}_{}".format(self.model._meta.db_table, procname)
procname = connection.ops.quote_name(procname)
args = list(args)
arg_marks = ["%s"] * len(args)
arg_marks = ", ".join(arg_marks)
query = "select * from {}({});".format(procname, arg_marks)
cursor.execute(query, args)
return cursor.fetchone()[0]
def call_procedure(procname):
def wrapper(self, *args):
return self.procedure(procname, *args)
wrapper.__name__ = procname
return wrapper
def call_order_by_procedure(procname):
def wrapper(self, *args, field='id', extra_fields=None):
return self.order_by_procedure(procname, *args, field=field,
extra_fields=extra_fields)
wrapper.__name__ = procname
return wrapper
def call_unchainable_procedure(procname):
def wrapper(self, *args):
return self.unchainable_procedure(procname, *args)
wrapper.__name__ = procname
return wrapper
def call_single_valued_procedure(procname):
def wrapper(self, *args):
return self.single_valued_procedure(procname, *args)
wrapper.__name__ = procname
return wrapper