PuzaTech/Fugue

View on GitHub
src/main/java/com/hongliangjie/fugue/distributions/MultinomialDistribution.java

Summary

Maintainability
A
0 mins
Test Coverage
package com.hongliangjie.fugue.distributions;

import com.hongliangjie.fugue.utils.LogUtils;
import com.hongliangjie.fugue.utils.MathExp;
import com.hongliangjie.fugue.utils.MathLog;

/**
 * Created by liangjie on 3/3/16.
 */
public class MultinomialDistribution extends DiscreteDistribution {

    protected double[] _accumulatedWeights;
    protected int _K;
    protected MathLog mathLog;
    protected MathExp mathExp;
    protected LogUtils logU;
    protected Sampler s;


    public MultinomialDistribution(int K){
        this(K, new MathLog(), new MathExp(), "normal");
    }


    public MultinomialDistribution(int K, MathLog m, MathExp e, String method){
        _accumulatedWeights = new double[K];
        _K = K;
        mathLog = m;
        mathExp = e;
        logU = new LogUtils(m, e);

        if ("normal".equals(method)){
            s = new NormalSampler();
        }
        else if ("log".equals(method)){
            s = new LogSampler();
        }
        else if ("binary".equals(method)){
            s = new NormalBinarySearchSampler();
        }
    }

    protected abstract class Sampler{
        public abstract int sample(double uniformRV);
        public abstract double[] setWeights(double[] weights);
    }

    protected class NormalBinarySearchSampler extends NormalSampler{
        @Override
        public int sample(double uniformRV){
            double u = uniformRV * _accumulatedWeights[_K - 1];
            int lower = 0;
            int upper = _accumulatedWeights.length - 1;
            while (lower <= upper){
                int mid = lower + (upper - lower) / 2;
                if((_accumulatedWeights[mid] - u) > 0){
                    upper = mid - 1;
                }
                else{
                    lower = mid + 1;
                }
            }
            return lower;
        }
    }

    protected class NormalSampler extends Sampler{


        @Override
        public int sample(double uniformRV) {
            double u = uniformRV * _accumulatedWeights[_K - 1];

            int index = -1;
            for (index = 0; index < _K; index++) {
                if (u < _accumulatedWeights[index])
                    break;
            }
            return index;
        }

        @Override
        public double[] setWeights(double[] weights) {
            for(int i = 0; i < weights.length; i ++){
                if (i == 0) {
                    _accumulatedWeights[i] = weights[i];
                }
                else{
                    _accumulatedWeights[i] = _accumulatedWeights[i - 1] + weights[i];
                }
            }
            return _accumulatedWeights;
        }
    }

    protected class LogSampler extends Sampler{

        @Override
        public int sample(double uniformRV) {
            // log( (0,1)*N ) -> log(0, 1) + log N where N is the normalization factor
            double u = mathLog.compute(uniformRV) + _accumulatedWeights[_K - 1];
            int index = -1;
            for (index = 0; index < _K; index++){
                if (u < _accumulatedWeights[index]){
                    break;
                }
            }
            return index;
        }

        @Override
        public double[] setWeights(double[] weights) {
            for(int i = 0; i < weights.length; i++){
                if (i == 0) {
                    _accumulatedWeights[i] = weights[i];
                }
                else{
                    _accumulatedWeights[i] = logU.logSumTwo(_accumulatedWeights[i-1], weights[i]);
                }
            }
            return _accumulatedWeights;
        }
    }

    public int dimensions(){
        return _K;
    }

    public double[] setProbabilities(double[] prob){
        return s.setWeights(prob);
    }

    @Override
    public int sample(double uniformRV) {
        return s.sample(uniformRV);
    }

}