montydb/engine/update.py
from collections import OrderedDict
from datetime import datetime
from ..errors import WriteError
from .field_walker import FieldWalker, FieldWriteError
from .weighted import Weighted, _cmp_decimal
from .queries import QueryFilter, ordering
from ..types import (
bson,
string_types,
is_numeric_type,
is_duckument_type,
is_integer_type,
keep,
)
def _update(fieldwalker, field, value, evaluator, array_filters):
fieldwalker.go(field)
try:
fieldwalker.set(value, evaluator, array_filters)
# Take error message and put error code
except FieldWriteError as err:
raise WriteError(str(err), code=err.code)
def _drop(fieldwalker, field, array_filters):
fieldwalker.go(field)
try:
fieldwalker.drop(array_filters)
# Take error message and put error code
except FieldWriteError as err:
raise WriteError(str(err), code=err.code)
class Updator(object):
def __init__(self, spec, array_filters=None):
self.update_ops = {
# field update ops
"$inc": parse_inc,
"$min": parse_min,
"$max": parse_max,
"$mul": parse_mul,
"$rename": parse_rename,
"$set": parse_set,
"$setOnInsert": self.parse_set_on_insert,
"$unset": parse_unset,
"$currentDate": parse_currentDate,
# array update ops
# $ implemented in FieldWalker
# $[] implemented in FieldWalker
# $[<identifier>] implemented in FieldWalker
"$addToSet": parse_add_to_set,
"$pop": parse_pop,
"$pull": parse_pull,
"$push": parse_push,
"$pullAll": parse_pull_all,
# $each implemented in Eacher
# $position implemented in Eacher
# $slice implemented in Eacher
# $sort implemented in Eacher
}
self.fields_to_update = []
self.array_filters = self.array_filter_parser(array_filters or [])
# sort by key (operator)
self.operations = OrderedDict(sorted(self.parser(spec).items()))
self.__insert = None
self.__fieldwalker = None
def __repr__(self):
pass
def __call__(self, fieldwalker, do_insert=False):
"""Update document and return a bool value indicate changed or not"""
self.__fieldwalker = fieldwalker
self.__insert = do_insert
with fieldwalker:
for operator in self.operations.values():
operator(fieldwalker)
return fieldwalker.commit()
@property
def fieldwalker(self):
return self.__fieldwalker
def array_filter_parser(self, array_filters):
filters = {}
for i, filter_ in enumerate(array_filters):
top = ""
conds = {}
for identifier, cond in filter_.items():
id_s = identifier.split(".", 1)
if not top and id_s[0] in filters:
msg = (
"Found multiple array filters with the same "
"top-level field name {}".format(id_s[0])
)
raise WriteError(msg, code=9)
if top and id_s[0] != top:
msg = (
"Error parsing array filter: Expected a single "
"top-level field name, found {0!r} and {1!r}"
"".format(top, id_s[0])
)
raise WriteError(msg, code=9)
top = id_s[0]
conds.update({identifier: cond})
filters[top] = QueryFilter(conds)
return filters
def parser(self, spec):
if not next(iter(spec)).startswith("$"):
raise ValueError("update only works with $ operators")
update_stack = {}
idnt_tops = list(self.array_filters.keys())
for op, cmd_doc in spec.items():
if op not in self.update_ops:
raise WriteError("Unknown modifier: {}".format(op))
if not is_duckument_type(cmd_doc):
msg = (
"Modifiers operate on fields but we found type {0} "
"instead. For example: {{$mod: {{<field>: ...}}}} "
"not {1}".format(type(cmd_doc).__name__, spec)
)
raise WriteError(msg, code=9)
for field, value in cmd_doc.items():
for top in list(idnt_tops):
if "$[{}]".format(top) in field:
idnt_tops.remove(top)
update_stack[field] = self.update_ops[op](
field, value, self.array_filters
)
self.check_conflict(field)
if op == "$rename":
self.check_conflict(value)
if idnt_tops:
msg = (
"The array filter for identifier {0!r} was not "
"used in the update {1}".format(idnt_tops[0], spec)
)
raise WriteError(msg, code=9)
return update_stack
def check_conflict(self, field):
for staged in self.fields_to_update:
if field.startswith(staged) or staged.startswith(field):
msg = (
"Updating the path {0!r} would create a "
"conflict at {1!r}".format(field, staged[: len(field)])
)
raise WriteError(msg, code=40)
self.fields_to_update.append(field)
def parse_set_on_insert(self, field, value, array_filters):
@keep(value)
def _set_on_insert(fieldwalker):
if self.__insert:
parse_set(field, value, array_filters)(fieldwalker)
return _set_on_insert
def parse_inc(field, value, array_filters):
if not is_numeric_type(value):
val_repr_ = "{!r}" if isinstance(value, string_types) else "{}"
val_repr_ = val_repr_.format(value)
msg = "Cannot increment with non-numeric argument: {{{0}: {1}}}".format(
field, val_repr_
)
raise WriteError(msg, code=14)
@keep(value)
def _inc(fieldwalker):
def evaluator(node, inc_val):
old_val = node.value
if node.exists and not is_numeric_type(old_val):
_id = fieldwalker.doc["_id"]
value_type = type(old_val).__name__
msg = (
"Cannot apply $inc to a value of non-numeric type. "
"{{_id: {0}}} has the field {1!r} of non-numeric type "
"{2}".format(_id, str(node), value_type)
)
raise WriteError(msg, code=14)
is_decimal128 = False
if isinstance(old_val, bson.Decimal128):
is_decimal128 = True
old_val = old_val.to_decimal()
if isinstance(inc_val, bson.Decimal128):
is_decimal128 = True
inc_val = inc_val.to_decimal()
if is_decimal128:
return bson.Decimal128((old_val or 0) + inc_val)
else:
return (old_val or 0) + inc_val
_update(fieldwalker, field, value, evaluator, array_filters)
return _inc
def parse_min(field, value, array_filters):
@keep(value)
def _min(fieldwalker):
def evaluator(node, min_val):
old_val = node.value
if node.exists:
old_val = Weighted(old_val)
min_val = Weighted(min_val)
return min_val.value if min_val < old_val else old_val.value
else:
return min_val
_update(fieldwalker, field, value, evaluator, array_filters)
return _min
def parse_max(field, value, array_filters):
@keep(value)
def _max(fieldwalker):
def evaluator(node, max_val):
old_val = node.value
if node.exists:
old_val = Weighted(old_val)
max_val = Weighted(max_val)
return max_val.value if max_val > old_val else old_val.value
else:
return max_val
_update(fieldwalker, field, value, evaluator, array_filters)
return _max
def parse_mul(field, value, array_filters):
if not is_numeric_type(value):
val_repr_ = "{!r}" if isinstance(value, string_types) else "{}"
val_repr_ = val_repr_.format(value)
msg = "Cannot multiply with non-numeric argument: {{{0}: {1}}}".format(
field, val_repr_
)
raise WriteError(msg, code=14)
@keep(value)
def _mul(fieldwalker):
def evaluator(node, mul_val):
old_val = node.value
if node.exists and not is_numeric_type(old_val):
_id = fieldwalker.doc["_id"]
value_type = type(old_val).__name__
msg = (
"Cannot apply $mul to a value of non-numeric type. "
"{{_id: {0}}} has the field {1!r} of non-numeric type "
"{2}".format(_id, str(node), value_type)
)
raise WriteError(msg, code=14)
is_decimal128 = False
if isinstance(old_val, bson.Decimal128):
is_decimal128 = True
old_val = old_val.to_decimal()
if isinstance(mul_val, bson.Decimal128):
is_decimal128 = True
mul_val = mul_val.to_decimal()
if is_decimal128:
return bson.Decimal128((old_val or 0) * mul_val)
else:
return (old_val or 0.0) * mul_val
_update(fieldwalker, field, value, evaluator, array_filters)
return _mul
def _get_array_member(fieldvalues):
for node in fieldvalues.nodes:
if node.in_array:
return node
def parse_rename(field, new_field, array_filters):
if not isinstance(new_field, string_types):
msg = "The 'to' field for $rename must be a string: {0}: {1}".format(
field, new_field
)
raise WriteError(msg, code=2)
if field == new_field:
msg = (
"The source and target field for $rename must differ: "
"{0}: {1!r}".format(field, new_field)
)
raise WriteError(msg, code=2)
if field.startswith(new_field) or new_field.startswith(field):
msg = (
"The source and target field for $rename must not be on the "
"same path: {0}: {1!r}".format(field, new_field)
)
raise WriteError(msg, code=2)
@keep(new_field)
def _rename(fieldwalker):
probe = FieldWalker(fieldwalker.doc)
probe.go(field).get()
fieldvalues = probe.value
if not fieldvalues.is_exists():
return
value = next(fieldvalues.iter_plain())
array_member = _get_array_member(fieldvalues)
if array_member is not None:
_id = probe.doc["_id"]
array_field = str(array_member.parent)
msg = (
"The source field cannot be an array element, "
"{0!r} in doc with _id: {1} has an array field "
"called {2!r}".format(field, _id, array_field)
)
raise WriteError(msg, code=2)
probe.go(new_field).get()
fieldvalues = probe.value
array_member = _get_array_member(fieldvalues)
if array_member is not None:
_id = probe.doc["_id"]
array_field = str(array_member.parent)
msg = (
"The destination field cannot be an array element, "
"{0!r} in doc with _id: {1} has an array field "
"called {2!r}".format(new_field, _id, array_field)
)
raise WriteError(msg, code=2)
_drop(fieldwalker, field, array_filters)
_update(fieldwalker, new_field, value, None, array_filters)
return _rename
def parse_set(field, value, array_filters):
@keep(value)
def _set(fieldwalker):
_update(fieldwalker, field, value, None, array_filters)
return _set
def parse_unset(field, _, array_filters):
@keep(field)
def _unset(fieldwalker):
_drop(fieldwalker, field, array_filters)
return _unset
def parse_currentDate(field, value, array_filters):
date_type = {
"date": datetime.utcnow(),
"timestamp": bson.Timestamp(datetime.utcnow(), 1),
}
if not isinstance(value, bool):
if not is_duckument_type(value):
msg = (
"{} is not valid type for $currentDate. Please use a "
"boolean ('true') or a $type expression ({{$type: "
"'timestamp/date'}}).".format(type(value).__name__)
)
raise WriteError(msg, code=2)
for k, v in value.items():
if k != "$type":
msg = "Unrecognized $currentDate option: {}".format(k)
raise WriteError(msg, code=2)
if v not in date_type:
msg = (
"The '$type' string field is required to be 'date' "
"or 'timestamp': {$currentDate: {field : {$type: "
"'date'}}}"
)
raise WriteError(msg, code=2)
value = date_type[v]
else:
value = date_type["date"]
@keep(value)
def _currentDate(fieldwalker):
parse_set(field, value, array_filters)(fieldwalker)
return _currentDate
def parse_add_to_set(field, value_or_each, array_filters):
if is_duckument_type(value_or_each) and next(iter(value_or_each)) == "$each":
value = EachAdder(value_or_each)
run_each = True
else:
value = value_or_each
run_each = False
@keep(value)
def _add_to_set(fieldwalker):
def evaluator(node, new_elem):
old_val = node.value
if node.exists and not isinstance(old_val, list):
value_type = type(old_val).__name__
msg = (
"Cannot apply $addToSet to non-array field. Field "
"named {0!r} has non-array type {1}"
"".format(str(node), value_type)
)
raise WriteError(msg, code=2)
if run_each:
eacher = new_elem
new_array = eacher(old_val)
else:
new_array = (old_val or [])[:]
if new_elem not in new_array:
new_array.append(new_elem)
return new_array
_update(fieldwalker, field, value, evaluator, array_filters)
return _add_to_set
def parse_pop(field, value, array_filters):
if not is_numeric_type(value):
msg = "Expected a number in: {0}: {1!r}".format(field, value)
raise WriteError(msg, code=9)
else:
try:
value = float(value)
msg_raw = "Expected an integer: {0}: {1!r}"
except TypeError:
msg_raw = "Cannot represent as a 64-bit integer: {0}: {1!r}"
value = float(value.to_decimal())
if value not in (1.0, -1.0):
raise WriteError(msg_raw.format(field, value), code=9)
@keep(value)
def _pop(fieldwalker):
def evaluator(node, pop_ind):
old_val = node.value
if node.exists and not isinstance(old_val, list):
value_type = type(old_val).__name__
msg = (
"Path {0!r} contains an element of non-array type "
"{1!r}".format(str(node), value_type)
)
raise WriteError(msg, code=14)
if not node.exists:
# do nothing
return old_val
if pop_ind == 1:
return old_val[:-1]
else:
return old_val[1:]
_update(fieldwalker, field, value, evaluator, array_filters)
return _pop
def parse_pull(field, value_or_conditions, array_filters):
if is_duckument_type(value_or_conditions):
query_spec = {}
for k, v in value_or_conditions.items():
if not k[:1] == "$":
query_spec[".".join((field, k))] = v
else:
query_spec[field] = {k: v}
queryfilter = QueryFilter(query_spec)
else:
queryfilter = QueryFilter({field: value_or_conditions})
@keep(queryfilter)
def _pull(fieldwalker):
def evaluator(node, _):
old_val = node.value
if node.exists and not isinstance(old_val, list):
msg = "Cannot apply $pull to a non-array value"
raise WriteError(msg, code=2)
if not node.exists:
# do nothing
return old_val
new_array = []
for elem in old_val:
result = queryfilter({field: elem})
if not result:
new_array.append(elem)
return new_array
_update(fieldwalker, field, None, evaluator, array_filters)
return _pull
def parse_push(field, value_or_each, array_filters):
if is_duckument_type(value_or_each) and "$each" in value_or_each:
value = EachPusher(value_or_each)
run_each = True
else:
value = value_or_each
run_each = False
@keep(value)
def _push(fieldwalker):
def evaluator(node, new_elem):
old_val = node.value
if node.exists and not isinstance(old_val, list):
value_type = type(old_val).__name__
_id = fieldwalker.doc["_id"]
msg = (
"The field {0!r} must be an array but is of type "
"{1} in document {{_id: {2}}}"
"".format(str(node), value_type, _id)
)
raise WriteError(msg, code=2)
if run_each:
eacher = new_elem
new_array = eacher(old_val)
else:
new_array = (old_val or [])[:]
new_array.append(new_elem)
return new_array
_update(fieldwalker, field, value, evaluator, array_filters)
return _push
def parse_pull_all(field, value, array_filters):
if not isinstance(value, list):
value_type = type(value).__name__
msg = "$pullAll requires an array argument but was given a {}".format(
value_type
)
raise WriteError(msg, code=2)
@keep(value)
def _pull_all(fieldwalker):
def evaluator(node, pull_list):
old_val = node.value
if node.exists and not isinstance(old_val, list):
msg = "Cannot apply $pull to a non-array value"
raise WriteError(msg, code=2)
if not node.exists:
# do nothing
return old_val
def convert(lst):
for val in lst:
if isinstance(val, bson.Decimal128):
yield _cmp_decimal(val)
else:
yield val
pull_list = list(convert(pull_list))
old_val = list(convert(old_val))
new_array = [elem for elem in old_val if elem not in pull_list]
return new_array
_update(fieldwalker, field, value, evaluator, array_filters)
return _pull_all
class EachAdder(object):
def __init__(self, spec):
spec = spec.copy()
self.mods = {
"$each": None,
}
for mod, value in spec.items():
try:
type_check = self.validators[mod]
except KeyError:
raise WriteError(
"Found unexpected fields after $each in $addToSet: %s" % spec,
code=2,
)
self.mods[mod] = type_check(self, value)
def __call__(self, array):
new_array = (array or [])[:]
new_elems = self.mods["$each"][:]
new_array[0:] += [e for e in new_elems if e not in new_array]
return new_array
def _validate_each(self, each):
try:
each[:]
except TypeError:
type_name = type(each).__name__
raise WriteError(
"The argument to $each in $addToSet must be an "
"array but it was of type %s" % type_name,
code=14,
)
return each
validators = {
"$each": _validate_each,
}
class EachPusher(object):
def __init__(self, spec):
spec = spec.copy()
self.mods = {
"$each": None,
"$position": None,
"$slice": None,
"$sort": None,
}
for mod, value in spec.items():
try:
type_check = self.validators[mod]
except KeyError:
raise WriteError("Unrecognized clause in $push: %s" % mod, code=2)
self.mods[mod] = type_check(self, value)
def __call__(self, array):
new_array = (array or [])[:]
new_elems = self.mods["$each"][:]
position = self.mods["$position"]
slice = self.mods["$slice"]
sort = self.mods["$sort"]
if position is None:
new_array += new_elems
else:
new_array[:position] += new_elems
if slice is not None:
if slice >= 0:
new_array = new_array[:slice]
else:
new_array = new_array[slice:]
if sort is not None:
if is_duckument_type(sort):
fieldwalkers = list()
unsortable = list()
for elem in new_array:
if is_duckument_type(elem):
fieldwalkers.append(FieldWalker(elem))
else:
unsortable.append(elem)
ordered = ordering(fieldwalkers, sort)
new_array = [f.doc for f in ordered]
if unsortable:
is_reverse = bool(1 - next(iter(sort.values())))
if is_reverse:
new_array += unsortable
else:
new_array[:0] += unsortable
else:
is_reverse = bool(1 - sort)
ordered = sorted((Weighted(e) for e in new_array), reverse=is_reverse)
new_array = [w.value for w in ordered]
return new_array
def _validate_each(self, each):
try:
each[:]
except TypeError:
type_name = type(each).__name__
raise WriteError(
"The argument to $each in $push must be an "
"array but it was of type: %s" % type_name,
code=2,
)
return each
def _validate_position(self, position):
if not is_integer_type(position):
type_name = type(position).__name__
raise WriteError(
"The value for $position must be an integer "
"value, not of type: %s" % type_name,
code=2,
)
return position
def _validate_slice(self, slice):
if not is_integer_type(slice):
type_name = type(slice).__name__
raise WriteError(
"The value for $slice must be an integer "
"value but was given type: %s" % type_name,
code=2,
)
return slice
def _validate_sort(self, sort, int_only=False):
if is_integer_type(sort) or int_only:
if sort not in (1, -1):
raise WriteError(
"The $sort element value must be either 1 or -1", code=2
)
return sort
if is_duckument_type(sort):
for key, value in sort.items():
self._validate_sort(value, int_only=True)
return sort
raise WriteError(
"The $sort is invalid: use 1/-1 to sort the whole "
"element, or {field:1/-1} to sort embedded fields",
code=2,
)
validators = {
"$each": _validate_each,
"$position": _validate_position,
"$slice": _validate_slice,
"$sort": _validate_sort,
}