rubocop-hq/rubocop

View on GitHub
lib/rubocop/cop/style/safe_navigation.rb

Summary

Maintainability
B
6 hrs
Test Coverage
A
100%
# frozen_string_literal: true

module RuboCop
  module Cop
    module Style
      # Transforms usages of a method call safeguarded by a non `nil`
      # check for the variable whose method is being called to
      # safe navigation (`&.`). If there is a method chain, all of the methods
      # in the chain need to be checked for safety, and all of the methods will
      # need to be changed to use safe navigation.
      #
      # The default for `ConvertCodeThatCanStartToReturnNil` is `false`.
      # When configured to `true`, this will
      # check for code in the format `!foo.nil? && foo.bar`. As it is written,
      # the return of this code is limited to `false` and whatever the return
      # of the method is. If this is converted to safe navigation,
      # `foo&.bar` can start returning `nil` as well as what the method
      # returns.
      #
      # The default for `MaxChainLength` is `2`
      # We have limited the cop to not register an offense for method chains
      # that exceed this option is set.
      #
      # @safety
      #   Autocorrection is unsafe because if a value is `false`, the resulting
      #   code will have different behavior or raise an error.
      #
      #   [source,ruby]
      #   ----
      #   x = false
      #   x && x.foo  # return false
      #   x&.foo      # raises NoMethodError
      #   ----
      #
      # @example
      #   # bad
      #   foo.bar if foo
      #   foo.bar.baz if foo
      #   foo.bar(param1, param2) if foo
      #   foo.bar { |e| e.something } if foo
      #   foo.bar(param) { |e| e.something } if foo
      #
      #   foo.bar if !foo.nil?
      #   foo.bar unless !foo
      #   foo.bar unless foo.nil?
      #
      #   foo && foo.bar
      #   foo && foo.bar.baz
      #   foo && foo.bar(param1, param2)
      #   foo && foo.bar { |e| e.something }
      #   foo && foo.bar(param) { |e| e.something }
      #
      #   foo ? foo.bar : nil
      #   foo.nil? ? nil : foo.bar
      #   !foo.nil? ? foo.bar : nil
      #   !foo ? nil : foo.bar
      #
      #   # good
      #   foo&.bar
      #   foo&.bar&.baz
      #   foo&.bar(param1, param2)
      #   foo&.bar { |e| e.something }
      #   foo&.bar(param) { |e| e.something }
      #   foo && foo.bar.baz.qux # method chain with more than 2 methods
      #   foo && foo.nil? # method that `nil` responds to
      #
      #   # Method calls that do not use `.`
      #   foo && foo < bar
      #   foo < bar if foo
      #
      #   # When checking `foo&.empty?` in a conditional, `foo` being `nil` will actually
      #   # do the opposite of what the author intends.
      #   foo && foo.empty?
      #
      #   # This could start returning `nil` as well as the return of the method
      #   foo.nil? || foo.bar
      #   !foo || foo.bar
      #
      #   # Methods that are used on assignment, arithmetic operation or
      #   # comparison should not be converted to use safe navigation
      #   foo.baz = bar if foo
      #   foo.baz + bar if foo
      #   foo.bar > 2 if foo
      class SafeNavigation < Base
        include NilMethods
        include RangeHelp
        extend AutoCorrector
        extend TargetRubyVersion

        MSG = 'Use safe navigation (`&.`) instead of checking if an object ' \
              'exists before calling the method.'
        LOGIC_JUMP_KEYWORDS = %i[break fail next raise return throw yield].freeze

        minimum_target_ruby_version 2.3

        # if format: (if checked_variable body nil)
        # unless format: (if checked_variable nil body)
        # @!method modifier_if_safe_navigation_candidate(node)
        def_node_matcher :modifier_if_safe_navigation_candidate, <<~PATTERN
          {
            (if {
                  (send $_ {:nil? :!})
                  $_
                } nil? $_)

            (if {
                  (send (send $_ :nil?) :!)
                  $_
                } $_ nil?)
          }
        PATTERN

        # @!method ternary_safe_navigation_candidate(node)
        def_node_matcher :ternary_safe_navigation_candidate, <<~PATTERN
          {
            (if (send $_ {:nil? :!}) nil $_)

            (if (send (send $_ :nil?) :!) $_ nil)

            (if $_ $_ nil)
          }
        PATTERN

        # @!method not_nil_check?(node)
        def_node_matcher :not_nil_check?, '(send (send $_ :nil?) :!)'

        def on_if(node)
          return if allowed_if_condition?(node)

          check_node(node)
        end

        def on_and(node)
          check_node(node)
        end

        private

        def check_node(node)
          checked_variable, receiver, method_chain, method = extract_parts(node)
          return if receiver != checked_variable || receiver.nil?
          return if use_var_only_in_unless_modifier?(node, checked_variable)
          return if chain_length(method_chain, method) > max_chain_length
          return if unsafe_method_used?(method_chain, method)
          return if method_chain.method?(:empty?)

          add_offense(node) { |corrector| autocorrect(corrector, node) }
        end

        def use_var_only_in_unless_modifier?(node, variable)
          node.if_type? && node.unless? && !method_called?(variable)
        end

        def autocorrect(corrector, node)
          body = extract_body(node)
          method_call = method_call(node)

          corrector.remove(begin_range(node, body))
          corrector.remove(end_range(node, body))
          corrector.insert_before(method_call.loc.dot, '&') unless method_call.safe_navigation?
          handle_comments(corrector, node, method_call)

          add_safe_nav_to_all_methods_in_chain(corrector, method_call, body)
        end

        def extract_body(node)
          if node.if_type? && node.ternary?
            node.branches.find { |branch| !branch.nil_type? }
          else
            node.node_parts[1]
          end
        end

        def handle_comments(corrector, node, method_call)
          comments = comments(node)
          return if comments.empty?

          corrector.insert_before(method_call, "#{comments.map(&:text).join("\n")}\n")
        end

        def comments(node)
          relevant_comment_ranges(node).each.with_object([]) do |range, comments|
            comments.concat(processed_source.each_comment_in_lines(range).to_a)
          end
        end

        def relevant_comment_ranges(node)
          # Get source lines ranges inside the if node that aren't inside an inner node
          # Comments inside an inner node should remain attached to that node, and not
          # moved.
          begin_pos = node.loc.first_line
          end_pos = node.loc.last_line

          node.child_nodes.each.with_object([]) do |child, ranges|
            ranges << (begin_pos...child.loc.first_line)
            begin_pos = child.loc.last_line
          end << (begin_pos...end_pos)
        end

        def allowed_if_condition?(node)
          node.else? || node.elsif?
        end

        def method_call(node)
          _checked_variable, matching_receiver, = extract_parts(node)
          matching_receiver.parent
        end

        def extract_parts(node)
          case node.type
          when :if
            extract_parts_from_if(node)
          when :and
            extract_parts_from_and(node)
          end
        end

        def extract_parts_from_if(node)
          variable, receiver =
            if node.ternary?
              ternary_safe_navigation_candidate(node)
            else
              modifier_if_safe_navigation_candidate(node)
            end

          checked_variable, matching_receiver, method = extract_common_parts(receiver, variable)

          matching_receiver = nil if receiver && LOGIC_JUMP_KEYWORDS.include?(receiver.type)

          [checked_variable, matching_receiver, receiver, method]
        end

        def extract_parts_from_and(node)
          checked_variable, rhs = *node
          if cop_config['ConvertCodeThatCanStartToReturnNil']
            checked_variable = not_nil_check?(checked_variable) || checked_variable
          end

          checked_variable, matching_receiver, method = extract_common_parts(rhs, checked_variable)
          [checked_variable, matching_receiver, rhs, method]
        end

        def extract_common_parts(method_chain, checked_variable)
          matching_receiver = find_matching_receiver_invocation(method_chain, checked_variable)

          method = matching_receiver.parent if matching_receiver

          [checked_variable, matching_receiver, method]
        end

        def find_matching_receiver_invocation(method_chain, checked_variable)
          return nil unless method_chain

          receiver = method_chain.receiver

          return receiver if receiver == checked_variable

          find_matching_receiver_invocation(receiver, checked_variable)
        end

        def chain_length(method_chain, method)
          method.each_ancestor(:send).inject(1) do |total, ancestor|
            break total + 1 if ancestor == method_chain

            total + 1
          end
        end

        def unsafe_method_used?(method_chain, method)
          return true if unsafe_method?(method)

          method.each_ancestor(:send).any? do |ancestor|
            break true unless config.for_cop('Lint/SafeNavigationChain')['Enabled']

            break true if unsafe_method?(ancestor)
            break true if nil_methods.include?(ancestor.method_name)
            break false if ancestor == method_chain
          end
        end

        def unsafe_method?(send_node)
          negated?(send_node) ||
            send_node.assignment? ||
            (!send_node.dot? && !send_node.safe_navigation?)
        end

        def negated?(send_node)
          if method_called?(send_node)
            negated?(send_node.parent)
          else
            send_node.send_type? && send_node.method?(:!)
          end
        end

        def method_called?(send_node)
          send_node&.parent&.send_type?
        end

        def begin_range(node, method_call)
          range_between(node.source_range.begin_pos, method_call.source_range.begin_pos)
        end

        def end_range(node, method_call)
          range_between(method_call.source_range.end_pos, node.source_range.end_pos)
        end

        def add_safe_nav_to_all_methods_in_chain(corrector,
                                                 start_method,
                                                 method_chain)
          start_method.each_ancestor do |ancestor|
            break unless %i[send block].include?(ancestor.type)
            next unless ancestor.send_type?

            corrector.insert_before(ancestor.loc.dot, '&')

            break if ancestor == method_chain
          end
        end

        def max_chain_length
          cop_config.fetch('MaxChainLength', 2)
        end
      end
    end
  end
end