jedymatt/sqlalchemyseed

View on GitHub
src/sqlalchemyseed/seeder.py

Summary

Maintainability
A
2 hrs
Test Coverage
"""
Seeder module
"""

from typing import Union

import sqlalchemy
from sqlalchemy.orm import Session

from sqlalchemyseed import util
from sqlalchemyseed.attribute import instrumented_attribute, referenced_class, set_instance_attribute


class Seeder:
    """
    Seeder class
    """

    def __init__(self, session: Session) -> None:
        self.session = session

        self._instances = []

    def reset(self):
        """
        Reset seeder.
        """
        self._instances = []

    def seed(self, data: Union[dict, list]) -> None:
        """
        Seed data into database.
        """
        self.reset()

        if isinstance(data, dict):
            self._seed_dict(data)
        elif isinstance(data, list):
            self._seed_list(data)
        else:
            raise TypeError("'data' should be 'dict' or 'list'.")

    @property
    def instances(self) -> list:
        """
        Instances that are created by seeding.

        Returns:
            list: Instances that are created by seeding.
        """
        return self._instances

    def commit(self):
        """
        Commit seeding data into database.
        """
        self.session.commit()

    def _seed_dict(self, data: dict):
        model: str = data.get('model')
        data_: Union[dict, list] = data.get('data')
        where: Union[dict, list] = data.get('where')

        model_class = util.get_model_class(model)

        instances = []
        if where is not None:
            instances = self._seed_where(model_class, where)
        else:  # where is None and data is not None
            instances = self._seed_data_(model_class, data_)

        self._instances.extend(instances)

    def _seed_list(self, data: list):
        for item in data:
            self._seed_dict(item)

    def _seed_dict_where(self, model_class: sqlalchemy.orm.mapper, where: dict):
        kwargs = where.copy()
        kwargs.pop('$rel', None)
        return self.session.query(model_class).filter_by(**kwargs).first()

    def _instantiate_model_class(self, model_class: sqlalchemy.orm.mapper, data_: dict):
        kwargs = data_.copy()
        kwargs.pop('$rel', None)
        return model_class(**kwargs)

    def _seed_rel(self, rel: dict, instance: sqlalchemy.orm.mapper):
        for key, value in rel.items():
            if isinstance(value, dict):
                self._seed_rel_attribute_value(instance, key, value)
            elif isinstance(value, list):
                for item in value:
                    self._seed_rel_attribute_value(instance, key, item)

    def _seed_rel_attribute_value(self, instance: sqlalchemy.orm.mapper, key: str, value: dict):
        model = value.get('model')
        data = value.get('data')
        where = value.get('where')

        if model is not None:
            model_class = util.get_model_class(model)
        else:
            model_class = referenced_class(
                instrumented_attribute(instance, key)
            )

        if where is not None:
            rel_instances = self._seed_where(model_class, where)
        else:
            rel_instances = self._seed_data_(model_class, data)

        set_instance_attribute(instance, key, rel_instances)

    def _seed_where(self, model_class: sqlalchemy.orm.mapper, where: dict):
        instances = []

        if isinstance(where, dict):
            instance = self._seed_dict_where(model_class, where)
            instances.append(instance)

            self.session.add(instance)

            if where.get('$rel') is not None:
                self._seed_rel(where['$rel'], instance)
        else:
            for item in where:
                instance = self._seed_dict_where(model_class, item)
                instances.append(instance)

                self.session.add_all(instances)

                if item.get('$rel') is not None:
                    self._seed_rel(item['$rel'], instance)

        return instances

    def _seed_data_(self, model_class: sqlalchemy.orm.mapper, data_: dict) -> list:
        instances = []

        if isinstance(data_, dict):
            instance = self._instantiate_model_class(model_class, data_)
            instances.append(instance)

            self.session.add(instance)

            if data_.get('$rel') is not None:
                self._seed_rel(data_['$rel'], instance)
        else:
            for item in data_:
                instance = self._instantiate_model_class(model_class, item)
                instances.append(instance)

                self.session.add_all(instances)

                if item.get('$rel') is not None:
                    self._seed_rel(item['$rel'], instance)

        return instances