renalreg/cornflake

View on GitHub
cornflake/sqlalchemy_orm.py

Summary

Maintainability
A
2 hrs
Test Coverage
from sqlalchemy.sql import sqltypes
from sqlalchemy.dialects import postgresql
from sqlalchemy import inspect
from sqlalchemy.orm import ColumnProperty

from cornflake import fields, serializers


class ModelSerializer(serializers.Serializer):
    type_map = {
        sqltypes.String: fields.StringField,
        sqltypes.Integer: fields.IntegerField,
        sqltypes.BigInteger: fields.IntegerField,
        sqltypes.Date: fields.DateField,
        sqltypes.DateTime: fields.DateTimeField,
        sqltypes.Boolean: fields.BooleanField,
        sqltypes.Numeric: fields.FloatField,
        postgresql.INET: fields.StringField,
        postgresql.UUID: fields.UUIDField,
        postgresql.JSONB: fields.Field
    }

    class Meta(object):
        model_class = None

    def get_model_class(self):
        model_class = self.Meta.model_class
        assert model_class is not None
        return model_class

    def get_model_fields(self):
        """ List of model fields to include (defaults to all) """

        model_fields = getattr(self.Meta, 'fields', None)

        if model_fields is not None:
            model_fields = set(model_fields)

        return model_fields

    def get_model_exclude(self):
        """ List of fields to exclude """

        return set(getattr(self.Meta, 'exclude', []))

    def get_model_read_only(self):
        """ Fields that should be read only (serialized but not deserialized) """

        return set(getattr(self.Meta, 'read_only', []))

    def get_model_write_only(self):
        """ Fields that should be write only (deserialized but not serialized) """

        return set(getattr(self.Meta, 'write_only', []))

    def get_field_class(self, col_type):
        for sql_type, field_type in self.type_map.items():
            if isinstance(col_type, sql_type):
                return field_type

        return None

    def get_fields(self):
        fields = super(ModelSerializer, self).get_fields()

        model_fields = self.get_model_fields()
        model_exclude = self.get_model_exclude()
        model_read_only = self.get_model_read_only()
        model_write_only = self.get_model_write_only()

        props = inspect(self.get_model_class()).attrs

        for prop in props:
            if not isinstance(prop, ColumnProperty):
                continue

            key = prop.key

            # Field explicitly defined
            if key in fields:
                continue

            # Not in field list
            if model_fields is not None and key not in model_fields:
                continue

            # Field excluded
            if key in model_exclude:
                continue

            col = prop.columns[0]
            col_type = col.type

            field_kwargs = {}

            # Read only field
            # Don't allow id column to be updated
            # TODO(rupert) default to read only if primary key (remove 'id' check)
            if key in model_read_only or key == 'id':
                field_kwargs['read_only'] = True

            # Write only field
            if key in model_write_only:
                field_kwargs['write_only'] = True

            # Get the field class for this column type
            field_class = self.get_field_class(col_type)

            # This will skip column types we can't handle
            if field_class is not None:
                field = field_class(**field_kwargs)
                field.bind(self, key)
                fields[key] = field

        return fields

    def create(self, validated_data):
        model_class = self.get_model_class()
        instance = model_class()

        for attr, value in validated_data.items():
            if hasattr(instance, attr):
                setattr(instance, attr, value)

        return instance

    def update(self, instance, validated_data):
        for attr, value in validated_data.items():
            if hasattr(instance, attr):
                setattr(instance, attr, value)

        return instance


class ReferenceField(fields.Field):
    type_map = {
        sqltypes.String: fields.StringField,
        sqltypes.Integer: fields.IntegerField,
        postgresql.UUID: fields.UUIDField,
    }

    error_messages = {
        'not_found': 'Object not found.',
        'no_id': 'No ID supplied.'
    }

    model_class = None

    # TODO(rupert) use Model.id instead, default to getattr(model_class, 'id')
    model_id = 'id'

    serializer_class = None

    def __init__(self, **kwargs):
        self.model_class = kwargs.pop('model_class', self.model_class)
        self.model_id = kwargs.pop('model_id', self.model_id)
        self.serializer_class = kwargs.pop('serializer_class', self.serializer_class)

        assert self.model_class is not None
        assert self.model_id is not None

        super(ReferenceField, self).__init__(**kwargs)

        self.field = self.get_field()
        self.serializer = self.get_serializer()

    def get_serializer(self):
        serializer_class = self.serializer_class

        if serializer_class is not None:
            serializer = serializer_class()
        else:
            serializer = None

        return serializer

    def get_field_class(self):
        prop = getattr(inspect(self.model_class).attrs, self.model_id)
        col = prop.columns[0]
        col_type = col.type

        for sql_type, field_type in self.type_map.items():
            if isinstance(col_type, sql_type):
                return field_type

        return fields.StringField

    def get_field(self):
        return self.get_field_class()()

    def bind(self, parent, field_name=None):
        super(ReferenceField, self).bind(parent, field_name)
        self.field.bind(self, field_name)

        if self.serializer is not None:
            self.serializer.bind(self, field_name)

    def get_instance(self, id):
        attribute = getattr(self.model_class, self.model_id)
        instance = self.model_class.query.filter(attribute == id).first()

        if instance is None:
            self.fail('not_found')

        return instance

    def to_internal_value(self, data):
        if isinstance(data, self.model_class):
            return data

        if isinstance(data, dict):
            value = data.get(self.model_id)

            if value is None:
                self.fail('no_id')

            instance_id = self.field.to_internal_value(value)
        else:
            instance_id = self.field.to_internal_value(data)

        instance = self.get_instance(instance_id)

        return instance

    def to_representation(self, instance):
        if self.serializer is not None:
            return self.serializer.to_representation(instance)
        else:
            instance_id = getattr(instance, self.model_id)
            return self.field.to_representation(instance_id)