activerecord/lib/active_record/associations/join_dependency.rb

Summary

Maintainability
C
1 day
Test Coverage
# frozen_string_literal: true

module ActiveRecord
  module Associations
    class JoinDependency # :nodoc:
      extend ActiveSupport::Autoload

      eager_autoload do
        autoload :JoinBase
        autoload :JoinAssociation
      end

      class Aliases # :nodoc:
        def initialize(tables)
          @tables = tables
          @alias_cache = tables.each_with_object({}) { |table, h|
            h[table.node] = table.columns.each_with_object({}) { |column, i|
              i[column.name] = column.alias
            }
          }
          @columns_cache = tables.each_with_object({}) { |table, h|
            h[table.node] = table.columns
          }
        end

        def columns
          @tables.flat_map(&:column_aliases)
        end

        def column_aliases(node)
          @columns_cache[node]
        end

        def column_alias(node, column)
          @alias_cache[node][column]
        end

        Table = Struct.new(:node, :columns) do # :nodoc:
          def column_aliases
            t = node.table
            columns.map { |column| t[column.name].as(column.alias) }
          end
        end
        Column = Struct.new(:name, :alias)
      end

      def self.make_tree(associations)
        hash = {}
        walk_tree associations, hash
        hash
      end

      def self.walk_tree(associations, hash)
        case associations
        when Symbol, String
          hash[associations.to_sym] ||= {}
        when Array
          associations.each do |assoc|
            walk_tree assoc, hash
          end
        when Hash
          associations.each do |k, v|
            cache = hash[k] ||= {}
            walk_tree v, cache
          end
        else
          raise ConfigurationError, associations.inspect
        end
      end

      def initialize(base, table, associations, join_type)
        tree = self.class.make_tree associations
        @join_root = JoinBase.new(base, table, build(tree, base))
        @join_type = join_type
      end

      def base_klass
        join_root.base_klass
      end

      def reflections
        join_root.drop(1).map!(&:reflection)
      end

      def join_constraints(joins_to_add, alias_tracker, references)
        @alias_tracker = alias_tracker
        @joined_tables = {}
        @references = {}

        references.each do |table_name|
          @references[table_name.to_sym] = table_name if table_name.is_a?(Arel::Nodes::SqlLiteral)
        end unless references.empty?

        joins = make_join_constraints(join_root, join_type)

        joins.concat joins_to_add.flat_map { |oj|
          if join_root.match? oj.join_root
            walk(join_root, oj.join_root, oj.join_type)
          else
            make_join_constraints(oj.join_root, oj.join_type)
          end
        }
      end

      def instantiate(result_set, strict_loading_value, &block)
        primary_key = aliases.column_alias(join_root, join_root.primary_key)

        seen = Hash.new { |i, parent|
          i[parent] = Hash.new { |j, child_class|
            j[child_class] = {}
          }
        }.compare_by_identity

        model_cache = Hash.new { |h, klass| h[klass] = {} }
        parents = model_cache[join_root]

        column_aliases = aliases.column_aliases(join_root)
        column_names = []

        result_set.columns.each do |name|
          column_names << name unless /\At\d+_r\d+\z/.match?(name)
        end

        if column_names.empty?
          column_types = {}
        else
          column_types = result_set.column_types
          unless column_types.empty?
            attribute_types = join_root.attribute_types
            column_types = column_types.slice(*column_names).delete_if { |k, _| attribute_types.key?(k) }
          end
          column_aliases += column_names.map! { |name| Aliases::Column.new(name, name) }
        end

        message_bus = ActiveSupport::Notifications.instrumenter

        payload = {
          record_count: result_set.length,
          class_name: join_root.base_klass.name
        }

        message_bus.instrument("instantiation.active_record", payload) do
          result_set.each { |row_hash|
            parent_key = primary_key ? row_hash[primary_key] : row_hash
            parent = parents[parent_key] ||= join_root.instantiate(row_hash, column_aliases, column_types, &block)
            construct(parent, join_root, row_hash, seen, model_cache, strict_loading_value)
          }
        end

        parents.values
      end

      def apply_column_aliases(relation)
        @join_root_alias = relation.select_values.empty?
        relation._select!(-> { aliases.columns })
      end

      def each(&block)
        join_root.each(&block)
      end

      protected
        attr_reader :join_root, :join_type

      private
        attr_reader :alias_tracker, :join_root_alias

        def aliases
          @aliases ||= Aliases.new join_root.each_with_index.map { |join_part, i|
            column_names = if join_part == join_root && !join_root_alias
              primary_key = join_root.primary_key
              primary_key ? [primary_key] : []
            else
              join_part.column_names
            end

            columns = column_names.each_with_index.map { |column_name, j|
              Aliases::Column.new column_name, "t#{i}_r#{j}"
            }
            Aliases::Table.new(join_part, columns)
          }
        end

        def make_join_constraints(join_root, join_type)
          join_root.children.flat_map do |child|
            make_constraints(join_root, child, join_type)
          end
        end

        def make_constraints(parent, child, join_type)
          foreign_table = parent.table
          foreign_klass = parent.base_klass
          child.join_constraints(foreign_table, foreign_klass, join_type, alias_tracker) do |reflection|
            table, terminated = @joined_tables[reflection]
            root = reflection == child.reflection

            if table && (!root || !terminated)
              @joined_tables[reflection] = [table, root] if root
              next table, true
            end

            table_name = @references[reflection.name.to_sym]&.to_s

            table = alias_tracker.aliased_table_for(reflection.klass.arel_table, table_name) do
              name = reflection.alias_candidate(parent.table_name)
              root ? name : "#{name}_join"
            end

            @joined_tables[reflection] ||= [table, root] if join_type == Arel::Nodes::OuterJoin
            table
          end.concat child.children.flat_map { |c| make_constraints(child, c, join_type) }
        end

        def walk(left, right, join_type)
          intersection, missing = right.children.map { |node1|
            [left.children.find { |node2| node1.match? node2 }, node1]
          }.partition(&:first)

          joins = intersection.flat_map { |l, r| r.table = l.table; walk(l, r, join_type) }
          joins.concat missing.flat_map { |_, n| make_constraints(left, n, join_type) }
        end

        def find_reflection(klass, name)
          klass._reflect_on_association(name) ||
            raise(ConfigurationError, "Can't join '#{klass.name}' to association named '#{name}'; perhaps you misspelled it?")
        end

        def build(associations, base_klass)
          associations.map do |name, right|
            reflection = find_reflection base_klass, name
            reflection.check_validity!
            reflection.check_eager_loadable!

            if reflection.polymorphic?
              raise EagerLoadPolymorphicError.new(reflection)
            end

            JoinAssociation.new(reflection, build(right, reflection.klass))
          end
        end

        def construct(ar_parent, parent, row, seen, model_cache, strict_loading_value)
          return if ar_parent.nil?

          parent.children.each do |node|
            if node.reflection.collection?
              other = ar_parent.association(node.reflection.name)
              other.loaded!
            elsif ar_parent.association_cached?(node.reflection.name)
              model = ar_parent.association(node.reflection.name).target
              construct(model, node, row, seen, model_cache, strict_loading_value)
              next
            end

            if node.primary_key
              keys = Array(node.primary_key).map { |column| aliases.column_alias(node, column) }
              id = keys.map { |key| row[key] }
            else
              keys = Array(node.reflection.join_primary_key).map { |column| aliases.column_alias(node, column.to_s) }
              id = keys.map { nil } # Avoid id-based model caching.
            end

            if keys.any? { |key| row[key].nil? }
              nil_association = ar_parent.association(node.reflection.name)
              nil_association.loaded!
              next
            end

            unless model = seen[ar_parent][node][id]
              model = construct_model(ar_parent, node, row, model_cache, id, strict_loading_value)
              seen[ar_parent][node][id] = model if id
            end

            construct(model, node, row, seen, model_cache, strict_loading_value)
          end
        end

        def construct_model(record, node, row, model_cache, id, strict_loading_value)
          other = record.association(node.reflection.name)

          unless model = model_cache[node][id]
            model = node.instantiate(row, aliases.column_aliases(node)) do |m|
              m.strict_loading! if strict_loading_value
              other.set_inverse_instance(m)
            end
            model_cache[node][id] = model if id
          end

          if node.reflection.collection?
            other.target.push(model)
          else
            other.target = model
          end

          model.readonly! if node.readonly?
          model.strict_loading! if node.strict_loading?
          model
        end
    end
  end
end