MiniProfiler/rack-mini-profiler

View on GitHub
lib/mini_profiler/storage/redis_store.rb

Summary

Maintainability
C
1 day
Test Coverage
# frozen_string_literal: true

require 'digest'
require 'securerandom'

module Rack
  class MiniProfiler
    class RedisStore < AbstractStore

      attr_reader :prefix

      EXPIRES_IN_SECONDS = 60 * 60 * 24

      def initialize(args = nil)
        @args               = args || {}
        @prefix             = @args.delete(:prefix) || 'MPRedisStore'
        @redis_connection   = @args.delete(:connection)
        @expires_in_seconds = @args.delete(:expires_in) || EXPIRES_IN_SECONDS
      end

      def save(page_struct)
        redis.setex prefixed_id(page_struct[:id]), @expires_in_seconds, Marshal::dump(page_struct)
      end

      def load(id)
        key = prefixed_id(id)
        raw = redis.get key
        begin
          # rubocop:disable Security/MarshalLoad
          Marshal.load(raw) if raw
          # rubocop:enable Security/MarshalLoad
        rescue
          # bad format, junk old data
          redis.del key
          nil
        end
      end

      def set_unviewed(user, id)
        key = user_key(user)
        if redis.call([:exists, prefixed_id(id)]) == 1
          expire_at = Process.clock_gettime(Process::CLOCK_MONOTONIC).to_i + redis.ttl(prefixed_id(id))
          redis.zadd(key, expire_at, id)
        end
        redis.expire(key, @expires_in_seconds)
      end

      def set_all_unviewed(user, ids)
        key = user_key(user)
        redis.del(key)
        ids.each do |id|
          if redis.call([:exists, prefixed_id(id)]) == 1
            expire_at = Process.clock_gettime(Process::CLOCK_MONOTONIC).to_i + redis.ttl(prefixed_id(id))
            redis.zadd(key, expire_at, id)
          end
        end
        redis.expire(key, @expires_in_seconds)
      end

      def set_viewed(user, id)
        redis.zrem(user_key(user), id)
      end

      # Remove expired ids from the unviewed sorted set and return the remaining ids
      def get_unviewed_ids(user)
        key = user_key(user)
        redis.zremrangebyscore(key, '-inf', Process.clock_gettime(Process::CLOCK_MONOTONIC).to_i)
        redis.zrevrangebyscore(key, '+inf', '-inf')
      end

      def diagnostics(user)
        client = (redis.respond_to? :_client) ? redis._client : redis.client
"Redis prefix: #{@prefix}
Redis location: #{client.host}:#{client.port} db: #{client.db}
unviewed_ids: #{get_unviewed_ids(user)}
"
      end

      def flush_tokens
        redis.del("#{@prefix}-key1", "#{@prefix}-key1_old", "#{@prefix}-key2")
      end

      # Only used for testing
      def simulate_expire
        redis.del("#{@prefix}-key1")
      end

      def allowed_tokens
        key1, key1_old, key2 = redis.mget("#{@prefix}-key1", "#{@prefix}-key1_old", "#{@prefix}-key2")

        if key1 && (key1.length == 32)
          return [key1, key2].compact
        end

        timeout = Rack::MiniProfiler::AbstractStore::MAX_TOKEN_AGE

        # TODO  this could be moved to lua to correct a concurrency flaw
        # it is not critical cause worse case some requests will miss profiling info

        # no key so go ahead and set it
        key1 = SecureRandom.hex

        if key1_old && (key1_old.length == 32)
          key2 = key1_old
          redis.setex "#{@prefix}-key2", timeout, key2
        else
          key2 = nil
        end

        redis.setex "#{@prefix}-key1", timeout, key1
        redis.setex "#{@prefix}-key1_old", timeout * 2, key1

        [key1, key2].compact
      end

      COUNTER_LUA = <<~LUA
        if redis.call("INCR", KEYS[1]) % ARGV[1] == 0 then
          redis.call("DEL", KEYS[1])
          return 1
        else
          return 0
        end
      LUA

      COUNTER_LUA_SHA = Digest::SHA1.hexdigest(COUNTER_LUA)

      def should_take_snapshot?(period)
        1 == cached_redis_eval(
          COUNTER_LUA,
          COUNTER_LUA_SHA,
          reraise: false,
          keys: [snapshot_counter_key()],
          argv: [period]
        )
      end

      def push_snapshot(page_struct, group_name, config)
        group_zset_key = group_snapshot_zset_key(group_name)
        group_hash_key = group_snapshot_hash_key(group_name)
        overview_zset_key = snapshot_overview_zset_key

        id = page_struct[:id]
        score = page_struct.duration_ms.to_s

        per_group_limit = config.max_snapshots_per_group.to_s
        groups_limit = config.max_snapshot_groups.to_s
        bytes = Marshal.dump(page_struct)

        lua = <<~LUA
          local group_zset_key = KEYS[1]
          local group_hash_key = KEYS[2]
          local overview_zset_key = KEYS[3]

          local id = ARGV[1]
          local score = tonumber(ARGV[2])
          local group_name = ARGV[3]
          local per_group_limit = tonumber(ARGV[4])
          local groups_limit = tonumber(ARGV[5])
          local prefix = ARGV[6]
          local bytes = ARGV[7]

          local current_group_score = redis.call("ZSCORE", overview_zset_key, group_name)
          if current_group_score == false or score > tonumber(current_group_score) then
            redis.call("ZADD", overview_zset_key, score, group_name)
          end

          local do_save = true
          local overview_size = redis.call("ZCARD", overview_zset_key)
          while (overview_size > groups_limit) do
            local lowest_group = redis.call("ZRANGE", overview_zset_key, 0, 0)[1]
            redis.call("ZREM", overview_zset_key, lowest_group)
            if lowest_group == group_name then
              do_save = false
            else
              local lowest_group_zset_key = prefix .. "-mp-group-snapshot-zset-key-" .. lowest_group
              local lowest_group_hash_key = prefix .. "-mp-group-snapshot-hash-key-" .. lowest_group
              redis.call("DEL", lowest_group_zset_key, lowest_group_hash_key)
            end
            overview_size = overview_size - 1
          end

          if do_save then
            redis.call("ZADD", group_zset_key, score, id)
            local group_size = redis.call("ZCARD", group_zset_key)
            while (group_size > per_group_limit) do
              local lowest_snapshot_id = redis.call("ZRANGE", group_zset_key, 0, 0)[1]
              redis.call("ZREM", group_zset_key, lowest_snapshot_id)
              if lowest_snapshot_id == id then
                do_save = false
              else
                redis.call("HDEL", group_hash_key, lowest_snapshot_id)
              end
              group_size = group_size - 1
            end
            if do_save then
              redis.call("HSET", group_hash_key, id, bytes)
            end
          end
        LUA
        redis.eval(
          lua,
          keys: [group_zset_key, group_hash_key, overview_zset_key],
          argv: [id, score, group_name, per_group_limit, groups_limit, @prefix, bytes]
        )
      end

      def fetch_snapshots_overview
        overview_zset_key = snapshot_overview_zset_key
        groups = redis
          .zrange(overview_zset_key, 0, -1, withscores: true)
          .map { |(name, worst_score)| [name, { worst_score: worst_score }] }

        prefixed_group_names = groups.map { |(group_name, _)| group_snapshot_zset_key(group_name) }
        metadata = redis.eval(<<~LUA, keys: prefixed_group_names)
          local metadata = {}
          for i, k in ipairs(KEYS) do
            local best = redis.call("ZRANGE", k, 0, 0, "WITHSCORES")[2]
            local count = redis.call("ZCARD", k)
            metadata[i] = {best, count}
          end
          return metadata
        LUA
        groups.each.with_index do |(_, hash), index|
          best, count = metadata[index]
          hash[:best_score] = best.to_f
          hash[:snapshots_count] = count.to_i
        end
        groups.to_h
      end

      def fetch_snapshots_group(group_name)
        group_hash_key = group_snapshot_hash_key(group_name)
        snapshots = []
        corrupt_snapshots = []
        redis.hgetall(group_hash_key).each do |id, bytes|
          # rubocop:disable Security/MarshalLoad
          snapshots << Marshal.load(bytes)
          # rubocop:enable Security/MarshalLoad
        rescue
          corrupt_snapshots << id
        end
        if corrupt_snapshots.size > 0
          cleanup_corrupt_snapshots(corrupt_snapshots, group_name)
        end
        snapshots
      end

      def load_snapshot(id, group_name)
        group_hash_key = group_snapshot_hash_key(group_name)
        bytes = redis.hget(group_hash_key, id)
        return if !bytes
        begin
          # rubocop:disable Security/MarshalLoad
          Marshal.load(bytes)
          # rubocop:enable Security/MarshalLoad
        rescue
          cleanup_corrupt_snapshots([id], group_name)
          nil
        end
      end

      private

      def user_key(user)
        "#{@prefix}-#{user}-v1"
      end

      def prefixed_id(id)
        "#{@prefix}#{id}"
      end

      def redis
        @redis_connection ||= begin
          require 'redis' unless defined? Redis
          Redis.new(@args)
        end
      end

      def snapshot_counter_key
        @snapshot_counter_key ||= "#{@prefix}-mini-profiler-snapshots-counter"
      end

      def group_snapshot_zset_key(group_name)
        # if you change this key, remember to change it in the LUA script in
        # the push_snapshot method as well
        "#{@prefix}-mp-group-snapshot-zset-key-#{group_name}"
      end

      def group_snapshot_hash_key(group_name)
        # if you change this key, remember to change it in the LUA script in
        # the push_snapshot method as well
        "#{@prefix}-mp-group-snapshot-hash-key-#{group_name}"
      end

      def snapshot_overview_zset_key
        "#{@prefix}-mp-overviewgroup-snapshot-zset-key"
      end

      def cached_redis_eval(script, script_sha, reraise: true, argv: [], keys: [])
        begin
          redis.evalsha(script_sha, argv: argv, keys: keys)
        rescue ::Redis::CommandError => e
          if e.message.start_with?('NOSCRIPT')
            redis.eval(script, argv: argv, keys: keys)
          else
            raise e if reraise
          end
        end
      end

      def cleanup_corrupt_snapshots(corrupt_snapshots_ids, group_name)
        group_hash_key = group_snapshot_hash_key(group_name)
        group_zset_key = group_snapshot_zset_key(group_name)
        overview_zset_key = snapshot_overview_zset_key
        lua = <<~LUA
          local group_hash_key = KEYS[1]
          local group_zset_key = KEYS[2]
          local overview_zset_key = KEYS[3]
          local group_name = ARGV[1]
          for i, k in ipairs(ARGV) do
            if k ~= group_name then
              redis.call("HDEL", group_hash_key, k)
              redis.call("ZREM", group_zset_key, k)
            end
          end
          if redis.call("ZCARD", group_zset_key) == 0 then
            redis.call("ZREM", overview_zset_key, group_name)
            redis.call("DEL", group_hash_key, group_zset_key)
          else
            local worst_score = tonumber(redis.call("ZRANGE", group_zset_key, -1, -1, "WITHSCORES")[2])
            redis.call("ZADD", overview_zset_key, worst_score, group_name)
          end
        LUA
        redis.eval(
          lua,
          keys: [group_hash_key, group_zset_key, overview_zset_key],
          argv: [group_name, *corrupt_snapshots_ids]
        )
      end

      # only used in tests
      def wipe_snapshots_data
        keys = redis.keys(group_snapshot_hash_key('*'))
        keys += redis.keys(group_snapshot_zset_key('*'))
        redis.del(
          keys,
          snapshot_overview_zset_key,
          snapshot_counter_key
        )
      end
    end
  end
end