take-five/acts_as_ordered_tree

View on GitHub
lib/acts_as_ordered_tree/transaction/dsl.rb

Summary

Maintainability
A
35 mins
Test Coverage
# coding: utf-8

require 'active_support/core_ext/string/inflections'

unless defined? Arel::Nodes::Case
  module Arel
    module Nodes
      # Case node
      #
      # @example
      #   switch.when(table[:x].gt(1), table[:y]).else(table[:z])
      #   # CASE WHEN "table"."x" > 1 THEN "table"."y" ELSE "table"."z" END
      #   switch.when(table[:x].gt(1)).then(table[:y]).else(table[:z])
      class Case < Arel::Nodes::Node
        include Arel::OrderPredications
        include Arel::Predications

        attr_reader :conditions, :default

        def initialize
          @conditions = []
          @default = nil
        end

        def when(condition, expression = nil)
          @conditions << When.new(condition, expression)
          self
        end

        def then(expression)
          @conditions.last.right = expression
          self
        end

        def else(expression)
          @default = Else.new(expression)
          self
        end
      end

      class When < Arel::Nodes::Binary
      end

      class Else < Arel::Nodes::Unary
      end
    end

    module Visitors
      class ToSql < Arel::Visitors::ToSql.superclass
        private
        def visit_Arel_Nodes_Case o, *a
          conditions = o.conditions.map { |x| visit x, *a }.join(' ')
          default = o.default && visit(o.default, *a)

          "CASE #{[conditions, default].compact.join(' ')} END"
        end

        def visit_Arel_Nodes_When o, *a
          "WHEN #{visit o.left, *a} THEN #{visit o.right, *a}"
        end

        def visit_Arel_Nodes_Else o, *a
          "ELSE #{visit o.expr, *a}"
        end

        if Arel::VERSION >= '6.0.0'
          def visit_Arel_Nodes_Case o, collector
            collector << 'CASE '
            o.conditions.each do |x|
              visit x, collector
              collector << ' '
            end
            if o.default
              visit o.default, collector
              collector << ' '
            end
            collector << 'END'
          end

          def visit_Arel_Nodes_When o, collector
            collector << 'WHEN '
            visit o.left, collector
            collector << ' THEN '
            visit o.right, collector
          end

          def visit_Arel_Nodes_Else o, collector
            collector << 'ELSE'
            visit o.expr, collector
          end

          def visit_NilClass o, collector
            collector << 'NULL'
          end
        end
      end

      class DepthFirst < Arel::Visitors::Visitor
        def visit_Arel_Nodes_Case o, *a
          visit o.conditions, *a
          visit o.default, *a
        end
        alias :visit_Arel_Nodes_When :binary
        alias :visit_Arel_Nodes_Else :unary
      end
    end
  end
end

module ActsAsOrderedTree
  module Transaction
    # Simple DSL to generate complex UPDATE queries.
    # Requires +record+ method.
    #
    # @api private
    module DSL
      module Shortcuts
        INFIX_OPERATIONS = Hash[
            :==   => Arel::Nodes::Equality,
            :'!=' => Arel::Nodes::NotEqual,
            :>    => Arel::Nodes::GreaterThan,
            :>=   => Arel::Nodes::GreaterThanOrEqual,
            :<    => Arel::Nodes::LessThan,
            :<=   => Arel::Nodes::LessThanOrEqual,
            :=~   => Arel::Nodes::Matches,
            :'!~' => Arel::Nodes::DoesNotMatch,
            :|    => Arel::Nodes::Or
        ]

        # generate subclasses and methods
        INFIX_OPERATIONS.each do |operator, klass|
          subclass = Class.new(klass) { include Shortcuts }
          const_set(klass.name.demodulize, subclass)
          INFIX_OPERATIONS[operator] = subclass

          define_method(operator) do |arg|
            subclass.new(self, arg)
          end
        end

        And = Class.new(Arel::Nodes::And) { include Shortcuts }

        def &(arg)
          And.new [self, arg]
        end
      end

      Attribute = Class.new(Arel::Attributes::Attribute) { include Shortcuts }
      SqlLiteral = Class.new(Arel::Nodes::SqlLiteral) { include Shortcuts }

      NamedFunction = Class.new(Arel::Nodes::NamedFunction) {
        include Shortcuts
        include Arel::Math
      }

      # Create Arel::Nodes::Case node
      def switch
        Arel::Nodes::Case.new
      end

      # Create assignments expression for UPDATE statement
      #
      # @example
      #   Model.where(:parent_id => nil).update_all(set :name => switch.when(x < 10).then('OK').else('TOO LARGE'))
      #
      # @param [Hash] assignments
      def set(assignments)
        assignments.map do |attr, value|
          next unless attr.present?

          name = attr.is_a?(Arel::Attributes::Attribute) ? attr.name : attr.to_s

          quoted = record.class.connection.quote_column_name(name)
          "#{quoted} = (#{value.to_sql})"
        end.join(', ')
      end

      def attribute(name)
        name && Attribute.new(table, name.to_sym)
      end

      def expression(expr)
        SqlLiteral.new(expr.to_s)
      end

      def id
        attribute(record.ordered_tree.columns.id)
      end

      def parent_id
        attribute(record.ordered_tree.columns.parent)
      end

      def position
        attribute(record.ordered_tree.columns.position)
      end

      def depth
        attribute(record.ordered_tree.columns.depth)
      end

      def table
        record.class.arel_table
      end

      def method_missing(id, *args)
        if args.length > 0
          # function
          NamedFunction.new(id.to_s.upcase, args)
        else
          super
        end
      end
    end # module DSL
  end # module Transaction
end # module ActsAsOrderedTree