lib/postgres_ext/active_record/relation/query_methods.rb
module ActiveRecord
module QueryMethods
class WhereChain
def overlap(opts, *rest)
substitute_comparisons(opts, rest, Arel::Nodes::Overlap, 'overlap')
end
def contained_within(opts, *rest)
substitute_comparisons(opts, rest, Arel::Nodes::ContainedWithin, 'contained_within')
end
def contained_within_or_equals(opts, *rest)
substitute_comparisons(opts, rest, Arel::Nodes::ContainedWithinEquals, 'contained_within_or_equals')
end
def contains(opts, *rest)
build_where_chain(opts, rest) do |rel|
case rel
when Arel::Nodes::In, Arel::Nodes::Equality
column = left_column(rel) || column_from_association(rel)
equality_for_hstore(rel) if column.type == :hstore
if column.type == :hstore
Arel::Nodes::ContainsHStore.new(rel.left, rel.right)
elsif column.respond_to?(:array) && column.array
Arel::Nodes::ContainsArray.new(rel.left, rel.right)
else
Arel::Nodes::ContainsINet.new(rel.left, rel.right)
end
else
raise ArgumentError, "Invalid argument for .where.overlap(), got #{rel.class}"
end
end
end
def contained_in_array(opts, *rest)
build_where_chain(opts, rest) do |rel|
case rel
when Arel::Nodes::In, Arel::Nodes::Equality
column = left_column(rel) || column_from_association(rel)
equality_for_hstore(rel) if column.type == :hstore
if column.type == :hstore
Arel::Nodes::ContainedInHStore.new(rel.left, rel.right)
elsif column.respond_to?(:array) && column.array
Arel::Nodes::ContainedInArray.new(rel.left, rel.right)
else
Arel::Nodes::ContainsINet.new(rel.left, rel.right)
end
else
raise ArgumentError, "Invalid argument for .where.overlap(), got #{rel.class}"
end
end
end
def contains_or_equals(opts, *rest)
substitute_comparisons(opts, rest, Arel::Nodes::ContainsEquals, 'contains_or_equals')
end
def any(opts, *rest)
equality_to_function('ANY', opts, rest)
end
def all(opts, *rest)
equality_to_function('ALL', opts, rest)
end
private
def find_column(col, rel)
col.name == rel.left.name.to_s || col.name == rel.left.relation.name.to_s
end
def left_column(rel)
rel.left.relation.engine.columns.find { |col| find_column(col, rel) }
end
def column_from_association(rel)
if assoc = assoc_from_related_table(rel)
column = assoc.klass.columns.find { |col| find_column(col, rel) }
end
end
def equality_for_hstore(rel)
new_right_name = rel.left.name.to_s
if rel.right.respond_to?(:val)
return if rel.right.val.is_a?(Hash)
rel.right = Arel::Nodes.build_quoted({new_right_name => rel.right.val},
rel.left)
else
return if rel.right.is_a?(Hash)
rel.right = {new_right_name => rel.right }
end
rel.left.name = rel.left.relation.name.to_sym
rel.left.relation.name = rel.left.relation.engine.table_name
end
def assoc_from_related_table(rel)
engine = rel.left.relation.engine
engine.reflect_on_association(rel.left.relation.name.to_sym) ||
engine.reflect_on_association(rel.left.relation.name.singularize.to_sym)
end
def build_where_chain(opts, rest, &block)
where_value = @scope.send(:build_where, opts, rest).map(&block)
@scope.references!(PredicateBuilder.references(opts)) if Hash === opts
@scope.where_values += where_value
@scope
end
def substitute_comparisons(opts, rest, arel_node_class, method)
build_where_chain(opts, rest) do |rel|
case rel
when Arel::Nodes::In, Arel::Nodes::Equality
arel_node_class.new(rel.left, rel.right)
else
raise ArgumentError, "Invalid argument for .where.#{method}(), got #{rel.class}"
end
end
end
def equality_to_function(function_name, opts, rest)
build_where_chain(opts, rest) do |rel|
case rel
when Arel::Nodes::Equality
Arel::Nodes::Equality.new(rel.right, Arel::Nodes::NamedFunction.new(function_name, [rel.left]))
else
raise ArgumentError, "Invalid argument for .where.#{function_name.downcase}(), got #{rel.class}"
end
end
end
end
# WithChain objects act as placeholder for queries in which #with does not have any parameter.
# In this case, #with must be chained with #recursive to return a new relation.
class WithChain
def initialize(scope)
@scope = scope
end
# Returns a new relation expressing WITH RECURSIVE
def recursive(*args)
@scope.with_values += args
@scope.recursive_value = true
@scope
end
end
[:with].each do |name|
class_eval <<-CODE, __FILE__, __LINE__ + 1
def #{name}_values # def select_values
@values[:#{name}] || [] # @values[:select] || []
end # end
#
def #{name}_values=(values) # def select_values=(values)
raise ImmutableRelation if @loaded # raise ImmutableRelation if @loaded
@values[:#{name}] = values # @values[:select] = values
end # end
CODE
end
[:rank, :recursive].each do |name|
class_eval <<-CODE, __FILE__, __LINE__ + 1
def #{name}_value=(value) # def readonly_value=(value)
raise ImmutableRelation if @loaded # raise ImmutableRelation if @loaded
@values[:#{name}] = value # @values[:readonly] = value
end # end
def #{name}_value # def readonly_value
@values[:#{name}] # @values[:readonly]
end # end
CODE
end
def with(opts = :chain, *rest)
if opts == :chain
WithChain.new(spawn)
elsif opts.blank?
self
else
spawn.with!(opts, *rest)
end
end
def with!(opts = :chain, *rest) # :nodoc:
if opts == :chain
WithChain.new(self)
else
self.with_values += [opts] + rest
self
end
end
def ranked(options = :order)
spawn.ranked! options
end
def ranked!(value)
self.rank_value = value
self
end
def build_arel_with_extensions
arel = build_arel_without_extensions
build_with(arel)
build_rank(arel, rank_value) if rank_value
arel
end
def build_with(arel)
with_statements = with_values.flat_map do |with_value|
case with_value
when String
with_value
when Hash
with_value.map do |name, expression|
case expression
when String
select = Arel::Nodes::SqlLiteral.new "(#{expression})"
when ActiveRecord::Relation, Arel::SelectManager
select = Arel::Nodes::SqlLiteral.new "(#{expression.to_sql})"
end
Arel::Nodes::As.new Arel::Nodes::SqlLiteral.new("\"#{name.to_s}\""), select
end
when Arel::Nodes::As
with_value
end
end
unless with_statements.empty?
if recursive_value
arel.with :recursive, with_statements
else
arel.with with_statements
end
end
end
def build_rank(arel, rank_window_options)
unless arel.projections.count == 1 && Arel::Nodes::Count === arel.projections.first
rank_window = case rank_window_options
when :order
arel.orders
when Symbol
table[rank_window_options].asc
when Hash
rank_window_options.map { |field, dir| table[field].send(dir) }
else
Arel::Nodes::SqlLiteral.new "(#{rank_window_options})"
end
unless rank_window.blank?
rank_node = Arel::Nodes::SqlLiteral.new 'rank()'
window = Arel::Nodes::Window.new
if String === rank_window
window = window.frame rank_window
else
window = window.order(rank_window)
end
over_node = Arel::Nodes::Over.new rank_node, window
arel.project(over_node)
end
end
end
alias_method_chain :build_arel, :extensions
end
end