deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.models.word2vec.wordstore;
import lombok.Data;
import lombok.NonNull;
import lombok.val;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.nd4j.common.util.ThreadUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threadly.concurrent.PriorityScheduler;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
public class VocabConstructor<T extends SequenceElement> {
private List<VocabSource<T>> sources = new ArrayList<>();
private VocabCache<T> cache;
private Collection<String> stopWords;
private boolean useAdaGrad = false;
private boolean fetchLabels = false;
private int limit;
private AtomicLong seqCount = new AtomicLong(0);
private InvertedIndex<T> index;
private boolean enableScavenger = false;
private T unk;
private boolean allowParallelBuilder = true;
private boolean lockf = false;
protected static final Logger log = LoggerFactory.getLogger(VocabConstructor.class);
private VocabConstructor() {
}
/**
* Placeholder for future implementation
* @return
*/
protected WeightLookupTable<T> buildExtendedLookupTable() {
return null;
}
/**
* Placeholder for future implementation
* @return
*/
protected VocabCache<T> buildExtendedVocabulary() {
return null;
}
/**
* This method transfers existing WordVectors model into current one
*
* @param wordVectors
* @return
*/
@SuppressWarnings("unchecked") // method is safe, since all calls inside are using generic SequenceElement methods
public VocabCache<T> buildMergedVocabulary(@NonNull WordVectors wordVectors, boolean fetchLabels) {
return buildMergedVocabulary((VocabCache<T>) wordVectors.vocab(), fetchLabels);
}
/**
* This method returns total number of sequences passed through VocabConstructor
*
* @return
*/
public long getNumberOfSequences() {
return seqCount.get();
}
/**
* This method transfers existing vocabulary into current one
*
* Please note: this method expects source vocabulary has Huffman tree indexes applied
*
* @param vocabCache
* @return
*/
public VocabCache<T> buildMergedVocabulary(@NonNull VocabCache<T> vocabCache, boolean fetchLabels) {
if (cache == null)
cache = new AbstractCache.Builder<T>().build();
for (int t = 0; t < vocabCache.numWords(); t++) {
String label = vocabCache.wordAtIndex(t);
if (label == null)
continue;
T element = vocabCache.wordFor(label);
// skip this element if it's a label, and user don't want labels to be merged
if (!fetchLabels && element.isLabel())
continue;
cache.addToken(element);
cache.addWordToIndex(element.getIndex(), element.getLabel());
// backward compatibility code
cache.putVocabWord(element.getLabel());
}
if (cache.numWords() == 0)
throw new IllegalStateException("Source VocabCache has no indexes available, transfer is impossible");
/*
Now, when we have transferred vocab, we should roll over iterator, and gather labels, if any
*/
log.info("Vocab size before labels: " + cache.numWords());
if (fetchLabels) {
for (VocabSource<T> source : sources) {
SequenceIterator<T> iterator = source.getIterator();
iterator.reset();
while (iterator.hasMoreSequences()) {
Sequence<T> sequence = iterator.nextSequence();
seqCount.incrementAndGet();
if (sequence.getSequenceLabels() != null)
for (T label : sequence.getSequenceLabels()) {
if (!cache.containsWord(label.getLabel())) {
label.markAsLabel(true);
label.setSpecial(true);
label.setIndex(cache.numWords());
cache.addToken(label);
cache.addWordToIndex(label.getIndex(), label.getLabel());
// backward compatibility code
cache.putVocabWord(label.getLabel());
}
}
}
}
}
log.info("Vocab size after labels: " + cache.numWords());
return cache;
}
public VocabCache<T> transferVocabulary(@NonNull VocabCache<T> vocabCache, boolean buildHuffman) {
val result = cache != null ? cache : new AbstractCache.Builder<T>().build();
for (val v: vocabCache.tokens()) {
result.addToken(v);
// optionally transferring indices
if (v.getIndex() >= 0)
result.addWordToIndex(v.getIndex(), v.getLabel());
else
result.addWordToIndex(result.numWords(), v.getLabel());
}
if (buildHuffman) {
val huffman = new Huffman(result.vocabWords());
huffman.build();
huffman.applyIndexes(result);
}
return result;
}
public void processDocument(AbstractCache<T> targetVocab, Sequence<T> document,
AtomicLong finalCounter, AtomicLong loopCounter) {
try {
Map<String, AtomicLong> seqMap = new HashMap<>();
if (fetchLabels && document.getSequenceLabels() != null) {
for (T labelWord : document.getSequenceLabels()) {
if (!targetVocab.hasToken(labelWord.getLabel())) {
labelWord.setSpecial(true);
labelWord.markAsLabel(true);
labelWord.setElementFrequency(1);
targetVocab.addToken(labelWord);
}
}
}
List<String> tokens = document.asLabels();
for (String token : tokens) {
if (stopWords != null && stopWords.contains(token))
continue;
if (token == null || token.isEmpty())
continue;
if (!targetVocab.containsWord(token)) {
T element = document.getElementByLabel(token);
element.setElementFrequency(1);
element.setSequencesCount(1);
targetVocab.addToken(element);
loopCounter.incrementAndGet();
// if there's no such element in tempHolder, it's safe to set seqCount to 1
seqMap.put(token, new AtomicLong(0));
} else {
targetVocab.incrementWordCount(token);
// if element exists in tempHolder, we should update it seqCount, but only once per sequence
if (!seqMap.containsKey(token)) {
seqMap.put(token, new AtomicLong(1));
T element = targetVocab.wordFor(token);
element.incrementSequencesCount();
}
if (index != null) {
if (document.getSequenceLabel() != null) {
index.addWordsToDoc(index.numDocuments(), document.getElements(), document.getSequenceLabel());
} else {
index.addWordsToDoc(index.numDocuments(), document.getElements());
}
}
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
finally {
finalCounter.incrementAndGet();
}
}
/**
* This method scans all sources passed through builder, and returns all words as vocab.
* If TargetVocabCache was set during instance creation, it'll be filled too.
*
*
* @return
*/
public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
long lastTime = System.currentTimeMillis();
long lastSequences = 0;
long lastElements = 0;
long startTime = lastTime;
AtomicLong parsedCount = new AtomicLong(0);
if (resetCounters && buildHuffmanTree)
throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
if (cache == null)
cache = new AbstractCache.Builder<T>().build();
log.debug("Target vocab size before building: [" + cache.numWords() + "]");
final AtomicLong loopCounter = new AtomicLong(0);
AbstractCache<T> topHolder = new AbstractCache.Builder<T>().minElementFrequency(0).build();
int cnt = 0;
int numProc = Runtime.getRuntime().availableProcessors();
int numThreads = Math.max(numProc / 2, 2);
PriorityScheduler executorService = new PriorityScheduler(numThreads);
final AtomicLong execCounter = new AtomicLong(0);
final AtomicLong finCounter = new AtomicLong(0);
for (VocabSource<T> source : sources) {
SequenceIterator<T> iterator = source.getIterator();
iterator.reset();
log.debug("Trying source iterator: [" + cnt + "]");
log.debug("Target vocab size before building: [" + cache.numWords() + "]");
cnt++;
AbstractCache<T> tempHolder = new AbstractCache.Builder<T>().build();
int sequences = 0;
while (iterator.hasMoreSequences()) {
Sequence<T> document = iterator.nextSequence();
seqCount.incrementAndGet();
parsedCount.addAndGet(document.size());
tempHolder.incrementTotalDocCount();
execCounter.incrementAndGet();
if (allowParallelBuilder) {
executorService.execute(new VocabRunnable(tempHolder, document, finCounter, loopCounter));
// as we see in profiler, this lock isn't really happen too often
// we don't want too much left in tail
while (execCounter.get() - finCounter.get() > numProc) {
ThreadUtils.uncheckedSleep(1);
}
}
else {
processDocument(tempHolder, document, finCounter, loopCounter);
}
sequences++;
if (seqCount.get() % 100000 == 0) {
long currentTime = System.currentTimeMillis();
long currentSequences = seqCount.get();
long currentElements = parsedCount.get();
double seconds = (currentTime - lastTime) / (double) 1000;
double seqPerSec = (currentSequences - lastSequences) / seconds;
double elPerSec = (currentElements - lastElements) / seconds;
// log.info("Document time: {} us; hasNext time: {} us", timesNext.get(timesNext.size() / 2), timesHasNext.get(timesHasNext.size() / 2));
log.info("Sequences checked: [{}]; Current vocabulary size: [{}]; Sequences/sec: {}; Words/sec: {};",
seqCount.get(), tempHolder.numWords(), String.format("%.2f", seqPerSec),
String.format("%.2f", elPerSec));
lastTime = currentTime;
lastElements = currentElements;
lastSequences = currentSequences;
}
/**
* Firing scavenger loop
*/
if (enableScavenger && loopCounter.get() >= 2000000 && tempHolder.numWords() > 10000000) {
log.info("Starting scavenger...");
while (execCounter.get() != finCounter.get()) {
ThreadUtils.uncheckedSleep(1);
}
filterVocab(tempHolder, Math.max(1, source.getMinWordFrequency() / 2));
loopCounter.set(0);
}
}
// block untill all threads are finished
log.debug("Waiting till all processes stop...");
while (execCounter.get() != finCounter.get()) {
ThreadUtils.uncheckedSleep(1);
}
// apply minWordFrequency set for this source
log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "], NumWords: ["
+ tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get()
+ "], counter: [" + parsedCount.get() + "]");
if (source.getMinWordFrequency() > 0) {
filterVocab(tempHolder, source.getMinWordFrequency());
}
log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "], NumWords: ["
+ tempHolder.totalWordOccurrences() + "], sequences parsed: [" + seqCount.get()
+ "], counter: [" + parsedCount.get() + "]");
// at this moment we're ready to transfer
topHolder.importVocabulary(tempHolder);
}
// at this moment, we have vocabulary full of words, and we have to reset counters before transfer everything back to VocabCache
System.gc();
cache.importVocabulary(topHolder);
// adding UNK word
if (unk != null) {
log.info("Adding UNK element to vocab...");
unk.setSpecial(true);
cache.addToken(unk);
}
if (resetCounters) {
for (T element : cache.vocabWords()) {
element.setElementFrequency(0);
}
cache.updateWordsOccurrences();
}
if (buildHuffmanTree) {
if (limit > 0) {
// we want to sort labels before truncating them, so we'll keep most important words
val words = new ArrayList<T>(cache.vocabWords());
Collections.sort(words);
// now rolling through them
for (val element : words) {
if (element.getIndex() > limit && !element.isSpecial() && !element.isLabel())
cache.removeElement(element.getLabel());
}
}
// and now we're building Huffman tree
val huffman = new Huffman(cache.vocabWords());
huffman.build();
huffman.applyIndexes(cache);
}
executorService.shutdown();
System.gc();
long endSequences = seqCount.get();
long endTime = System.currentTimeMillis();
double seconds = (endTime - startTime) / (double) 1000;
double seqPerSec = endSequences / seconds;
log.info("Sequences checked: [{}], Current vocabulary size: [{}]; Sequences/sec: [{}];", seqCount.get(),
cache.numWords(), String.format("%.2f", seqPerSec));
return cache;
}
protected void filterVocab(AbstractCache<T> cache, int minWordFrequency) {
int numWords = cache.numWords();
LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<>();
for (T element : cache.vocabWords()) {
if (element.getElementFrequency() < minWordFrequency && !element.isSpecial() && !element.isLabel())
labelsToRemove.add(element.getLabel());
}
for (String label : labelsToRemove) {
cache.removeElement(label);
}
log.debug("Scavenger: Words before: {}; Words after: {};", numWords, cache.numWords());
}
public static class Builder<T extends SequenceElement> {
private List<VocabSource<T>> sources = new ArrayList<>();
private VocabCache<T> cache;
private Collection<String> stopWords = new ArrayList<>();
private boolean useAdaGrad = false;
private boolean fetchLabels = false;
private InvertedIndex<T> index;
private int limit;
private boolean enableScavenger = false;
private T unk;
private boolean allowParallelBuilder = true;
private boolean lockf = false;
public Builder() {
}
/**
* This method sets the limit to resulting vocabulary size.
*
* PLEASE NOTE: This method is applicable only if huffman tree is built.
*
* @param limit
* @return
*/
public Builder<T> setEntriesLimit(int limit) {
this.limit = limit;
return this;
}
public Builder<T> allowParallelTokenization(boolean reallyAllow) {
this.allowParallelBuilder = reallyAllow;
return this;
}
/**
* Defines, if adaptive gradients should be created during vocabulary mastering
*
* @param useAdaGrad
* @return
*/
protected Builder<T> useAdaGrad(boolean useAdaGrad) {
this.useAdaGrad = useAdaGrad;
return this;
}
/**
* After temporary internal vocabulary is built, it will be transferred to target VocabCache you pass here
*
* @param cache target VocabCache
* @return
*/
public Builder<T> setTargetVocabCache(@NonNull VocabCache<T> cache) {
this.cache = cache;
return this;
}
/**
* Adds SequenceIterator for vocabulary construction.
* Please note, you can add as many sources, as you wish.
*
* @param iterator SequenceIterator to build vocabulary from
* @param minElementFrequency elements with frequency below this value will be removed from vocabulary
* @return
*/
public Builder<T> addSource(@NonNull SequenceIterator<T> iterator, int minElementFrequency) {
sources.add(new VocabSource<T>(iterator, minElementFrequency));
return this;
}
/*
public Builder<T> addSource(LabelAwareIterator iterator, int minWordFrequency) {
sources.add(new VocabSource(iterator, minWordFrequency));
return this;
}
public Builder<T> addSource(SentenceIterator iterator, int minWordFrequency) {
sources.add(new VocabSource(new SentenceIteratorConverter(iterator), minWordFrequency));
return this;
}
*/
/*
public Builder setTokenizerFactory(@NonNull TokenizerFactory factory) {
this.tokenizerFactory = factory;
return this;
}
*/
public Builder<T> setStopWords(@NonNull Collection<String> stopWords) {
this.stopWords = stopWords;
return this;
}
/**
* Sets, if labels should be fetched, during vocab building
*
* @param reallyFetch
* @return
*/
public Builder<T> fetchLabels(boolean reallyFetch) {
this.fetchLabels = reallyFetch;
return this;
}
public Builder<T> setIndex(InvertedIndex<T> index) {
this.index = index;
return this;
}
public Builder<T> enableScavenger(boolean reallyEnable) {
this.enableScavenger = reallyEnable;
return this;
}
public Builder<T> setUnk(T unk) {
this.unk = unk;
return this;
}
public VocabConstructor<T> build() {
VocabConstructor<T> constructor = new VocabConstructor<>();
constructor.sources = this.sources;
constructor.cache = this.cache;
constructor.stopWords = this.stopWords;
constructor.useAdaGrad = this.useAdaGrad;
constructor.fetchLabels = this.fetchLabels;
constructor.limit = this.limit;
constructor.index = this.index;
constructor.enableScavenger = this.enableScavenger;
constructor.unk = this.unk;
constructor.allowParallelBuilder = this.allowParallelBuilder;
constructor.lockf = this.lockf;
return constructor;
}
public Builder<T> setLockFactor(boolean lockf) {
this.lockf = lockf;
return this;
}
}
@Data
private static class VocabSource<T extends SequenceElement> {
@NonNull
private SequenceIterator<T> iterator;
@NonNull
private int minWordFrequency;
}
protected class VocabRunnable implements Runnable {
private final AtomicLong finalCounter;
private final Sequence<T> document;
private final AbstractCache<T> targetVocab;
private final AtomicLong loopCounter;
private AtomicBoolean done = new AtomicBoolean(false);
public VocabRunnable(@NonNull AbstractCache<T> targetVocab, @NonNull Sequence<T> sequence,
@NonNull AtomicLong finalCounter, @NonNull AtomicLong loopCounter) {
this.finalCounter = finalCounter;
this.document = sequence;
this.targetVocab = targetVocab;
this.loopCounter = loopCounter;
}
@Override
public void run() {
try {
processDocument(targetVocab, document, finalCounter, loopCounter);
} catch (Exception e) {
throw new RuntimeException(e);
}
finally {
done.set(true);
}
}
}
}