danmcclain/postgres_ext

View on GitHub
lib/postgres_ext/active_record/relation/query_methods.rb

Summary

Maintainability
C
1 day
Test Coverage
module ActiveRecord
  module QueryMethods
    class WhereChain
      def overlap(opts, *rest)
        substitute_comparisons(opts, rest, Arel::Nodes::Overlap, 'overlap')
      end

      def contained_within(opts, *rest)
        substitute_comparisons(opts, rest, Arel::Nodes::ContainedWithin, 'contained_within')
      end

      def contained_within_or_equals(opts, *rest)
        substitute_comparisons(opts, rest, Arel::Nodes::ContainedWithinEquals, 'contained_within_or_equals')
      end

      def contains(opts, *rest)
        build_where_chain(opts, rest) do |rel|
          case rel
          when Arel::Nodes::In, Arel::Nodes::Equality
            column = left_column(rel) || column_from_association(rel)
            equality_for_hstore(rel) if column.type == :hstore

            if column.type == :hstore
              Arel::Nodes::ContainsHStore.new(rel.left, rel.right)
            elsif column.respond_to?(:array) && column.array
              Arel::Nodes::ContainsArray.new(rel.left, rel.right)
            else
              Arel::Nodes::ContainsINet.new(rel.left, rel.right)
            end
          else
            raise ArgumentError, "Invalid argument for .where.overlap(), got #{rel.class}"
          end
        end
      end

      def contained_in_array(opts, *rest)
        build_where_chain(opts, rest) do |rel|
          case rel
          when Arel::Nodes::In, Arel::Nodes::Equality
            column = left_column(rel) || column_from_association(rel)
            equality_for_hstore(rel) if column.type == :hstore

            if column.type == :hstore
              Arel::Nodes::ContainedInHStore.new(rel.left, rel.right)
            elsif column.respond_to?(:array) && column.array
              Arel::Nodes::ContainedInArray.new(rel.left, rel.right)
            else
              Arel::Nodes::ContainsINet.new(rel.left, rel.right)
            end
          else
            raise ArgumentError, "Invalid argument for .where.overlap(), got #{rel.class}"
          end
        end
      end

      def contains_or_equals(opts, *rest)
        substitute_comparisons(opts, rest, Arel::Nodes::ContainsEquals, 'contains_or_equals')
      end

      def any(opts, *rest)
        equality_to_function('ANY', opts, rest)
      end

      def all(opts, *rest)
        equality_to_function('ALL', opts, rest)
      end

      private

      def find_column(col, rel)
        col.name == rel.left.name.to_s || col.name == rel.left.relation.name.to_s
      end

      def left_column(rel)
        rel.left.relation.engine.columns.find { |col| find_column(col, rel) }
      end

      def column_from_association(rel)
        if assoc = assoc_from_related_table(rel)
          column = assoc.klass.columns.find { |col| find_column(col, rel) }
        end
      end

      def equality_for_hstore(rel)
        new_right_name = rel.left.name.to_s
        if rel.right.respond_to?(:val)
          return if rel.right.val.is_a?(Hash)
          rel.right = Arel::Nodes.build_quoted({new_right_name => rel.right.val},
                                               rel.left)
        else
          return if rel.right.is_a?(Hash)
          rel.right = {new_right_name => rel.right }
        end

        rel.left.name = rel.left.relation.name.to_sym
        rel.left.relation.name = rel.left.relation.engine.table_name
      end

      def assoc_from_related_table(rel)
        engine = rel.left.relation.engine
        engine.reflect_on_association(rel.left.relation.name.to_sym) ||
          engine.reflect_on_association(rel.left.relation.name.singularize.to_sym)
      end

      def build_where_chain(opts, rest, &block)
        where_value = @scope.send(:build_where, opts, rest).map(&block)
        @scope.references!(PredicateBuilder.references(opts)) if Hash === opts
        @scope.where_values += where_value
        @scope
      end

      def substitute_comparisons(opts, rest, arel_node_class, method)
        build_where_chain(opts, rest) do |rel|
          case rel
          when Arel::Nodes::In, Arel::Nodes::Equality
            arel_node_class.new(rel.left, rel.right)
          else
            raise ArgumentError, "Invalid argument for .where.#{method}(), got #{rel.class}"
          end
        end
      end

      def equality_to_function(function_name, opts, rest)
        build_where_chain(opts, rest) do |rel|
          case rel
          when Arel::Nodes::Equality
            Arel::Nodes::Equality.new(rel.right, Arel::Nodes::NamedFunction.new(function_name, [rel.left]))
          else
            raise ArgumentError, "Invalid argument for .where.#{function_name.downcase}(), got #{rel.class}"
          end
        end
      end
    end

    # WithChain objects act as placeholder for queries in which #with does not have any parameter.
    # In this case, #with must be chained with #recursive to return a new relation.
    class WithChain
      def initialize(scope)
        @scope = scope
      end

      # Returns a new relation expressing WITH RECURSIVE
      def recursive(*args)
        @scope.with_values += args
        @scope.recursive_value = true
        @scope
      end
    end

    [:with].each do |name|
      class_eval <<-CODE, __FILE__, __LINE__ + 1
       def #{name}_values                   # def select_values
         @values[:#{name}] || []            #   @values[:select] || []
       end                                  # end
                                            #
       def #{name}_values=(values)          # def select_values=(values)
         raise ImmutableRelation if @loaded #   raise ImmutableRelation if @loaded
         @values[:#{name}] = values         #   @values[:select] = values
       end                                  # end
      CODE
    end

    [:rank, :recursive].each do |name|
      class_eval <<-CODE, __FILE__, __LINE__ + 1
        def #{name}_value=(value)            # def readonly_value=(value)
          raise ImmutableRelation if @loaded #   raise ImmutableRelation if @loaded
          @values[:#{name}] = value          #   @values[:readonly] = value
        end                                  # end

        def #{name}_value                    # def readonly_value
          @values[:#{name}]                  #   @values[:readonly]
        end                                  # end
      CODE
    end

    def with(opts = :chain, *rest)
      if opts == :chain
        WithChain.new(spawn)
      elsif opts.blank?
        self
      else
        spawn.with!(opts, *rest)
      end
    end

    def with!(opts = :chain, *rest) # :nodoc:
      if opts == :chain
        WithChain.new(self)
      else
        self.with_values += [opts] + rest
        self
      end
    end

    def ranked(options = :order)
      spawn.ranked! options
    end

    def ranked!(value)
      self.rank_value = value
      self
    end

    def build_arel_with_extensions
      arel = build_arel_without_extensions

      build_with(arel)

      build_rank(arel, rank_value) if rank_value

      arel
    end

    def build_with(arel)
      with_statements = with_values.flat_map do |with_value|
        case with_value
        when String
          with_value
        when Hash
          with_value.map  do |name, expression|
            case expression
            when String
              select = Arel::Nodes::SqlLiteral.new "(#{expression})"
            when ActiveRecord::Relation, Arel::SelectManager
              select = Arel::Nodes::SqlLiteral.new "(#{expression.to_sql})"
            end
            Arel::Nodes::As.new Arel::Nodes::SqlLiteral.new("\"#{name.to_s}\""), select
          end
        when Arel::Nodes::As
          with_value
        end
      end
      unless with_statements.empty?
        if recursive_value
          arel.with :recursive, with_statements
        else
          arel.with with_statements
        end
      end
    end

    def build_rank(arel, rank_window_options)
      unless arel.projections.count == 1 && Arel::Nodes::Count === arel.projections.first
        rank_window = case rank_window_options
                      when :order
                        arel.orders
                      when Symbol
                        table[rank_window_options].asc
                      when Hash
                        rank_window_options.map { |field, dir| table[field].send(dir) }
                      else
                        Arel::Nodes::SqlLiteral.new "(#{rank_window_options})"
                      end

        unless rank_window.blank?
          rank_node = Arel::Nodes::SqlLiteral.new 'rank()'
          window = Arel::Nodes::Window.new
          if String === rank_window
            window = window.frame rank_window
          else
            window = window.order(rank_window)
          end
          over_node = Arel::Nodes::Over.new rank_node, window

          arel.project(over_node)
        end
      end
    end

    alias_method_chain :build_arel, :extensions
  end
end