giuse/machine_learning_workbench

View on GitHub
lib/machine_learning_workbench/compressor/decaying_learning_rate_vq.rb

Summary

Maintainability
A
0 mins
Test Coverage
# frozen_string_literal: true

module MachineLearningWorkbench::Compressor
  # VQ with per-centroid decaying learning rates.
  # Optimized for online training.
  class DecayingLearningRateVQ < VectorQuantization

    attr_reader :lrate_min, :lrate_min_den, :decay_rate

    def initialize **opts
      puts "Ignoring learning rate: `lrate: #{opts[:lrate]}`" if opts[:lrate]
      @lrate_min = opts.delete(:lrate_min) || 0.001
      @lrate_min_den = opts.delete(:lrate_min_den) || 1
      @decay_rate = opts.delete(:decay_rate) || 1
      super **opts.merge({lrate: nil})
    end

    # Overloading lrate check from original VQ
    def check_lrate lrate; nil; end

    # Decaying per-centroid learning rate.
    # @param centr_idx [Integer] index of the centroid
    # @param lower_bound [Float] minimum learning rate
    # @note nicely overloads the `attr_reader` of parent class
    def lrate centr_idx, min_den: lrate_min_den, lower_bound: lrate_min, decay: decay_rate
      [1.0/(ntrains[centr_idx]*decay+min_den), lower_bound].max
      .tap { |l| puts "centr: #{centr_idx}, ntrains: #{ntrains[centr_idx]}, lrate: #{l}" }
    end

    # Train on one vector
    # @return [Integer] index of trained centroid
    def train_one vec, eps: nil
      # NOTE: ignores epsilon if passed
      trg_idx, _simil = most_similar_centr(vec)
      # norm_vec = vec / NLinalg.norm(vec)
      # centrs[trg_idx, true] = centrs[trg_idx, true] * (1-lrate(trg_idx)) + norm_vec * lrate(trg_idx)
      centrs[trg_idx, true] = centrs[trg_idx, true] * (1-lrate(trg_idx)) + vec * lrate(trg_idx)
      trg_idx
    end

  end
end