appbot/kms_rails

View on GitHub
lib/kms_rails/active_record.rb

Summary

Maintainability
B
5 hrs
Test Coverage
A
100%
require 'msgpack'
require 'kms_rails/core'

module KmsRails
  module ActiveRecord
    class << self
      def included base
        base.extend ClassMethods
      end
    end
    
    module ClassMethods
      def kms_attr(field, key_id:, retain: false, msgpack: false, context_key: nil, context_value: nil)
        include InstanceMethods

        real_field = "#{field}_enc"
        enc        = Core.new(key_id: key_id, msgpack: msgpack, context_key: context_key, context_value: context_value)

        define_method "#{field}=" do |data|
          raise RuntimeError, "Field '#{field}' must not be a real column, '#{real_field}' is the real column" if self.class.column_names.include?(field.to_s)
          raise RuntimeError, "Field '#{real_field}' must exist to store encrypted data" unless self.class.column_names.include?(real_field)

          if data.blank? # Just set to nil if nil
            clear_retained(field)
            self[real_field] = nil
            return 
          end

          set_retained(field, data) if retain
          encrypted_data = enc.encrypt(data)
          data = nil
          
          store_hash(field, encrypted_data)
        end

        define_method "#{real_field}" do
          raise RuntimeError, "Field '#{field}' must not be a real column, '#{real_field}' is the real column" if self.class.column_names.include?(field.to_s)
          raise RuntimeError, "Field '#{real_field}' must exist to retrieve encrypted data" unless self.class.column_names.include?(real_field)
          Core.to64( get_hash(field) )
        end

        define_method "#{field}" do
          raise RuntimeError, "Field '#{field}' must not be a real column, '#{real_field}' is the real column" if self.class.column_names.include?(field.to_s)
          raise RuntimeError, "Field '#{real_field}' must exist to retrieve decrypted data" unless self.class.column_names.include?(real_field)

          hash = get_hash(field)
          return nil unless hash

          if retain && (plaintext = get_retained(field))
            plaintext
          else
            plaintext = enc.decrypt(hash)
            set_retained(field, plaintext) if retain
            plaintext
          end
        end

        define_method "#{field}_clear" do
          clear_retained(field)
        end

      end
    end

    module InstanceMethods
      def store_hash(field, data)
        self["#{field}_enc"] = data.to_msgpack
      end

      def get_hash(field)
        hash = read_attribute("#{field}_enc")
        hash ? MessagePack.unpack(hash) : nil
      end

      def get_retained(field)
        @_retained ||= {}
        @_retained[field]
      end

      def set_retained(field, plaintext)
        @_retained ||= {}

        if @_retained[field]
          Core.shred_string(@_retained[field]) if @_retained[field].class == String
          @_retained[field] = nil
        end

        @_retained[field] = plaintext.dup
      end

      def clear_retained(field)
        @_retained ||= {}
        return if !@_retained.include?(field) || @_retained[field].nil?
        Core.shred_string(@_retained[field]) if @_retained[field].class == String
        @_retained[field] = nil
      end
    end
  end
end

if Object.const_defined?('ActiveRecord')
  ActiveRecord::Base.send(:include, KmsRails::ActiveRecord)
end