lib/ml_algorithms.rb
# Class represents computational graph
module BackPropogation
class ComputationalGraph
PRIORITY = Hash['(' => 0, '+' => 1, '-' => 1, '*' => 2, '/' => 2, '^' => 3]
TEMPLATES = {
operand: /^\s*([^\+\-\*\/\(\)\^\s]+)\s*(.*)/,
string: /^\s*([\+\-\*\/\^])\s*(.*)/,
brackets: /^\s*\(\s*(.*)/,
nested: /^\s*\)\s*(.*)/
}
attr_accessor :graph
def initialize(expr_s)
exprproc = ComputationalGraph::polish_parser(expr_s, [])
pregraph = []
@graph = []
exprproc.split.each do |elem|
case elem
when '+'
dot = ComputationalGates::SummGate.new(elem)
dot.connect(pregraph.pop,pregraph.pop)
when '*'
dot = ComputationalGates::MultGate.new(elem)
dot.connect(pregraph.pop,pregraph.pop)
when '/'
dot = ComputationalGates::DivGate.new(elem)
scnd = pregraph.pop
frst = pregraph.pop
dot.connect(frst,scnd)
else
dot = ComputationalGates::CompGate.new(elem)
end
pregraph.push(dot)
@graph.push(dot)
end
end
#Compute a value of expression
def forward_pass(variables_val)
@graph.each do |elem|
if elem.class != ComputationalGates::CompGate
elem.forward_pass
else
elem.frwrd = variables_val[elem.name]
end
end
graph.last.frwrd
end
#Compute a gradient value for inputs
def backward_pass(loss_value)
param_grad = Hash.new()
@graph.last.bckwrd = loss_value
@graph.reverse.each do |elem|
if elem.class != ComputationalGates::CompGate
elem.backward_pass
else
param_grad[elem.name] = elem.bckwrd
end
end
param_grad
end
def self.parse_operand(left, right, stack)
left + ' ' + polish_parser(right, stack)
end
def self.parse_string(left, right, i_str, stack)
if stack.empty? || PRIORITY[stack.last] < PRIORITY[left]
polish_parser(right, stack.push(left))
else
stack.pop + ' ' + polish_parser(i_str, stack)
end
end
def self.parse_nested(left, right, stack)
raise ArgumentError, 'Error: Excess of closing brackets.' if stack.empty?
head = stack.pop
PRIORITY[head].positive? ? head + ' ' + polish_parser(right, stack) : polish_parser(left, stack)
end
def self.parse_brackets(left, stack)
polish_parser(left, stack)
end
def self.parse_default(left, stack)
return '' if stack.empty?
raise ArgumentError, 'Error: Excess of opening brackets.' unless PRIORITY[stack.last] > 0
stack.pop + ' ' + polish_parser(left, stack)
end
#String preprocessing algorithm expression for computation
def self.polish_parser(i_str, stack)
case i_str
when TEMPLATES[:operand]
parse_operand(Regexp.last_match(1), Regexp.last_match(2), stack)
when TEMPLATES[:string]
parse_string(Regexp.last_match(1), Regexp.last_match(2), i_str, stack)
when TEMPLATES[:brackets]
parse_brackets(Regexp.last_match(1), stack.push('('))
when TEMPLATES[:nested]
parse_nested(Regexp.last_match(1), i_str, stack)
else
parse_default(i_str, stack)
end
end
end
module ComputationalGates
class CompGate
attr_accessor :frwrd,:bckwrd,:out,:name
def initialize(name)
@name = name
@frwrd = self
end
end
class SummGate < CompGate
attr_accessor :in_frst,:in_scnd
def initialize(name)
super(name)
end
def connect(f_n,s_n)
@in_frst = f_n
@in_scnd = s_n
f_n.out = self
s_n.out = self
end
def forward_pass()
@frwrd = @in_frst.frwrd + @in_scnd.frwrd
end
def backward_pass()
@in_frst.bckwrd = @bckwrd
@in_scnd.bckwrd = @bckwrd
end
end
class MultGate < CompGate
attr_accessor :in_frst,:in_scnd
def initialize(name)
super(name)
end
def connect(f_n,s_n)
@in_frst = f_n
@in_scnd = s_n
f_n.out = self
s_n.out = self
end
def forward_pass()
@frwrd = @in_frst.frwrd * @in_scnd.frwrd
end
def backward_pass()
@in_frst.bckwrd = @bckwrd * @in_scnd.frwrd
@in_scnd.bckwrd = @bckwrd * @in_frst.frwrd
end
end
class DivGate < CompGate
attr_accessor :in_frst,:in_scnd
def initialize(name)
super(name)
end
def connect(f_n,s_n)
@in_frst = f_n
@in_scnd = s_n
f_n.out = self
s_n.out = self
end
def forward_pass()
@frwrd = @in_frst.frwrd / @in_scnd.frwrd
end
def backward_pass()
@in_frst.bckwrd = @bckwrd * ((-1)/(@in_scnd.frwrd ** 2))
@in_scnd.bckwrd = @bckwrd * ((-1)/(@in_frst.frwrd ** 2))
end
end
end
end