Crystalnix/termius-cli

View on GitHub
termius/core/commands/mixins.py

Summary

Maintainability
B
4 hrs
Test Coverage
# -*- coding: utf-8 -*-
"""Module with different CLI commands mixins."""
import getpass
import os
from operator import attrgetter
from functools import partial
from cached_property import cached_property
from ..exceptions import (
    DoesNotExistException, ArgumentRequiredException,
    TooManyEntriesException, SkipField
)
from ..models.terminal import SshConfig, Identity
from .utils import parse_ids_names, DefaultAttrGetter
from ..models.utils import GroupStackGenerator, Merger


# pylint: disable=too-few-public-methods
class PasswordPromptMixin(object):
    """Mixin to command to call account password prompt."""

    # pylint: disable=no-self-use
    def prompt_password(self):
        """Ask user to enter password in secure way."""
        return getpass.getpass('Password:')


class GetRelationMixin(object):
    """Mixin that add way to retrieve entry per id or name."""

    def get_relation(self, model_class, arg):
        """Retrieve relation object from storage."""
        try:
            relation_id = int(arg)
        except ValueError:
            relation_id = None
        try:
            return self.storage.get(model_class, query_union=any,
                                    id=relation_id, label=arg)
        except DoesNotExistException:
            self.fail_not_exist(model_class)
        except TooManyEntriesException:
            self.fail_too_many(model_class)

    # pylint: disable=no-self-use
    def fail_not_exist(self, model_class):
        """Raise an error about not existed instance."""
        raise ArgumentRequiredException(
            'Not found any {} instance.'.format(model_class.__name__.lower())
        )

    # pylint: disable=no-self-use
    def fail_too_many(self, model_class):
        """Raise an error about too many instances."""
        raise ArgumentRequiredException(
            'Found too many {} instances.'.format(model_class.__name__.lower())
        )

    def get_safely_instance(self, model_class, arg):
        """Provide safer way to get relations."""
        return arg and self.get_relation(model_class, arg)

    def get_safely_instance_partial(self, model, arg_name):
        """Return wrap get_safely_instance() with partial for arg_name."""
        return partial(self._safely_instance, model=model, arg_name=arg_name)

    def _safely_instance(self, args, model, arg_name):
        value = getattr(args, arg_name)
        return self.get_safely_instance(model, value)


class PrepareResultMixin(object):
    """Mixin with method to transform dict-list to 2-size tuple."""

    @property
    def prepare_fields(self):
        """Return fields for model."""
        return self.model_class.allowed_fields()

    def prepare_result(self, found_list):
        """Return tuple with data in format for Lister."""
        fields = sorted(list(set(self.prepare_fields) - set(self.skip_fields)))
        getter = DefaultAttrGetter(*fields)
        return fields, [getter(i) for i in found_list]


class SshConfigPrepareMixin(PrepareResultMixin):
    """Mixin with methods to render ssh config and identity fields."""

    @property
    def prepare_fields(self):
        """Return fields for model."""
        return (
            self.instance_fields +
            self.ssh_config_fields +
            self.identity_fields
        )

    @property
    def instance_fields(self):
        """Return instance fields."""
        return [
            i for i in list(self.model_class.allowed_fields())
            if i != 'ssh_config'
        ]

    @property
    def ssh_config_fields(self):
        """Return ssh config fields."""
        fields = SshConfig.allowed_fields()
        field_format = 'ssh_config.{}'.format
        return [
            field_format(i) for i in fields
            if i != 'identity'
        ]

    @property
    def identity_fields(self):
        """Return identity fields."""
        fields = Identity.allowed_fields()
        field_format = 'ssh_config.identity.{}'.format
        return [
            field_format(i) for i in fields
            if i != 'identity'
        ]


class GetObjectsMixin(object):
    """Mixin with method to list objects with ids or name list."""

    def get_objects(self, ids__names):
        """Get model list.

        Models will match id and label with passed ids__names list.
        """
        ids, names = parse_ids_names(ids__names)
        instances = self.storage.filter(
            self.model_class, any,
            **{'id.rcontains': ids, 'label.rcontains': names}
        )
        if not instances:
            raise DoesNotExistException("There aren't any instance.")
        return instances


class ArgModelSerializerMixin(object):
    """Class to keep logic of command line args serialization to model."""

    @cached_property
    def fields(self):
        """Return dictionary of args serializers to models field."""
        return {
            i: attrgetter(i) for i in self.model_class.fields
        }

    # pylint: disable=no-self-use
    def serialize_args(self, args, instance=None):
        """Convert args to instance."""
        instance = instance or self.model_class()
        for i in self.model_class.fields:
            try:
                value = self.fields[i](args)
            except (SkipField, KeyError):
                continue
            if value is not None:
                setattr(instance, i, value)
        self.validate(instance)
        return instance

    # pylint: disable=unused-argument
    def validate(self, instance):
        """Validate models fields before saving."""
        return instance

    # pylint: disable=unused-argument
    def skip(self, args):
        """Call to skip field serialization."""
        raise SkipField()


class InstanceOperationMixin(ArgModelSerializerMixin, object):
    """Mixin with methods to create, update and delete operations."""

    def create_instance(self, args):
        """Create new model entry."""
        instance = self.serialize_args(args)
        with self.storage:
            self.pre_save(instance)
            saved_instance = self.storage.save(instance)
            instance.id = saved_instance.id
            self.update_children(instance, args)
        self.log_create(saved_instance)

    def update_instance(self, args, instance):
        """Update model entry."""
        instance = self.serialize_args(args, instance)
        with self.storage:
            self.pre_save(instance)
            self.storage.save(instance)
            self.update_children(instance, args)
        self.log_update(instance)

    # pylint: disable=no-self-use,unused-argument
    def pre_save(self, instance):
        """Patch instance fields before saving."""

    def update_children(self, instance, args):
        """Update children of instance.

        It's called while create and update instance.
        """

    def delete_instance(self, instance):
        """Delete model entry."""
        with self.storage:
            self.storage.delete(instance)
        self.log_delete(instance)

    def log_create(self, entry):
        """Log creating new model entry."""
        self._general_log(entry, 'Entry created.')

    def log_update(self, entry):
        """Log updating model entry."""
        self._general_log(entry, 'Entry updated.')

    def log_delete(self, entry):
        """Log deleting model entry."""
        self._general_log(entry, 'Entry deleted.')

    def _general_log(self, entry, message):
        self.log.info(message)

        if os.getenv('TERMIUS_CLI_DEBUG'):
            self.app.stdout.write('{}\n'.format(entry.id))


class GroupStackGetterMixin(object):
    """Mixin to get whole stack of parent groups."""

    # pylint: disable=no-self-use
    def get_group_stack(self, instance):
        """Generate parent group stack for instance."""
        stack_generator = GroupStackGenerator(instance)
        return stack_generator.generate()


class SshConfigMergerMixin(GroupStackGetterMixin, object):
    """Mixin to squash (aka merge) stack to single ssh config."""

    def get_merged_ssh_config(self, instance):
        """Get merged ssh config instance for instance.

        :param instance: Host or Group instance.
        """
        group_stack = self.get_group_stack(instance)
        full_stack = [instance] + group_stack
        return self.merge_ssh_config(full_stack)

    def merge_ssh_config(self, full_stack):
        """Squash full_stack to single ssh_config instance."""
        ssh_config_merger = self.get_ssh_config_merger(full_stack)
        identity_merger = self.get_identity_merger(ssh_config_merger)
        ssh_config = ssh_config_merger.merge()
        visible_identity = self.get_visible_identity(ssh_config_merger)
        if visible_identity:
            ssh_config.identity = visible_identity
        else:
            ssh_config.identity = identity_merger.merge()
        return ssh_config

    # pylint: disable=no-self-use
    def get_ssh_config_merger(self, stack):
        """Create ssh config merger for passed stack."""
        return Merger(stack, 'ssh_config', SshConfig())

    # pylint: disable=no-self-use
    def get_visible_identity(self, ssh_config_merger):
        """Return first of visible identity."""
        stack = [
            i.identity for i in ssh_config_merger.get_entry_stack()
            if i.identity and i.identity.get('is_visible')
        ]
        return stack[0] if stack else None

    # pylint: disable=no-self-use
    def get_identity_merger(self, ssh_config_merger):
        """Create identity merger for passed merger."""
        stack = [
            i for i in ssh_config_merger.get_entry_stack()
            if i.identity and not i.identity.get('is_visible')
        ]
        return Merger(stack, 'identity', Identity())