saltstack/salt

View on GitHub
salt/utils/dictupdate.py

Summary

Maintainability
C
1 day
Test Coverage
# -*- coding: utf-8 -*-
'''
Alex Martelli's soulution for recursive dict update from
http://stackoverflow.com/a/3233356
'''

# Import python libs
from __future__ import absolute_import, print_function, unicode_literals

try:
    from collections.abc import Mapping
except ImportError:
    from collections import Mapping

# Import 3rd-party libs
import copy
import logging

# Import salt libs
import salt.ext.six as six
from salt.defaults import DEFAULT_TARGET_DELIM
import salt.utils.data
from salt.exceptions import SaltInvocationError
from salt.utils.odict import OrderedDict
from salt.utils.decorators.jinja import jinja_filter

log = logging.getLogger(__name__)


def update(dest, upd, recursive_update=True, merge_lists=False):
    '''
    Recursive version of the default dict.update

    Merges upd recursively into dest

    If recursive_update=False, will use the classic dict.update, or fall back
    on a manual merge (helpful for non-dict types like FunctionWrapper)

    If merge_lists=True, will aggregate list object types instead of replace.
    The list in ``upd`` is added to the list in ``dest``, so the resulting list
    is ``dest[key] + upd[key]``. This behavior is only activated when
    recursive_update=True. By default merge_lists=False.

    .. versionchanged: 2016.11.6
        When merging lists, duplicate values are removed. Values already
        present in the ``dest`` list are not added from the ``upd`` list.
    '''
    if (not isinstance(dest, Mapping)) \
            or (not isinstance(upd, Mapping)):
        raise TypeError('Cannot update using non-dict types in dictupdate.update()')
    updkeys = list(upd.keys())
    if not set(list(dest.keys())) & set(updkeys):
        recursive_update = False
    if recursive_update:
        for key in updkeys:
            val = upd[key]
            try:
                dest_subkey = dest.get(key, None)
            except AttributeError:
                dest_subkey = None
            if isinstance(dest_subkey, Mapping) \
                    and isinstance(val, Mapping):
                ret = update(dest_subkey, val, merge_lists=merge_lists)
                dest[key] = ret
            elif isinstance(dest_subkey, list) and isinstance(val, list):
                if merge_lists:
                    merged = copy.deepcopy(dest_subkey)
                    merged.extend([x for x in val if x not in merged])
                    dest[key] = merged
                else:
                    dest[key] = upd[key]
            else:
                dest[key] = upd[key]
        return dest
    try:
        for k in upd:
            dest[k] = upd[k]
    except AttributeError:
        # this mapping is not a dict
        for k in upd:
            dest[k] = upd[k]
    return dest


def merge_list(obj_a, obj_b):
    ret = {}
    for key, val in six.iteritems(obj_a):
        if key in obj_b:
            ret[key] = [val, obj_b[key]]
        else:
            ret[key] = val
    return ret


def merge_recurse(obj_a, obj_b, merge_lists=False):
    copied = copy.deepcopy(obj_a)
    return update(copied, obj_b, merge_lists=merge_lists)


def merge_aggregate(obj_a, obj_b):
    from salt.serializers.yamlex import merge_recursive as _yamlex_merge_recursive
    return _yamlex_merge_recursive(obj_a, obj_b, level=1)


def merge_overwrite(obj_a, obj_b, merge_lists=False):
    for obj in obj_b:
        if obj in obj_a:
            obj_a[obj] = obj_b[obj]
    return merge_recurse(obj_a, obj_b, merge_lists=merge_lists)


def merge(obj_a, obj_b, strategy='smart', renderer='yaml', merge_lists=False):
    if strategy == 'smart':
        if renderer.split('|')[-1] == 'yamlex' or renderer.startswith('yamlex_'):
            strategy = 'aggregate'
        else:
            strategy = 'recurse'

    if strategy == 'list':
        merged = merge_list(obj_a, obj_b)
    elif strategy == 'recurse':
        merged = merge_recurse(obj_a, obj_b, merge_lists)
    elif strategy == 'aggregate':
        #: level = 1 merge at least root data
        merged = merge_aggregate(obj_a, obj_b)
    elif strategy == 'overwrite':
        merged = merge_overwrite(obj_a, obj_b, merge_lists)
    elif strategy == 'none':
        # If we do not want to merge, there is only one pillar passed, so we can safely use the default recurse,
        # we just do not want to log an error
        merged = merge_recurse(obj_a, obj_b)
    else:
        log.warning(
            'Unknown merging strategy \'%s\', fallback to recurse',
            strategy
        )
        merged = merge_recurse(obj_a, obj_b)

    return merged


def ensure_dict_key(
        in_dict,
        keys,
        delimiter=DEFAULT_TARGET_DELIM,
        ordered_dict=False):
    '''
    Ensures that in_dict contains the series of recursive keys defined in keys.

    :param dict in_dict: The dict to work with.
    :param str keys: The delimited string with one or more keys.
    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
    :param bool ordered_dict: Create OrderedDicts if keys are missing.
                              Default: create regular dicts.
    '''
    if delimiter in keys:
        a_keys = keys.split(delimiter)
    else:
        a_keys = [keys]
    dict_pointer = in_dict
    while a_keys:
        current_key = a_keys.pop(0)
        if current_key not in dict_pointer or not isinstance(dict_pointer[current_key], dict):
            dict_pointer[current_key] = OrderedDict() if ordered_dict else {}
        dict_pointer = dict_pointer[current_key]
    return in_dict


def _dict_rpartition(
        in_dict,
        keys,
        delimiter=DEFAULT_TARGET_DELIM,
        ordered_dict=False):
    '''
    Helper function to:
    - Ensure all but the last key in `keys` exist recursively in `in_dict`.
    - Return the dict at the one-to-last key, and the last key

    :param dict in_dict: The dict to work with.
    :param str keys: The delimited string with one or more keys.
    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
    :param bool ordered_dict: Create OrderedDicts if keys are missing.
                              Default: create regular dicts.

    :return tuple(dict, str)
    '''
    if delimiter in keys:
        all_but_last_keys, _, last_key = keys.rpartition(delimiter)
        ensure_dict_key(in_dict,
                        all_but_last_keys,
                        delimiter=delimiter,
                        ordered_dict=ordered_dict)
        dict_pointer = salt.utils.data.traverse_dict(in_dict,
                                                     all_but_last_keys,
                                                     default=None,
                                                     delimiter=delimiter)
    else:
        dict_pointer = in_dict
        last_key = keys
    return dict_pointer, last_key


@jinja_filter('set_dict_key_value')
def set_dict_key_value(
        in_dict,
        keys,
        value,
        delimiter=DEFAULT_TARGET_DELIM,
        ordered_dict=False):
    '''
    Ensures that in_dict contains the series of recursive keys defined in keys.
    Also sets whatever is at the end of `in_dict` traversed with `keys` to `value`.

    :param dict in_dict: The dictionary to work with
    :param str keys: The delimited string with one or more keys.
    :param any value: The value to assign to the nested dict-key.
    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
    :param bool ordered_dict: Create OrderedDicts if keys are missing.
                              Default: create regular dicts.

    :return dict: Though it updates in_dict in-place.
    '''
    dict_pointer, last_key = _dict_rpartition(in_dict,
                                              keys,
                                              delimiter=delimiter,
                                              ordered_dict=ordered_dict)
    dict_pointer[last_key] = value
    return in_dict


@jinja_filter('update_dict_key_value')
def update_dict_key_value(
        in_dict,
        keys,
        value,
        delimiter=DEFAULT_TARGET_DELIM,
        ordered_dict=False):
    '''
    Ensures that in_dict contains the series of recursive keys defined in keys.
    Also updates the dict, that is at the end of `in_dict` traversed with `keys`,
    with `value`.

    :param dict in_dict: The dictionary to work with
    :param str keys: The delimited string with one or more keys.
    :param any value: The value to update the nested dict-key with.
    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
    :param bool ordered_dict: Create OrderedDicts if keys are missing.
                              Default: create regular dicts.

    :return dict: Though it updates in_dict in-place.
    '''
    dict_pointer, last_key = _dict_rpartition(in_dict,
                                              keys,
                                              delimiter=delimiter,
                                              ordered_dict=ordered_dict)
    if last_key not in dict_pointer or dict_pointer[last_key] is None:
        dict_pointer[last_key] = OrderedDict() if ordered_dict else {}
    try:
        dict_pointer[last_key].update(value)
    except AttributeError:
        raise SaltInvocationError('The last key contains a {}, which cannot update.'
                                  ''.format(type(dict_pointer[last_key])))
    except (ValueError, TypeError):
        raise SaltInvocationError('Cannot update {} with a {}.'
                                  ''.format(type(dict_pointer[last_key]), type(value)))
    return in_dict


@jinja_filter('append_dict_key_value')
def append_dict_key_value(
        in_dict,
        keys,
        value,
        delimiter=DEFAULT_TARGET_DELIM,
        ordered_dict=False):
    '''
    Ensures that in_dict contains the series of recursive keys defined in keys.
    Also appends `value` to the list that is at the end of `in_dict` traversed
    with `keys`.

    :param dict in_dict: The dictionary to work with
    :param str keys: The delimited string with one or more keys.
    :param any value: The value to append to the nested dict-key.
    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
    :param bool ordered_dict: Create OrderedDicts if keys are missing.
                              Default: create regular dicts.

    :return dict: Though it updates in_dict in-place.
    '''
    dict_pointer, last_key = _dict_rpartition(in_dict,
                                              keys,
                                              delimiter=delimiter,
                                              ordered_dict=ordered_dict)
    if last_key not in dict_pointer or dict_pointer[last_key] is None:
        dict_pointer[last_key] = []
    try:
        dict_pointer[last_key].append(value)
    except AttributeError:
        raise SaltInvocationError('The last key contains a {}, which cannot append.'
                                  ''.format(type(dict_pointer[last_key])))
    return in_dict


@jinja_filter('extend_dict_key_value')
def extend_dict_key_value(
        in_dict,
        keys,
        value,
        delimiter=DEFAULT_TARGET_DELIM,
        ordered_dict=False):
    '''
    Ensures that in_dict contains the series of recursive keys defined in keys.
    Also extends the list, that is at the end of `in_dict` traversed with `keys`,
    with `value`.

    :param dict in_dict: The dictionary to work with
    :param str keys: The delimited string with one or more keys.
    :param any value: The value to extend the nested dict-key with.
    :param str delimiter: The delimiter to use in `keys`. Defaults to ':'.
    :param bool ordered_dict: Create OrderedDicts if keys are missing.
                              Default: create regular dicts.

    :return dict: Though it updates in_dict in-place.
    '''
    dict_pointer, last_key = _dict_rpartition(
        in_dict,
        keys,
        delimiter=delimiter,
        ordered_dict=ordered_dict)
    if last_key not in dict_pointer or dict_pointer[last_key] is None:
        dict_pointer[last_key] = []
    try:
        dict_pointer[last_key].extend(value)
    except AttributeError:
        raise SaltInvocationError('The last key contains a {}, which cannot extend.'
                                  ''.format(type(dict_pointer[last_key])))
    except TypeError:
        raise SaltInvocationError('Cannot extend {} with a {}.'
                                  ''.format(type(dict_pointer[last_key]), type(value)))
    return in_dict