fair-search/fairsearch-elasticsearch-plugin

View on GitHub
src/main/java/com/purbon/search/fair/utils/RecursiveNumericFailprobabilityCalculator.java

Summary

Maintainability
B
6 hrs
Test Coverage
package com.purbon.search.fair.utils;

import java.util.ArrayList;
import java.util.HashMap;

public class RecursiveNumericFailprobabilityCalculator extends FailprobabilityCalculator {

    private static final double EPS = 0.0000000000000001;

    private HashMap<LegalAssignmentKey, Double> legalAssignmentCache = new HashMap<>();

    public RecursiveNumericFailprobabilityCalculator(int k, double p, double alpha) {
        super(k, p, alpha);
    }


    public MTableFailProbPair adjustAlpha() {
        double aMin = 0;
        double aMax = alpha;
        double aMid = (aMin + aMax) / 2;

        MTableFailProbPair min = computeBoundary(k, p, aMin);
        MTableFailProbPair max = computeBoundary(k, p, aMax);
        MTableFailProbPair mid = computeBoundary(k, p, aMid);

        while (min.getMassOfMTable() < max.getMassOfMTable() && mid.getFailProb() != alpha) {


            if (mid.getFailProb() < alpha) {
                aMin = aMid;
                min = computeBoundary(k, p, aMin);
            } else if (mid.getFailProb() > alpha) {
                aMax = aMid;
                max = computeBoundary(k, p, aMax);
            }
            aMid = (aMin + aMax) / 2;
            mid = computeBoundary(k, p, aMid);

            int maxMass = max.getMassOfMTable();
            int minMass = min.getMassOfMTable();
            int midMass = mid.getMassOfMTable();
            if(maxMass-minMass == 1 || max.getAlpha()-min.getAlpha()<=EPS){
                double minDiff = Math.abs(min.getFailProb() - alpha);
                double maxDiff = Math.abs(max.getFailProb() - alpha);
                if(minDiff<= maxDiff){
                    return min;
                }else{
                    return max;
                }
            }
            if (maxMass-midMass == 1 && midMass-minMass == 1) {
                double minDiff = Math.abs(min.getFailProb() - alpha);
                double maxDiff = Math.abs(max.getFailProb() - alpha);
                double midDiff = Math.abs(mid.getFailProb() - alpha);
                if (midDiff <= maxDiff && midDiff <= minDiff) {
                    return mid;
                }
                if(minDiff <= midDiff && minDiff <= maxDiff){
                    return min;
                }else{
                    return max;
                }
            }
        }

        return mid;


    }

    private MTableFailProbPair computeBoundary(int k, double p, double alpha) {
        int[] mTable = new MTableGenerator(k, p, alpha, false).getMTable();
        double failProb = calculateFailprobability(mTable, alpha);
        return new MTableFailProbPair(k, p, alpha, failProb, mTable);
    }

    @Override
    double calculateFailprobability(int[] mtable, double alpha) {
        if (mtable[mtable.length - 1] == 0) {
            return 0;
        }
        MTableGenerator generator = new MTableGenerator(k, p, alpha, false);
        this.mTable = mtable;
        this.auxMTable = generator.computeAuxTMTable();
        int maxProtected = auxMTable.getSumOf("block");
        ArrayList<Integer> blockSizes = auxMTable.getColumn("block");
        blockSizes = sublist(blockSizes, 1, blockSizes.size());
        double succesProb = findLegalAssignments(maxProtected, blockSizes);
        return succesProb == 0 ? 0 : 1 - succesProb;
    }

    private double findLegalAssignments(int numCandidates, ArrayList<Integer> blockSizes) {

        return findLegalAssignmentsAux(numCandidates, blockSizes, 1, 0);
    }

    private double findLegalAssignmentsAux(int numCandidates, ArrayList<Integer> blockSizes
            , int currentBlockNumber, int candidatesAssignedSoFar) {
        if (blockSizes.size() == 0) {
            return 1;
        } else {
            int minNeededThisBlock = currentBlockNumber - candidatesAssignedSoFar;
            if (minNeededThisBlock < 0) {
                minNeededThisBlock = 0;
            }
            int maxPossibleThisBlock = Math.min(blockSizes.get(0), numCandidates);
            double assignments = 0;
            ArrayList<Integer> newRemainingBlockSizes = sublist(blockSizes, 1, blockSizes.size());
            for (int itemsThisBlock = minNeededThisBlock; itemsThisBlock <= maxPossibleThisBlock; itemsThisBlock++) {
                int newRemainingCandidates = numCandidates - itemsThisBlock;

                double suffixes = calculateLegalAssignmentsAux(newRemainingCandidates, newRemainingBlockSizes
                        , currentBlockNumber + 1, candidatesAssignedSoFar + itemsThisBlock);

                assignments = assignments + getFromPmfCache(maxPossibleThisBlock, itemsThisBlock) * suffixes;

            }
            return assignments;
        }
    }

    private double calculateLegalAssignmentsAux(int remainingCandidates, ArrayList<Integer> remainingBlockSizes
            , int currentBlockNumber, int candidatesAssignedSoFar) {
        LegalAssignmentKey key = new LegalAssignmentKey(remainingCandidates, remainingBlockSizes
                , currentBlockNumber, candidatesAssignedSoFar);
        if (legalAssignmentCache.get(key) != null) {
            return legalAssignmentCache.get(key);
        } else {
            double value = findLegalAssignmentsAux(remainingCandidates, remainingBlockSizes, currentBlockNumber, candidatesAssignedSoFar);
            legalAssignmentCache.put(key, value);
            return value;
        }

    }


}