HazyResearch/numbskull

View on GitHub
numbskull/inference.py

Summary

Maintainability
F
3 wks
Test Coverage
"""TODO."""

from __future__ import print_function, absolute_import
import numba
from numba import jit
import numpy as np
import math


@jit(nopython=True, cache=True, nogil=True)
def gibbsthread(shardID, nshards, var_copy, weight_copy, weight, variable,
                factor, fmap, vmap, factor_index, Z, cstart,
                count, var_value, weight_value, sample_evidence, burnin):
    """TODO."""
    # Indentify start and end variable
    nvar = variable.shape[0]
    start = (shardID * nvar) // nshards
    end = ((shardID + 1) * nvar) // nshards
    # TODO: give option do not store result, or just store tally
    for var_samp in range(start, end):
        if variable[var_samp]["isEvidence"] == 4:
            # This variable is not owned by this machine
            continue
        if variable[var_samp]["isEvidence"] == 0 or sample_evidence:
            v = draw_sample(var_samp, var_copy, weight_copy, weight, variable,
                            factor, fmap, vmap, factor_index, Z[shardID],
                            var_value, weight_value)
            var_value[var_copy][var_samp] = v
            if not burnin:
                if variable[var_samp]["cardinality"] == 2:
                    count[cstart[var_samp]] += v
                else:
                    count[cstart[var_samp] + v] += 1


@jit(nopython=True, cache=True, nogil=True)
def draw_sample(var_samp, var_copy, weight_copy, weight, variable, factor,
                fmap, vmap, factor_index, Z, var_value, weight_value):
    """TODO."""
    cardinality = variable[var_samp]["cardinality"]
    for value in range(cardinality):
        Z[value] = np.exp(potential(var_samp, value, var_copy, weight_copy,
                                    weight, variable, factor, fmap,
                                    vmap, factor_index, var_value,
                                    weight_value))

    for j in range(1, cardinality):
        Z[j] += Z[j - 1]

    z = np.random.rand() * Z[cardinality - 1]

    return np.argmax(Z[:cardinality] >= z)


@jit(nopython=True, cache=True, nogil=True)
def potential(var_samp, value, var_copy, weight_copy, weight, variable, factor,
              fmap, vmap, factor_index, var_value, weight_value):
    """TODO."""
    p = 0.0
    varval_off = value
    if variable[var_samp]["dataType"] == 0:
        varval_off = 0
    vtf = vmap[variable[var_samp]["vtf_offset"] + varval_off]
    start = vtf["factor_index_offset"]
    end = start + vtf["factor_index_length"]
    for k in range(start, end):
        factor_id = factor_index[k]
        p += weight_value[weight_copy][factor[factor_id]["weightId"]] * \
            eval_factor(factor_id, var_samp, value, var_copy, variable,
                        factor, fmap, var_value)
    return p


FACTORS = {
    # Factor functions for boolean variables
    "NOOP": -1,
    "IMPLY_NATURAL": 0,
    "OR": 1,
    "EQUAL": 3,
    "AND": 2,
    "ISTRUE": 4,
    "LINEAR": 7,
    "RATIO": 8,
    "LOGICAL": 9,
    "IMPLY_MLN": 13,

    # Factor functions for categorical variables
    "AND_CAT": 12,
    "OR_CAT": 14,
    "EQUAL_CAT_CONST": 15,
    "IMPLY_NATURAL_CAT": 16,
    "IMPLY_MLN_CAT": 17,

    # Factor functions for generative models for data programming.
    #
    # These functions accept two types of categorical variables:
    #
    # y \in {1, -1} corresponding to latent labels, and
    # l \in {1, 0, -1} corresponding to labeling function outputs.
    #
    # The values of y are mapped to Numbskull variables y_index
    #     via {-1: 0, 1: 1}, and
    # the values of l are mapped to Numbskull variables l_index
    #     via {-1: 0, 0: 1, 1: 2}.

    # h(y) := y
    "DP_GEN_CLASS_PRIOR": 18,

    # h(l) := l
    "DP_GEN_LF_PRIOR": 19,

    # h(l) := l * l
    "DP_GEN_LF_PROPENSITY": 20,

    # h(y, l) := y * l
    "DP_GEN_LF_ACCURACY": 21,

    # h(l) := y * l * l
    "DP_GEN_LF_CLASS_PROPENSITY": 22,

    # l_2 fixes errors made by l_1
    #
    # h(y, l_1, l_2) := if l_1 == 0 and l_2 != 0: -1,
    #                   elif l_1 == -1 * y and l_2 == y: 1,
    #                   else: 0
    "DP_GEN_DEP_FIXING": 23,

    # l_2 reinforces the output of l_1
    #
    # h(y, l_1, l_2) := if l_1 == 0 and l_2 != 0: -1,
    #                   elif l_1 == y and l_2 == y: 1,
    #                   else: 0
    "DP_GEN_DEP_REINFORCING": 24,

    # h(l_1, l_2) := if l_1 != 0 and l_2 != 0: -1, else: 0
    "DP_GEN_DEP_EXCLUSIVE": 25,

    #h(l_1, l_2) := if l_1 == l_2: 1, else: 0
    "DP_GEN_DEP_SIMILAR": 26,

    # Factor functions for distribution
    "UFO": 30
}

for (key, value) in FACTORS.items():
    exec("FUNC_" + key + " = " + str(value))


@jit(nopython=True, cache=True, nogil=True)
def eval_factor(factor_id, var_samp, value, var_copy, variable, factor, fmap,
                var_value):
    """TODO."""
    ####################
    # BINARY VARIABLES #
    ####################
    fac = factor[factor_id]
    ftv_start = fac["ftv_offset"]
    ftv_end = ftv_start + fac["arity"]

    if fac["factorFunction"] == FUNC_NOOP:
        return 0
    elif fac["factorFunction"] == FUNC_IMPLY_NATURAL:
        for l in range(ftv_start, ftv_end):
            v = value if (fmap[l]["vid"] == var_samp) else \
                var_value[var_copy][fmap[l]["vid"]]
            if v == 0:
                # Early return if body is not satisfied
                return 0

        # If this point is reached, body must be true
        l = ftv_end - 1
        head = value if (fmap[l]["vid"] == var_samp) else \
            var_value[var_copy][fmap[l]["vid"]]
        if head:
            return 1
        return -1
    elif factor[factor_id]["factorFunction"] == FUNC_OR:
        for l in range(ftv_start, ftv_end):
            v = value if (fmap[l]["vid"] == var_samp) else \
                var_value[var_copy][fmap[l]["vid"]]
            if v == 1:
                return 1
        return -1
    elif factor[factor_id]["factorFunction"] == FUNC_EQUAL:
        v = value if (fmap[ftv_start]["vid"] == var_samp) \
            else var_value[var_copy][fmap[ftv_start]["vid"]]
        for l in range(ftv_start + 1, ftv_end):
            w = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v != w:
                return -1
        return 1
    elif factor[factor_id]["factorFunction"] == FUNC_AND \
            or factor[factor_id]["factorFunction"] == FUNC_ISTRUE:
        for l in range(ftv_start, ftv_end):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v == 0:
                return -1
        return 1
    elif factor[factor_id]["factorFunction"] == FUNC_LINEAR:
        res = 0
        head = value if (fmap[ftv_end - 1]["vid"] == var_samp) \
            else var_value[var_copy][fmap[ftv_end - 1]["vid"]]
        for l in range(ftv_start, ftv_end - 1):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v == head:
                res += 1
        # This does not match Dimmwitted, but matches the eq in the paper
        return res
    elif factor[factor_id]["factorFunction"] == FUNC_RATIO:
        res = 1
        head = value if (fmap[ftv_end - 1]["vid"] == var_samp) \
            else var_value[var_copy][fmap[ftv_end - 1]["vid"]]
        for l in range(ftv_start, ftv_end - 1):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v == head:
                res += 1
        # This does not match Dimmwitted, but matches the eq in the paper
        return math.log(res)  # TODO: use log2?
    elif factor[factor_id]["factorFunction"] == FUNC_LOGICAL:
        head = value if (fmap[ftv_end - 1]["vid"] == var_samp) \
            else var_value[var_copy][fmap[ftv_end - 1]["vid"]]
        for l in range(ftv_start, ftv_end - 1):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v == head:
                return 1
        return 0
    elif factor[factor_id]["factorFunction"] == FUNC_IMPLY_MLN:
        for l in range(ftv_start, ftv_end - 1):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v == 0:
                # Early return if body is not satisfied
                return 1

        # If this point is reached, body must be true
        l = ftv_end - 1
        head = value if (fmap[l]["vid"] == var_samp) \
            else var_value[var_copy][l]
        if head:
            return 1
        return 0

    #########################
    # CATEGORICAL VARIABLES #
    #########################
    elif factor[factor_id]["factorFunction"] == FUNC_AND_CAT \
            or factor[factor_id]["factorFunction"] == FUNC_EQUAL_CAT_CONST:
        for l in range(ftv_start, ftv_end):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v != fmap[l]["dense_equal_to"]:
                return 0
        return 1
    elif factor[factor_id]["factorFunction"] == FUNC_OR_CAT:
        for l in range(ftv_start, ftv_end):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v == fmap[l]["dense_equal_to"]:
                return 1
        return -1
    elif factor[factor_id]["factorFunction"] == FUNC_IMPLY_NATURAL_CAT:
        for l in range(ftv_start, ftv_end - 1):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v != fmap[l]["dense_equal_to"]:
                # Early return if body is not satisfied
                return 0

        # If this point is reached, body must be true
        l = ftv_end - 1
        head = value if (fmap[l]["vid"] == var_samp) \
            else var_value[var_copy][l]
        if head == fmap[l]["dense_equal_to"]:
            return 1
        return -1
    elif factor[factor_id]["factorFunction"] == FUNC_IMPLY_MLN_CAT:
        for l in range(ftv_start, ftv_end - 1):
            v = value if (fmap[l]["vid"] == var_samp) \
                else var_value[var_copy][fmap[l]["vid"]]
            if v != fmap[l]["dense_equal_to"]:
                # Early return if body is not satisfied
                return 1

        # If this point is reached, body must be true
        l = ftv_end - 1
        head = value if (fmap[l]["vid"] == var_samp) \
            else var_value[var_copy][l]
        if head == fmap[l]["dense_equal_to"]:
            return 1
        return 0

    #####################
    # DATA PROGRAMMING  #
    # GENERATIVE MODELS #
    #####################
    elif factor[factor_id]["factorFunction"] == FUNC_DP_GEN_CLASS_PRIOR:
        # NB: this doesn't make sense for categoricals
        y_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        return 1 if y_index == 1 else -1
    elif factor[factor_id]["factorFunction"] == FUNC_DP_GEN_LF_PRIOR:
        # NB: this doesn't make sense for categoricals
        l_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        if l_index == 2:
            return -1
        elif l_index == 0:
            return 0
        else:
            return 1
    elif factor[factor_id]["factorFunction"] == FUNC_DP_GEN_LF_PROPENSITY:
        l_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        abstain = variable[fmap[ftv_start]["vid"]]["cardinality"] - 1
        return 0 if l_index == abstain else 1
    elif factor[factor_id]["factorFunction"] == FUNC_DP_GEN_LF_ACCURACY:
        y_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        l_index = value if fmap[ftv_start + 1]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + 1]["vid"]]
        abstain = variable[fmap[ftv_start + 1]["vid"]]["cardinality"] - 1
        if l_index == abstain:
            return 0
        elif y_index == l_index:
            return 1
        else:
            return -1
    elif factor[factor_id]["factorFunction"] == \
            FUNC_DP_GEN_LF_CLASS_PROPENSITY:
        # NB: this doesn't make sense for categoricals
        y_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        l_index = value if fmap[ftv_start + 1]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + 1]["vid"]]
        abstain = variable[fmap[ftv_start + 1]["vid"]]["cardinality"] - 1
        if l_index == abstain:
            return 0
        elif y_index == 1:
            return 1
        else:
            return -1
    elif factor[factor_id]["factorFunction"] == FUNC_DP_GEN_DEP_FIXING:
        # NB: this doesn't make sense for categoricals
        y_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        l1_index = value if fmap[ftv_start + 1]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + 1]["vid"]]
        l2_index = value if fmap[ftv_start + 2]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + 2]["vid"]]
        abstain = variable[fmap[ftv_start + 1]["vid"]]["cardinality"] - 1
        if l1_index == abstain:
            return -1 if l2_index != 1 else 0
        elif l1_index == 0 and l2_index == 1 and y_index == 1:
            return 1
        elif l1_index == 1 and l2_index == 0 and y_index == 0:
            return 1
        else:
            return 0
    elif factor[factor_id]["factorFunction"] == FUNC_DP_GEN_DEP_REINFORCING:
        # NB: this doesn't make sense for categoricals
        y_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        l1_index = value if fmap[ftv_start + 1]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + 1]["vid"]]
        l2_index = value if fmap[ftv_start + 2]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + 2]["vid"]]
        abstain = variable[fmap[ftv_start + 1]["vid"]]["cardinality"] - 1
        if l1_index == abstain:
            return -1 if l2_index != 1 else 0
        elif l1_index == 0 and l2_index == 0 and y_index == 0:
            return 1
        elif l1_index == 1 and l2_index == 1 and y_index == 1:
            return 1
        else:
            return 0
    elif factor[factor_id]["factorFunction"] == FUNC_DP_GEN_DEP_EXCLUSIVE:
        l1_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        l2_index = value if fmap[ftv_start + 1]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + 1]["vid"]]
        abstain = variable[fmap[ftv_start]["vid"]]["cardinality"] - 1
        return 0 if l1_index == abstain or l2_index == abstain else -1
    elif factor[factor_id]["factorFunction"] == FUNC_DP_GEN_DEP_SIMILAR:
        l1_index = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        l2_index = value if fmap[ftv_start + 1]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + 1]["vid"]]
        return 1 if l1_index == l2_index else 0

    ###########################################
    # FACTORS FOR OPTIMIZING DISTRIBUTED CODE #
    ###########################################
    elif factor[factor_id]["factorFunction"] == FUNC_UFO:
        v = value if fmap[ftv_start]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start]["vid"]]
        if v == 0:
            return 0

        return value if fmap[ftv_start + v - 1]["vid"] == var_samp else \
            var_value[var_copy][fmap[ftv_start + v - 1]["vid"]]

    ######################
    # FACTOR NOT DEFINED #
    ######################
    else:  # FUNC_UNDEFINED
        print("Error: Factor Function", factor[factor_id]["factorFunction"],
              "( used in factor", factor_id, ") is not implemented.")
        raise NotImplementedError("Factor function is not implemented.")