src/main/java/com/hongliangjie/fugue/topicmodeling/LDA/LDA.java
package com.hongliangjie.fugue.topicmodeling.LDA;
import com.google.gson.Gson;
import com.hongliangjie.fugue.Message;
import com.hongliangjie.fugue.distributions.MultinomialDistribution;
import com.hongliangjie.fugue.serialization.Document;
import com.hongliangjie.fugue.serialization.Feature;
import com.hongliangjie.fugue.serialization.Model;
import com.hongliangjie.fugue.topicmodeling.TopicModel;
import com.hongliangjie.fugue.utils.LogGamma;
import com.hongliangjie.fugue.utils.MathExp;
import com.hongliangjie.fugue.utils.MathLog;
import com.hongliangjie.fugue.utils.RandomUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
/**
* Created by liangjie on 10/29/14.
*/
public class LDA extends TopicModel {
protected List<Document> internalDocs;
protected double[] sample_buffer;
protected HashMap<String, Integer> wordsForwardIndex;
protected HashMap<Integer, String> wordsInvertedIndex;
protected List<ModelCountainer> modelPools;
protected Message cmdArg;
protected RandomUtils randomGNR;
protected MathExp mathExp;
protected MathLog mathLog;
protected double[] iterationTimes;
protected int TOPIC_NUM;
protected int MAX_ITER;
protected int CURRENT_ITER;
protected int BURN_IN;
protected int INTERVAL;
protected int TOTAL_TOKEN;
protected int SAVED;
protected int TOTAL_SAVES;
protected static final Logger LOGGER = LogManager.getLogger("FUGUE-TOPICMODELING");
protected final class ModelCountainer{
public double[] alpha;
public double[] beta;
public double betaSum;
public double alphaSum;
public List<int[]> wordTopicCounts; // how many times a term appears in a topic
public int[] topicCounts; // total number of terms that are assigned to a topic
public List<int[]> docTopicBuffers; // !!! DENSE !!!
public List<List<Integer>> docTopicAssignments; // !!! DENSE !!!
public Map<String, int[]> outsideWordTopicCounts; // how many times a term appears in a topic
public Map<String, Double> outsideBeta;
public double[][] phi;
public void computePhi(){
phi = new double[wordTopicCounts.size()][alpha.length];
for(int i = 0; i < wordTopicCounts.size(); i++){
for (int j = 0; j < wordTopicCounts.get(i).length; j++){
phi[i][j] = (wordTopicCounts.get(i)[j] + beta[i]) / (topicCounts[j] + betaSum);
}
}
}
}
public LDA() {
this(new RandomUtils(0));
LOGGER.info("Random Number Generator: Native");
}
public LDA(RandomUtils r){
TOPIC_NUM = 1;
MAX_ITER = 250;
BURN_IN = 100;
INTERVAL = 5;
TOTAL_TOKEN = 0;
SAVED = 0;
randomGNR = r;
TOTAL_SAVES = 10;
}
@Override
@SuppressWarnings("unchecked")
public void setMessage(Message m) {
cmdArg = m;
String randomGNRStr = cmdArg.getParam("random").toString();
if (randomGNRStr != null){
if ("native".equals(randomGNRStr)){
randomGNR = new RandomUtils(0);
LOGGER.info("Random Number Generator: Native");
}
else if ("deterministic".equals(randomGNRStr)){
randomGNR = new RandomUtils(1);
LOGGER.info("Random Number Generator: Deterministic");
}
else{
randomGNR = new RandomUtils(0);
LOGGER.info("Random Number Generator: Native");
}
}
else{
randomGNR = new RandomUtils(0);
LOGGER.info("Random Number Generator: Native");
}
int expInt = Integer.parseInt(cmdArg.getParam("exp").toString());
mathExp = new MathExp(expInt);
LOGGER.info("Math Exp Function:" + expInt);
int logInt = Integer.parseInt(cmdArg.getParam("log").toString());
mathLog = new MathLog(logInt);
LOGGER.info("Math Log Function:" + logInt);
TOPIC_NUM = Integer.parseInt(cmdArg.getParam("topics").toString());
MAX_ITER = Integer.parseInt(cmdArg.getParam("iters").toString());
iterationTimes = new double[MAX_ITER];
internalDocs = (List<Document>) cmdArg.getParam("docs");
}
@Override
public Message getMessage(){
cmdArg.setParam("invertedIndex", wordsInvertedIndex);
cmdArg.setParam("forwardIndex", wordsForwardIndex);
return cmdArg;
}
public void rebuildIndex(){
/* build index */
LOGGER.info("Start to build index.");
wordsForwardIndex = new HashMap<String, Integer>();
wordsForwardIndex.clear();
wordsInvertedIndex = new HashMap<Integer, String>();
wordsInvertedIndex.clear();
for (int d = 0; d < internalDocs.size(); d++) {
for (Feature f : internalDocs.get(d).getFeatures()) {
String feature_name = f.getFeatureName();
if (!wordsForwardIndex.containsKey(feature_name)) {
wordsForwardIndex.put(feature_name, wordsForwardIndex.size());
Integer word_index = wordsForwardIndex.get(feature_name);
wordsInvertedIndex.put(word_index, feature_name);
}
}
}
LOGGER.info("Index Size:" + wordsForwardIndex.size());
}
protected void initTrainModel() {
rebuildIndex();
LOGGER.info("Start to initialize model.");
LOGGER.info("Topic Num:" + TOPIC_NUM);
LOGGER.info("ForwardIndex Size:" + wordsForwardIndex.size());
modelPools = new ArrayList<ModelCountainer>();
modelPools.add(0, new ModelCountainer());
/* initialize all model parameters with fixed sizes */
modelPools.get(0).wordTopicCounts = new ArrayList<int[]>(wordsForwardIndex.size());
modelPools.get(0).beta = new double[wordsForwardIndex.size()];
for (int i = 0; i < wordsForwardIndex.size(); i++) {
int[] topicCounts = new int[TOPIC_NUM];
modelPools.get(0).wordTopicCounts.add(topicCounts);
modelPools.get(0).beta[i] = 0.01;
modelPools.get(0).betaSum += modelPools.get(0).beta[i];
}
modelPools.get(0).docTopicBuffers = new ArrayList<int[]>(internalDocs.size());
modelPools.get(0).docTopicAssignments = new ArrayList<List<Integer>>(internalDocs.size());
for (int i = 0; i < internalDocs.size(); i++) {
int[] topicBuffer = new int[TOPIC_NUM];
modelPools.get(0).docTopicBuffers.add(topicBuffer);
modelPools.get(0).docTopicAssignments.add(new ArrayList<Integer>());
}
modelPools.get(0).topicCounts = new int[TOPIC_NUM];
modelPools.get(0).alpha = new double[TOPIC_NUM];
TOTAL_TOKEN = 0;
for (int k = 0; k < TOPIC_NUM; k++)
modelPools.get(0).alpha[k] = 50.0 / TOPIC_NUM;
modelPools.get(0).alphaSum = 50.0; // (50.0/TOPIC_NUM) * TOPIC_NUM
for (int d = 0; d < internalDocs.size(); d++) {
for (Feature f : internalDocs.get(d).getFeatures()) {
String feature_name = f.getFeatureName();
Integer feature_index = wordsForwardIndex.get(feature_name);
// we randomly assign a topic for this token
int topic = randomGNR.nextInt(TOPIC_NUM);
modelPools.get(0).docTopicAssignments.get(d).add(topic);
modelPools.get(0).docTopicBuffers.get(d)[topic]++;
modelPools.get(0).wordTopicCounts.get(feature_index)[topic]++;
modelPools.get(0).topicCounts[topic]++;
TOTAL_TOKEN++;
}
}
LOGGER.info("Term Num:" + modelPools.get(0).wordTopicCounts.size());
LOGGER.info("alphaSum:" + modelPools.get(0).alphaSum);
LOGGER.info("betaSum:" + modelPools.get(0).betaSum);
LOGGER.info("Finished initializing model");
}
protected abstract class HyperparameterOptimization{
public abstract void optimize(); // as hyper-parameter optimization only happens in training, the modelID is always set to 0 and therefore ignored.
}
protected class SliceSampling extends HyperparameterOptimization{
protected int _samplesNum; // the number of samples
protected double _step; // the step used in StepOut
protected int _hyperIterations;
public SliceSampling(){
_samplesNum = Integer.parseInt(cmdArg.getParam("sliceSamples").toString());
_step = Double.parseDouble(cmdArg.getParam("sliceSteps").toString());
_hyperIterations = Integer.parseInt(cmdArg.getParam("sliceIters").toString());
}
protected void copyArray(double[] src, double[] dest){
for(int i = 0; i < src.length; i++)
dest[i] = src[i];
}
@Override
public void optimize() {
double[] alpha = new double[modelPools.get(0).alpha.length];
double alphaSum = modelPools.get(0).alphaSum;
double[] beta = new double[modelPools.get(0).beta.length];
double betaSum = modelPools.get(0).betaSum;
double[] alphaLeft = new double[alpha.length];
double[] alphaRight = new double[alpha.length];
double[] betaLeft = new double[beta.length];
double[] betaRight = new double[beta.length];
double[] alphaNew = new double[alpha.length];
double[] betaNew = new double[beta.length];
double alphaNewSum = 0.0;
double betaNewSum = 0.0;
copyArray(modelPools.get(0).alpha, alpha);
copyArray(modelPools.get(0).beta, beta);
for(int k = 0; k < _samplesNum; k++){
double old_likelihood = likelihood(modelPools.get(0).wordTopicCounts, modelPools.get(0).docTopicBuffers, alpha, beta, alphaSum, betaSum);
double new_likelihood = mathLog.compute(randomGNR.nextDouble()) + old_likelihood;
// stepping out
for (int i = 0; i < alpha.length; i++){
alphaLeft[i] = alpha[i] - randomGNR.nextDouble() * _step;
alphaRight[i] = alphaLeft[i] + _step;
}
for (int i = 0; i < beta.length; i++){
betaLeft[i] = beta[i] - randomGNR.nextDouble() * _step;
betaRight[i] = betaLeft[i] + _step;
}
// This stepping out is simplified, please look at Fig 3. in Neal's "Slice Sampling" paper
for(int j = 0; j < _hyperIterations; j++){
alphaNewSum = 0.0;
betaNewSum = 0.0;
for(int i = 0; i < alpha.length; i++){
alphaNew[i] = randomGNR.nextDouble() * (alphaRight[i] - alphaLeft[i]) + alphaLeft[i];
alphaNewSum += alphaNew[i];
}
for(int i = 0; i < beta.length; i++){
betaNew[i] = randomGNR.nextDouble() * (betaRight[i] - betaLeft[i]) + betaLeft[i];
betaNewSum += betaNew[i];
}
double test_likelihood = likelihood(modelPools.get(0).wordTopicCounts, modelPools.get(0).docTopicBuffers, alphaNew, betaNew, alphaNewSum, betaNewSum);
if (test_likelihood > new_likelihood){
copyArray(alphaNew, alpha);
alphaSum = alphaNewSum;
copyArray(betaNew, beta);
betaSum = betaNewSum;
LOGGER.info("[Slice Sampling]: Sample " + k + " A new set of hyper-parameter with likelihood " + test_likelihood);
break;
}
else{
for(int i = 0; i < alpha.length; i++){
if(alphaNew[i] < alpha[i]){
alphaLeft[i] = alphaNew[i];
}
else{
alphaRight[i] = alphaNew[i];
}
}
for(int i = 0; i < beta.length; i++){
if(betaNew[i] < beta[i]){
betaLeft[i] = betaNew[i];
}
else{
betaRight[i] = betaNew[i];
}
}
}
}
}
// only keep the last sample for both alpha and beta
// update back to models
copyArray(alpha, modelPools.get(0).alpha);
modelPools.get(0).alphaSum = alphaSum;
copyArray(beta, modelPools.get(0).beta);
modelPools.get(0).betaSum = betaSum;
}
}
public double likelihood(int modelID){
return likelihood(modelPools.get(modelID).wordTopicCounts, modelPools.get(modelID).docTopicBuffers, modelPools.get(modelID).alpha, modelPools.get(modelID).beta, modelPools.get(modelID).alphaSum, modelPools.get(modelID).betaSum);
}
public double likelihood(List<int[]> wordTopicCounts, List<int[]> docTopicBuffers, double[] alpha, double[] beta, double alphaSum, double betaSum) {
double result_1 = 0.0;
double result_2 = 0.0;
// topics side likelihood
for (int v = 0; v < beta.length; v++) {
result_1 += LogGamma.logGamma(beta[v]);
}
result_1 = TOPIC_NUM * (LogGamma.logGamma(betaSum) - result_1);
for (int k = 0; k < TOPIC_NUM; k++) {
double part_1 = 0.0;
double part_2 = 0.0;
for (int v=0; v < modelPools.get(0).wordTopicCounts.size(); v++) {
part_1 = part_1 + LogGamma.logGamma(wordTopicCounts.get(v)[k] + beta[v]);
part_2 = part_2 + (wordTopicCounts.get(v)[k] + beta[v]);
}
result_1 = result_1 + part_1 - LogGamma.logGamma(part_2);
}
// document side likelihood
for (int k = 0; k < TOPIC_NUM; k++) {
result_2 += LogGamma.logGamma(alpha[k]);
}
result_2 = docTopicBuffers.size() * (LogGamma.logGamma(alphaSum) - result_2);
for (int d = 0; d < docTopicBuffers.size(); d++) {
double part_1 = 0.0;
double part_2 = 0.0;
for (int k = 0; k < TOPIC_NUM; k++) {
part_1 = part_1 + LogGamma.logGamma(docTopicBuffers.get(d)[k] + alpha[k]);
part_2 = part_2 + docTopicBuffers.get(d)[k] + alpha[k];
}
result_2 = result_2 + part_1 - LogGamma.logGamma(part_2);
}
return result_1 + result_2;
}
protected abstract class Sampler{
protected MultinomialDistribution dist;
protected ProcessDocuments processor;
public abstract int draw(int modelID, int featureID, double randomRV);
public void setProcessor(ProcessDocuments proc){
processor = proc;
}
}
protected class GibbsBinarySampling extends GibbsSampling{
public GibbsBinarySampling(){
dist = new MultinomialDistribution(TOPIC_NUM, mathLog, mathExp, "binary");
LOGGER.info("Gibbs Sampling: Binary");
}
}
protected class GibbsSampling extends Sampler{
public GibbsSampling(){
dist = new MultinomialDistribution(TOPIC_NUM, mathLog, mathExp, "normal");
LOGGER.info("Gibbs Sampling: Normal");
}
@Override
public int draw(int modelID, int featureID, double randomRV){
processor.computeProbabilities(modelID, featureID);
dist.setProbabilities(sample_buffer);
return dist.sample(randomRV);
}
}
protected class GibbsLogSampling extends Sampler{
public GibbsLogSampling(){
dist = new MultinomialDistribution(TOPIC_NUM, mathLog, mathExp, "log");
LOGGER.info("Gibbs Sampling: Log");
}
@Override
public int draw(int modelID, int featureID, double randomRV){
processor.computeLogProbabilities(modelID, featureID);
dist.setProbabilities(sample_buffer);
return dist.sample(randomRV);
}
}
protected class ProcessDocuments{
protected int[] docTopicBuffer;
protected List<Integer> docTopicAssignment;
protected Sampler sampler;
protected HyperparameterOptimization hyperOpt;
public ProcessDocuments(){
this(new GibbsSampling(), null);
}
public ProcessDocuments(Sampler s, HyperparameterOptimization hyper){
sample_buffer = new double[TOPIC_NUM];
sampler = s;
hyperOpt = hyper;
}
public void computeProbabilities(int modelID, int featureID){
// calculate normal probabilities
for (int k = 0; k < TOPIC_NUM; k++) {
sample_buffer[k] = ((modelPools.get(modelID).wordTopicCounts.get(featureID)[k] + modelPools.get(modelID).beta[featureID]) / (modelPools.get(modelID).topicCounts[k] + modelPools.get(modelID).betaSum)) * (docTopicBuffer[k] + modelPools.get(modelID).alpha[k]);
}
}
public void computeLogProbabilities(int modelID, int featureID){
// calculate log-probabilities
for (int k = 0; k < TOPIC_NUM; k++) {
sample_buffer[k] = mathLog.compute(docTopicBuffer[k] + modelPools.get(modelID).alpha[k]);
sample_buffer[k] += mathLog.compute(modelPools.get(modelID).wordTopicCounts.get(featureID)[k] + modelPools.get(modelID).beta[featureID]);
sample_buffer[k] -= mathLog.compute(modelPools.get(modelID).topicCounts[k] + modelPools.get(modelID).betaSum);
}
}
protected int sampleOneDoc(List<Document> docs, int index, int modelID){
Document d = docs.get(index);
docTopicAssignment = modelPools.get(modelID).docTopicAssignments.get(index);
docTopicBuffer = modelPools.get(modelID).docTopicBuffers.get(index);
int pos = 0;
for (Feature f : d.getFeatures()) {
String featureName = f.getFeatureName();
Integer featureIndex = wordsForwardIndex.get(featureName);
int current_topic = docTopicAssignment.get(pos);
docTopicBuffer[current_topic]--;
modelPools.get(modelID).wordTopicCounts.get(featureIndex)[current_topic]--;
modelPools.get(modelID).topicCounts[current_topic]--;
double randomRV = randomGNR.nextDouble();
int new_topic = sampler.draw(modelID, featureIndex, randomRV);
docTopicBuffer[new_topic]++;
modelPools.get(modelID).wordTopicCounts.get(featureIndex)[new_topic]++;
modelPools.get(modelID).topicCounts[new_topic]++;
docTopicAssignment.set(pos, new_topic);
pos++;
}
return pos;
}
public void sampleOverDocs(int modelID, List<Document> docs, int start, int end, int maxIter, int save){
int overall_pos = 0;
long overall_startTime = System.currentTimeMillis();
for (CURRENT_ITER = 0; CURRENT_ITER < maxIter; CURRENT_ITER++) {
LOGGER.info("Start to Iteration " + CURRENT_ITER);
long startTime = System.currentTimeMillis();
int num_d = 0;
int total_pos = 0;
for (int d = start; d < end; d++) {
int doc_pos = sampleOneDoc(docs, d, modelID);
overall_pos += doc_pos;
total_pos += doc_pos;
num_d++;
if (num_d % 500 == 0)
LOGGER.info("Processed:" + num_d);
}
LOGGER.info("Finished sampling.");
LOGGER.info("Finished Iteration " + CURRENT_ITER);
if (CURRENT_ITER % 25 == 0) {
double likelihood = likelihood(modelPools.get(modelID).wordTopicCounts, modelPools.get(modelID).docTopicBuffers, modelPools.get(modelID).alpha, modelPools.get(modelID).beta, modelPools.get(modelID).alphaSum, modelPools.get(modelID).betaSum);
LOGGER.info("Iteration " + CURRENT_ITER + " Likelihood:" + Double.toString(likelihood));
}
if ((CURRENT_ITER % 10 == 0) && (save == 1)){
saveModel(0);
}
long endTime = System.currentTimeMillis();
double timeDifference = (endTime - startTime) / 1000.0;
double tokenPerSeconds = (total_pos / 1000.0) / timeDifference;
LOGGER.info("Iteration Duration " + CURRENT_ITER + " " + Double.toString(timeDifference));
LOGGER.info("Tokens (per-K)/Seconds " + CURRENT_ITER + " " + Double.toString(tokenPerSeconds));
iterationTimes[CURRENT_ITER] = timeDifference;
if ((CURRENT_ITER >= BURN_IN) && (CURRENT_ITER % 25 == 0) && (hyperOpt != null)){
// hyper-parameter optimization
LOGGER.info("Start Hyper-parameter Optimization");
hyperOpt.optimize();
LOGGER.info("Finished Hyper-parameter Optimization");
}
}
double averageTime = 0;
for (int k = 0; k < maxIter; k++){
averageTime += iterationTimes[k];
}
long overall_endTime = System.currentTimeMillis();
LOGGER.info("Average Iteration Duration " + Double.toString(averageTime / (double)maxIter));
LOGGER.info("Average Tokens (per-K)/Seconds " + Double.toString((overall_pos / 1000.0) /((overall_endTime - overall_startTime) / 1000.0)));
}
}
protected class ProcessTestDocuments extends ProcessDocuments{
protected double[] modelPerplexity;
protected int[] modelSaved;
public ProcessTestDocuments(Sampler s){
super(s, null);
}
protected double computeTermProbability(double[] theta, int featureID, ModelCountainer m){
double prob = 0.0;
for (int k = 0; k < TOPIC_NUM; k++) {
prob += theta[k] * m.phi[featureID][k];
}
return prob;
}
protected void sampleTestDoc(List<Document> docs, int maxIter, int docIndex) {
// for each test document, half of the document is used to "fold-in" and the other half is used to compute "perplexity"
for (int m = 0; m < modelPools.size(); m++) {
docTopicAssignment = new ArrayList<Integer>();
docTopicBuffer = new int[TOPIC_NUM];
double[] theta = new double[TOPIC_NUM];
List<Feature> currentFeatures = docs.get(docIndex).getFeatures();
int docLength = currentFeatures.size();
int foldIn = docLength / 2;
// firstly init topic assignments
for (int i = 0; i < foldIn; i++) {
// we randomly assign a topic for this token
int topic = randomGNR.nextInt(TOPIC_NUM);
docTopicAssignment.add(topic);
docTopicBuffer[topic]++;
}
int BURN_IN = 0;
for (CURRENT_ITER = 0; CURRENT_ITER < maxIter; CURRENT_ITER++) {
// fold-in
for (int i = 0; i < foldIn; i++) {
String featureName = currentFeatures.get(i).getFeatureName();
Integer featureIndex = wordsForwardIndex.get(featureName);
int current_topic = docTopicAssignment.get(i);
docTopicBuffer[current_topic]--;
double randomRV = randomGNR.nextDouble();
int new_topic = sampler.draw(m, featureIndex, randomRV);
docTopicBuffer[new_topic]++;
docTopicAssignment.set(i, new_topic);
}
if ((CURRENT_ITER >= BURN_IN) && (CURRENT_ITER % 5 == 0)) {
// estimate theta
for (int k = 0; k < TOPIC_NUM; k++) {
theta[k] = (docTopicBuffer[k] + modelPools.get(m).alpha[k]) / (foldIn + modelPools.get(m).alphaSum);
}
// compute perplexity
for (int i = foldIn; i < currentFeatures.size(); i++) {
String featureName = currentFeatures.get(i).getFeatureName();
Integer featureID = wordsForwardIndex.get(featureName);
double prob = computeTermProbability(theta, featureID, modelPools.get(m));
modelPerplexity[m] = modelPerplexity[m] + Math.log(prob);
modelSaved[m] ++;
}
}
}
}
}
@Override
public void sampleOverDocs(int modelID, List<Document> docs, int start, int end, int maxIter, int save){
LOGGER.info("Start to testing.");
long overall_startTime = System.currentTimeMillis();
int num_d = 0;
modelPerplexity = new double[modelPools.size()];
modelSaved = new int[modelPools.size()];
for (int d = start; d < end; d++){
sampleTestDoc(docs, 150, d);
num_d++;
LOGGER.info("Processed:" + d);
}
// average perplexity
double totalAverage = 0.0;
for (int m = 0; m < modelPools.size(); m++){
modelPerplexity[m] = Math.exp(-modelPerplexity[m] / modelSaved[m]);
LOGGER.info("Model " + m + "\t" + modelPerplexity[m]);
totalAverage += modelPerplexity[m];
}
LOGGER.info("Total Average Perplexity:" + totalAverage / modelPools.size());
LOGGER.info("Finished testing.");
long overall_endTime = System.currentTimeMillis();
LOGGER.info("Average Document (per-K)/Seconds " + Double.toString((num_d / 1000.0) /((overall_endTime - overall_startTime) / 1000.0)));
}
}
protected Sampler getSampler(String samplerStr){
Sampler s = null;
if (samplerStr != null) {
if ("normal".equals(samplerStr)) {
s = new GibbsSampling();
}
else if ("log".equals(samplerStr)){
s = new GibbsLogSampling();
}
else if ("binary".equals(samplerStr)){
s = new GibbsBinarySampling();
}
else{
s = new GibbsSampling();
}
}
else{
s = new GibbsSampling();
}
return s;
}
protected HyperparameterOptimization getHyperOpt(String hyperOptStr){
HyperparameterOptimization hyper = null;
if (hyperOptStr != null){
if ("none".equals(hyperOptStr)){
hyper = null;
}
else if ("slice".equals(hyperOptStr)){
hyper = new SliceSampling();
}
else{
hyper = null;
}
}
return hyper;
}
public void train(){
int start = Integer.parseInt(cmdArg.getParam("start").toString());
int end = Integer.parseInt(cmdArg.getParam("end").toString());
if (start < 0)
start = 0;
if (end < 0)
end = internalDocs.size();
train(start, end);
}
public void train(int start, int end){
initTrainModel();
LOGGER.info("Start to perform Gibbs Sampling");
LOGGER.info("MAX_ITER:" + MAX_ITER);
String samplerStr = cmdArg.getParam("LDASampler").toString();
String hyperOptStr = cmdArg.getParam("LDAHyperOpt").toString();
int save = Integer.parseInt(cmdArg.getParam("saveModel").toString());
Sampler s = getSampler(samplerStr);
HyperparameterOptimization hyper = getHyperOpt(hyperOptStr);
ProcessDocuments p = new ProcessDocuments(s, hyper);
s.setProcessor(p);
p.sampleOverDocs(0, internalDocs, start, end, MAX_ITER, save);
if (save == 1)
saveModel(0);
}
public void initTestModels(){
for (ModelCountainer m : modelPools){
m.beta = new double[m.outsideBeta.size()];
for (int v = 0; v < m.outsideBeta.size(); v++ ){
// this is the default value
m.beta[v] = 0.01;
}
m.betaSum = 0.0;
for (Map.Entry<String, Double> entry : m.outsideBeta.entrySet()){
String sKey = entry.getKey();
Double sValue = entry.getValue();
if (wordsForwardIndex.containsKey(sKey)){
int wordIndex = wordsForwardIndex.get(sKey);
m.beta[wordIndex] = sValue;
m.betaSum += sValue;
}
else{
LOGGER.info("[WARNING]: Term:" + sKey + " is not in the dictionary when constructing beta array.");
}
}
m.alphaSum = 0.0;
for (int k = 0; k < m.alpha.length; k ++){
m.alphaSum += m.alpha[k];
}
m.wordTopicCounts = new ArrayList<int[]>();
for (int k = 0; k < m.outsideWordTopicCounts.size(); k ++){
int[] topicCounts = new int[m.alpha.length];
m.wordTopicCounts.add(topicCounts);
}
for (Map.Entry<String, int[]> entry : m.outsideWordTopicCounts.entrySet()){
String sKey = entry.getKey();
if (wordsForwardIndex.containsKey(sKey)){
int wordIndex = wordsForwardIndex.get(sKey);
int[] wordCounts = entry.getValue();
for ( int k = 0; k < wordCounts.length; k ++) {
m.wordTopicCounts.get(wordIndex)[k] = wordCounts[k];
}
}
else{
LOGGER.info("[WARNING]: Term:" + sKey + " is not in the dictionary when constructing word topic counts.");
}
}
if (m.alpha.length != TOPIC_NUM)
TOPIC_NUM = m.alpha.length;
m.computePhi(); // cache phi
LOGGER.info("TOPIC NUM:" + TOPIC_NUM);
LOGGER.info("Outside Term Num:" + m.outsideWordTopicCounts.size());
LOGGER.info("Term Num:" + m.wordTopicCounts.size());
LOGGER.info("alphaSum:" + m.alphaSum);
LOGGER.info("betaSum:" + m.betaSum);
LOGGER.info("Finished initializing model");
}
}
public void test(){
int start = Integer.parseInt(cmdArg.getParam("start").toString());
int end = Integer.parseInt(cmdArg.getParam("end").toString());
if (start < 0)
start = 0;
if (end < 0)
end = internalDocs.size();
test(start, end);
}
public void test(int start, int end){
try {
loadModel();
rebuildIndex();
initTestModels();
LOGGER.info("Start to perform Gibbs Sampling");
LOGGER.info("MAX_ITER:" + MAX_ITER);
String samplerStr = cmdArg.getParam("LDASampler").toString();
Sampler s = getSampler(samplerStr);
ProcessDocuments p = new ProcessTestDocuments(s);
s.setProcessor(p);
p.sampleOverDocs(-1, internalDocs, start, end, MAX_ITER, 0);
} catch (IOException e) {
e.printStackTrace();
}
}
@SuppressWarnings("unchecked")
public void loadModel() throws IOException {
int multipleModels = Integer.parseInt(cmdArg.getParam("multipleModels").toString());
String[] modelFileNames = null;
LOGGER.info("Load Multiple Test Models:" + multipleModels);
if (multipleModels == 1) {
modelFileNames = getModelFiles(true);
}
else{
modelFileNames = new String[1];
modelFileNames[0] = cmdArg.getParam("modelFile").toString();
}
Gson gson = new Gson();
modelPools = new ArrayList<ModelCountainer>();
for (int s = 0; s < modelFileNames.length; s++) {
String modelFileName = modelFileNames[s];
LOGGER.info("Trying to load " + modelFileName);
File modelFile = new File(modelFileName);
if (modelFile.exists() && !modelFile.isDirectory()) {
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(modelFile), "UTF8"));
String line = br.readLine();
if (line != null){
Model obj = gson.fromJson(line, LDAModel.class);
Message msg = obj.getParameters();
ModelCountainer currentModel = new ModelCountainer();
currentModel.alpha = (double[])msg.getParam("alpha");
currentModel.outsideBeta = (Map<String, Double>)msg.getParam("beta");
currentModel.outsideWordTopicCounts = (Map<String, int[]>)msg.getParam("wordTopicCounts");
currentModel.topicCounts = (int[])msg.getParam("topicCounts");
modelPools.add(currentModel);
}
LOGGER.info("Loaded " + modelFileName);
}
}
}
private String[] getModelFiles(boolean all){
String[] outputFileParts = cmdArg.getParam("modelFile").toString().split(Pattern.quote("."));
StringBuilder outputFilePrefix = new StringBuilder();
for(int i = 0; i < outputFileParts.length - 1; i ++){
outputFilePrefix.append(outputFileParts[i] + ".");
}
if (!all) {
String[] oneFile = new String[1];
outputFilePrefix.append(Integer.toString(SAVED % TOTAL_SAVES) + ".");
outputFilePrefix.append(outputFileParts[outputFileParts.length - 1]);
String outputFileName = outputFilePrefix.toString();
oneFile[0] = outputFileName;
return oneFile;
}
String[] returnFiles = new String[TOTAL_SAVES];
for(int i = 0; i < TOTAL_SAVES; i++){
StringBuilder firstPart = new StringBuilder(outputFilePrefix.toString());
firstPart.append(Integer.toString(i) + ".");
firstPart.append(outputFileParts[outputFileParts.length - 1]);
String outputFileName = firstPart.toString();
returnFiles[i] = outputFileName;
}
return returnFiles;
}
public void saveModel(int modelID) {
String outputFileName = getModelFiles(false)[0];
LOGGER.info("Starting to save model to:" + outputFileName);
Gson gson = new Gson();
Model obj = new LDAModel();
cmdArg.setParam("alpha", modelPools.get(modelID).alpha);
cmdArg.setParam("beta", modelPools.get(modelID).beta);
cmdArg.setParam("topicCounts", modelPools.get(modelID).topicCounts);
cmdArg.setParam("wordTopicCounts", modelPools.get(modelID).wordTopicCounts);
cmdArg.setParam("invertedIndex", wordsInvertedIndex);
obj.setParameters(cmdArg);
try {
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFileName), "UTF8"));
String json = gson.toJson(obj);
bw.write(json);
bw.close();
SAVED ++;
} catch (IOException e) {
e.printStackTrace();
}
LOGGER.info("Finished save model to:" + outputFileName);
}
}