deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCache.java

Summary

Maintainability
D
2 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.models.word2vec.wordstore.inmemory;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.JsonTypeInfo;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.*;
import org.nd4j.shade.jackson.databind.type.CollectionType;

import java.io.IOException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;

@Slf4j
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class")
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
        setterVisibility = JsonAutoDetect.Visibility.NONE)
public class AbstractCache<T extends SequenceElement> implements VocabCache<T> {
    private static final String CLASS_FIELD = "@class";
    private static final String VOCAB_LIST_FIELD = "VocabList";
    private static final String VOCAB_ITEM_FIELD = "VocabItem";
    private static final String DOC_CNT_FIELD = "DocumentsCounter";
    private static final String MINW_FREQ_FIELD = "MinWordsFrequency";
    private static final String HUGE_MODEL_FIELD = "HugeModelExpected";
    private static final String STOP_WORDS_FIELD = "StopWords";
    private static final String SCAVENGER_FIELD = "ScavengerThreshold";
    private static final String RETENTION_FIELD = "RetentionDelay";
    private static final String TOTAL_WORD_FIELD = "TotalWordCount";

    private final ConcurrentMap<Long, T> vocabulary = new ConcurrentHashMap<>();

    private final Map<String, T> extendedVocabulary = new ConcurrentHashMap<>();

    private final Map<Integer, T> idxMap = new ConcurrentHashMap<>();

    private final AtomicLong documentsCounter = new AtomicLong(0);


    private int minWordFrequency = 0;
    private boolean hugeModelExpected = false;

    // we're using <String>for compatibility & failproof reasons: it's easier to store unique labels then abstract objects of unknown size
    // TODO: wtf this one is doing here?
    private List<String> stopWords = new ArrayList<>(); // stop words

    // this variable defines how often scavenger will be activated
    private int scavengerThreshold = 3000000; // ser
    private int retentionDelay = 3; // ser

    // for scavenger mechanics we need to know the actual number of words being added
    private transient AtomicLong hiddenWordsCounter = new AtomicLong(0);

    private final AtomicLong totalWordCount = new AtomicLong(0); // ser

    private static final int MAX_CODE_LENGTH = 40;

    /**
     * Deserialize vocabulary from specified path
     */
    @Override
    public void loadVocab() {
        // TODO: this method should be static and accept path
    }

    /**
     * Returns true, if number of elements in vocabulary > 0, false otherwise
     *
     * @return
     */
    @Override
    public boolean vocabExists() {
        return !vocabulary.isEmpty();
    }

    /**
     * Serialize vocabulary to specified path
     *
     */
    @Override
    public void saveVocab() {
        // TODO: this method should be static and accept path
    }

    /**
     * Returns collection of labels available in this vocabulary
     *
     * @return
     */
    @Override
    public Collection<String> words() {
        return Collections.unmodifiableCollection(extendedVocabulary.keySet());
    }

    /**
     * Increment frequency for specified label by 1
     *
     * @param word the word to increment the count for
     */
    @Override
    public void incrementWordCount(String word) {
        incrementWordCount(word, 1);
    }


    /**
     * Increment frequency for specified label by specified value
     *
     * @param word the word to increment the count for
     * @param increment the amount to increment by
     */
    @Override
    public void incrementWordCount(String word, int increment) {
        T element = extendedVocabulary.get(word);
        if (element != null) {
            element.increaseElementFrequency(increment);
            totalWordCount.addAndGet(increment);
        }
    }

    /**
     * Returns the SequenceElement's frequency over training corpus
     *
     * @param word the word to retrieve the occurrence frequency for
     * @return
     */
    @Override
    public int wordFrequency(@NonNull String word) {
        // TODO: proper wordFrequency impl should return long, instead of int
        T element = extendedVocabulary.get(word);
        if (element != null)
            return (int) element.getElementFrequency();
        return 0;
    }

    /**
     * Checks, if specified label exists in vocabulary
     *
     * @param word the word to check for
     * @return
     */
    @Override
    public boolean containsWord(String word) {
        return extendedVocabulary.containsKey(word);
    }

    /**
     * Checks, if specified element exists in vocabulary
     *
     * @param element
     * @return
     */
    public boolean containsElement(T element) {
        // FIXME: lolwtf
        return vocabulary.values().contains(element);
    }

    /**
     * Returns the label of the element at specified Huffman index
     *
     * @param index the index of the word to get
     * @return
     */
    @Override
    public String wordAtIndex(int index) {
        T element = idxMap.get(index);
        if (element != null) {
            return element.getLabel();
        }
        return null;
    }

    /**
     * Returns SequenceElement at specified index
     *
     * @param index
     * @return
     */
    @Override
    public T elementAtIndex(int index) {
        return idxMap.get(index);
    }

    /**
     * Returns Huffman index for specified label
     *
     * @param label the label to get index for
     * @return >=0 if label exists, -1 if Huffman tree wasn't built yet, -2 if specified label wasn't found
     */
    @Override
    public int indexOf(String label) {
        T token = tokenFor(label);
        if (token != null) {
            return token.getIndex();
        } else
            return -2;
    }

    /**
     * Returns collection of SequenceElements stored in this vocabulary
     *
     * @return
     */
    @Override
    public Collection<T> vocabWords() {
        return vocabulary.values();
    }

    /**
     * Returns total number of elements observed
     *
     * @return
     */
    @Override
    public long totalWordOccurrences() {
        return totalWordCount.get();
    }

    public void setTotalWordOccurences(long value) {
        totalWordCount.set(value);
    }

    /**
     * Returns SequenceElement for specified label
     *
     * @param label to fetch element for
     * @return
     */
    @Override
    public T wordFor(@NonNull String label) {
        return extendedVocabulary.get(label);
    }

    @Override
    public T wordFor(long id) {
        return vocabulary.get(id);
    }

    /**
     * This method allows to insert specified label to specified Huffman tree position.
     * CAUTION: Never use this, unless you 100% sure what are you doing.
     *
     * @param index
     * @param label
     */
    @Override
    public void addWordToIndex(int index, String label) {
        if (index >= 0) {
            T token = tokenFor(label);
            if (token != null) {
                idxMap.put(index, token);
                token.setIndex(index);
            }
        }
    }

    @Override
    public void addWordToIndex(int index, long elementId) {
        if (index >= 0)
            idxMap.put(index, tokenFor(elementId));
    }

    @Override
    @Deprecated
    public void putVocabWord(String word) {
        if (!containsWord(word))
            throw new IllegalStateException("Specified label is not present in vocabulary");
    }

    /**
     * Returns number of elements in this vocabulary
     *
     * @return
     */
    @Override
    public int numWords() {
        return vocabulary.size();
    }

    /**
     * Returns number of documents (if applicable) the label was observed in.
     *
     * @param word the number of documents the word appeared in
     * @return
     */
    @Override
    public int docAppearedIn(String word) {
        T element = extendedVocabulary.get(word);
        if (element != null) {
            return (int) element.getSequencesCount();
        } else
            return -1;
    }

    /**
     * Increment number of documents the label was observed in
     *
     * Please note: this method is NOT thread-safe
     *
     * @param word the word to increment by
     * @param howMuch
     */
    @Override
    public void incrementDocCount(String word, long howMuch) {
        T element = extendedVocabulary.get(word);
        if (element != null) {
            element.incrementSequencesCount();
        }
    }

    /**
     * Set exact number of observed documents that contain specified word
     *
     * Please note: this method is NOT thread-safe
     *
     * @param word the word to set the count for
     * @param count the count of the word
     */
    @Override
    public void setCountForDoc(String word, long count) {
        T element = extendedVocabulary.get(word);
        if (element != null) {
            element.setSequencesCount(count);
        }
    }

    /**
     * Returns total number of documents observed (if applicable)
     *
     * @return
     */
    @Override
    public long totalNumberOfDocs() {
        return documentsCounter.intValue();
    }

    /**
     * Increment total number of documents observed by 1
     */
    @Override
    public void incrementTotalDocCount() {
        documentsCounter.incrementAndGet();
    }

    /**
     * Increment total number of documents observed by specified value
     */
    @Override
    public void incrementTotalDocCount(long by) {
        documentsCounter.addAndGet(by);
    }

    /**
     * This method allows to set total number of documents
     * @param by
     */
    public void setTotalDocCount(long by) {

        documentsCounter.set(by);
    }


    /**
     * Returns collection of SequenceElements from this vocabulary. The same as vocabWords() method
     *
     * @return collection of SequenceElements
     */
    @Override
    public Collection<T> tokens() {
        return vocabWords();
    }

    /**
     * This method adds specified SequenceElement to vocabulary
     *
     * @param element the word to add
     */
    @Override
    public boolean addToken(T element) {
        boolean ret = false;
        T oldElement = vocabulary.putIfAbsent(element.getStorageId(), element);
        if (oldElement == null) {
            //putIfAbsent added our element
            if (element.getLabel() != null) {
                extendedVocabulary.put(element.getLabel(), element);
            }
            oldElement = element;
            ret = true;
        } else {
            oldElement.incrementSequencesCount(element.getSequencesCount());
            oldElement.increaseElementFrequency((int) element.getElementFrequency());
        }
        totalWordCount.addAndGet((long) oldElement.getElementFrequency());
        return ret;
    }

    public void addToken(T element, boolean lockf) {
        T oldElement = vocabulary.putIfAbsent(element.getStorageId(), element);
        if (oldElement == null) {
            //putIfAbsent added our element
            if (element.getLabel() != null) {
                extendedVocabulary.put(element.getLabel(), element);
            }
            oldElement = element;
        } else {
            oldElement.incrementSequencesCount(element.getSequencesCount());
            oldElement.increaseElementFrequency((int) element.getElementFrequency());
        }
        totalWordCount.addAndGet((long) oldElement.getElementFrequency());
    }

    /**
     * Returns SequenceElement for specified label. The same as wordFor() method.
     *
     * @param label the label to get the token for
     * @return
     */
    @Override
    public T tokenFor(String label) {
        return wordFor(label);
    }

    @Override
    public T tokenFor(long id) {
        return vocabulary.get(id);
    }

    /**
     * Checks, if specified label already exists in vocabulary. The same as containsWord() method.
     *
     * @param label the token to test
     * @return
     */
    @Override
    public boolean hasToken(String label) {
        return containsWord(label);
    }


    /**
     * This method imports all elements from VocabCache passed as argument
     * If element already exists,
     *
     * @param vocabCache
     */
    public void importVocabulary(@NonNull VocabCache<T> vocabCache) {
        AtomicBoolean added = new AtomicBoolean(false);
        for (T element : vocabCache.vocabWords()) {
            if (this.addToken(element))
                added.set(true);
        }
        //logger.info("Current state: {}; Adding value: {}", this.documentsCounter.get(), vocabCache.totalNumberOfDocs());
        if (added.get())
            this.documentsCounter.addAndGet(vocabCache.totalNumberOfDocs());
    }

    @Override
    public void updateWordsOccurrences() {
        totalWordCount.set(0);
        for (T element : vocabulary.values()) {
            long value = (long) element.getElementFrequency();

            if (value > 0) {
                totalWordCount.addAndGet(value);
            }
        }
        log.info("Updated counter: [" + totalWordCount.get() + "]");
    }

    @Override
    public void removeElement(String label) {
        SequenceElement element = extendedVocabulary.get(label);
        if (element != null) {
            totalWordCount.getAndAdd((long) element.getElementFrequency() * -1);
            idxMap.remove(element.getIndex());
            extendedVocabulary.remove(label);
            vocabulary.remove(element.getStorageId());
        } else
            throw new IllegalStateException("Can't get label: '" + label + "'");
    }

    @Override
    public void removeElement(T element) {
        removeElement(element.getLabel());
    }

    private static ObjectMapper mapper = null;
    private static final Object lock = new Object();

    private static ObjectMapper mapper() {
        if (mapper == null) {
            synchronized (lock) {
                if (mapper == null) {
                    mapper = new ObjectMapper();
                    mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
                    mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
                    return mapper;
                }
            }
        }
        return mapper;
    }

    public String toJson() throws JsonProcessingException {

        JsonObject retVal = new JsonObject();
        ObjectMapper mapper = mapper();
        Iterator<T> iter = vocabulary.values().iterator();
        Class clazz = null;
        if (iter.hasNext())
            clazz = iter.next().getClass();
        else
            return retVal.getAsString();

        retVal.addProperty(CLASS_FIELD, mapper.writeValueAsString(this.getClass().getName()));

        JsonArray jsonValues = new JsonArray();
        for (T value : vocabulary.values()) {
            JsonObject item = new JsonObject();
            item.addProperty(CLASS_FIELD, mapper.writeValueAsString(clazz));
            item.addProperty(VOCAB_ITEM_FIELD, mapper.writeValueAsString(value));
            jsonValues.add(item);
        }
        retVal.add(VOCAB_LIST_FIELD, jsonValues);

        retVal.addProperty(DOC_CNT_FIELD, mapper.writeValueAsString(documentsCounter.longValue()));
        retVal.addProperty(MINW_FREQ_FIELD, mapper.writeValueAsString(minWordFrequency));
        retVal.addProperty(HUGE_MODEL_FIELD, mapper.writeValueAsString(hugeModelExpected));

        retVal.addProperty(STOP_WORDS_FIELD, mapper.writeValueAsString(stopWords));

        retVal.addProperty(SCAVENGER_FIELD, mapper.writeValueAsString(scavengerThreshold));
        retVal.addProperty(RETENTION_FIELD, mapper.writeValueAsString(retentionDelay));
        retVal.addProperty(TOTAL_WORD_FIELD, mapper.writeValueAsString(totalWordCount.longValue()));

        return retVal.toString();
    }

    public static <T extends SequenceElement> AbstractCache<T> fromJson(String jsonString)  throws IOException {
        AbstractCache<T> retVal = new AbstractCache.Builder<T>().build();

        JsonParser parser = new JsonParser();
        JsonObject json = parser.parse(jsonString).getAsJsonObject();

        ObjectMapper mapper = mapper();

        CollectionType wordsCollectionType = mapper.getTypeFactory()
                .constructCollectionType(List.class, VocabWord.class);

        List<T> items = new ArrayList<>();
        JsonArray jsonArray = json.get(VOCAB_LIST_FIELD).getAsJsonArray();
        for (int i = 0; i < jsonArray.size(); ++i) {
            VocabWord item = mapper.readValue(jsonArray.get(i).getAsJsonObject().get(VOCAB_ITEM_FIELD).getAsString(), VocabWord.class);
            items.add((T)item);
        }

        ConcurrentMap<Long, T> vocabulary = new ConcurrentHashMap<>();
        Map<String, T> extendedVocabulary = new ConcurrentHashMap<>();
        Map<Integer, T> idxMap = new ConcurrentHashMap<>();

        for (T item : items) {
            vocabulary.put(item.getStorageId(), item);
            extendedVocabulary.put(item.getLabel(), item);
            idxMap.put(item.getIndex(), item);
        }
        List<String> stopWords = mapper.readValue(json.get(STOP_WORDS_FIELD).getAsString(), List.class);

        Long documentsCounter = json.get(DOC_CNT_FIELD).getAsLong();
        Integer minWordsFrequency = json.get(MINW_FREQ_FIELD).getAsInt();
        Boolean hugeModelExpected = json.get(HUGE_MODEL_FIELD).getAsBoolean();
        Integer scavengerThreshold = json.get(SCAVENGER_FIELD).getAsInt();
        Integer retentionDelay = json.get(RETENTION_FIELD).getAsInt();
        Long totalWordCount = json.get(TOTAL_WORD_FIELD).getAsLong();

        retVal.vocabulary.putAll(vocabulary);
        retVal.extendedVocabulary.putAll(extendedVocabulary);
        retVal.idxMap.putAll(idxMap);
        retVal.stopWords.addAll(stopWords);
        retVal.documentsCounter.set(documentsCounter);
        retVal.minWordFrequency = minWordsFrequency;
        retVal.hugeModelExpected = hugeModelExpected;
        retVal.scavengerThreshold = scavengerThreshold;
        retVal.retentionDelay = retentionDelay;
        retVal.totalWordCount.set(totalWordCount);
        return retVal;
    }

    public static class Builder<T extends SequenceElement> {
        protected int scavengerThreshold = 3000000;
        protected int retentionDelay = 3;
        protected int minElementFrequency;
        protected boolean hugeModelExpected = false;


        public Builder<T> hugeModelExpected(boolean reallyExpected) {
            this.hugeModelExpected = reallyExpected;
            return this;
        }

        public Builder<T> scavengerThreshold(int threshold) {
            this.scavengerThreshold = threshold;
            return this;
        }

        public Builder<T> scavengerRetentionDelay(int delay) {
            this.retentionDelay = delay;
            return this;
        }

        public Builder<T> minElementFrequency(int minFrequency) {
            this.minElementFrequency = minFrequency;
            return this;
        }

        public AbstractCache<T> build() {
            AbstractCache<T> cache = new AbstractCache<>();
            cache.minWordFrequency = this.minElementFrequency;
            cache.scavengerThreshold = this.scavengerThreshold;
            cache.retentionDelay = this.retentionDelay;

            return cache;
        }

    }

    @Override
    public String toString() {
        return "AbstractCache{" +
                "vocabulary=" + vocabulary +
                ", extendedVocabulary=" + extendedVocabulary +
                ", idxMap=" + idxMap +
                ", documentsCounter=" + documentsCounter +
                ", minWordFrequency=" + minWordFrequency +
                ", hugeModelExpected=" + hugeModelExpected +
                ", stopWords=" + stopWords +
                ", scavengerThreshold=" + scavengerThreshold +
                ", retentionDelay=" + retentionDelay +
                ", hiddenWordsCounter=" + hiddenWordsCounter +
                ", totalWordCount=" + totalWordCount +
                '}';
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (!(o instanceof AbstractCache)) return false;
        AbstractCache<?> that = (AbstractCache<?>) o;
        return getMinWordFrequency() == that.getMinWordFrequency() && isHugeModelExpected() == that.isHugeModelExpected() && getScavengerThreshold() == that.getScavengerThreshold() && getRetentionDelay() == that.getRetentionDelay() && Objects.equals(getVocabulary(), that.getVocabulary()) && Objects.equals(getExtendedVocabulary(), that.getExtendedVocabulary()) && Objects.equals(getIdxMap(), that.getIdxMap()) && Objects.equals(getStopWords(), that.getStopWords());
    }

    @Override
    public int hashCode() {
        return Objects.hash(getVocabulary(), getExtendedVocabulary(), getIdxMap(), getMinWordFrequency(), isHugeModelExpected(), getStopWords(), getScavengerThreshold(), getRetentionDelay());
    }

    public ConcurrentMap<Long, T> getVocabulary() {
        return vocabulary;
    }

    public Map<String, T> getExtendedVocabulary() {
        return extendedVocabulary;
    }

    public Map<Integer, T> getIdxMap() {
        return idxMap;
    }

    public AtomicLong getDocumentsCounter() {
        return documentsCounter;
    }

    public int getMinWordFrequency() {
        return minWordFrequency;
    }

    public void setMinWordFrequency(int minWordFrequency) {
        this.minWordFrequency = minWordFrequency;
    }

    public boolean isHugeModelExpected() {
        return hugeModelExpected;
    }

    public void setHugeModelExpected(boolean hugeModelExpected) {
        this.hugeModelExpected = hugeModelExpected;
    }

    public List<String> getStopWords() {
        return stopWords;
    }

    public void setStopWords(List<String> stopWords) {
        this.stopWords = stopWords;
    }

    public int getScavengerThreshold() {
        return scavengerThreshold;
    }

    public void setScavengerThreshold(int scavengerThreshold) {
        this.scavengerThreshold = scavengerThreshold;
    }

    public int getRetentionDelay() {
        return retentionDelay;
    }

    public void setRetentionDelay(int retentionDelay) {
        this.retentionDelay = retentionDelay;
    }

    public AtomicLong getHiddenWordsCounter() {
        return hiddenWordsCounter;
    }

    public void setHiddenWordsCounter(AtomicLong hiddenWordsCounter) {
        this.hiddenWordsCounter = hiddenWordsCounter;
    }

    public AtomicLong getTotalWordCount() {
        return totalWordCount;
    }

    public static ObjectMapper getMapper() {
        return mapper;
    }

    public static void setMapper(ObjectMapper mapper) {
        AbstractCache.mapper = mapper;
    }
}