SciRuby/statsample

View on GitHub
lib/statsample/formula/fit_model.rb

Summary

Maintainability
A
0 mins
Test Coverage
require 'statsample/formula/formula'

module Statsample
  # Class for performing regression
  class FitModel
    def initialize(formula, df, opts = {})
      @formula = FormulaWrapper.new formula, df
      @df = df
      @opts = opts
    end

    def model
      @model || fit_model
    end

    def predict(new_data)
      model.predict(df_for_prediction(new_data))
    end

    def df_for_prediction df
      canonicalize_df(df)
    end

    def df_for_regression
      df = canonicalize_df(@df)
      df[@formula.y.value] = @df[@formula.y.value]
      df        
    end

    def canonicalize_df(orig_df)
      tokens = @formula.canonical_tokens
      tokens.shift if tokens.first.value == '1'
      df = tokens.map { |t| t.to_df orig_df }.reduce(&:merge)
      df
    end

    def fit_model
      # TODO: Add support for inclusion/exclusion of intercept
      @model = Statsample::Regression.multiple(
        df_for_regression,
        @formula.y.value,
        @opts
      )
    end
  end
end