deeplearning4j/deeplearning4j

View on GitHub
nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java

Summary

Maintainability
F
3 days
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.samediff.internal;

import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.function.Predicate;
import org.nd4j.common.primitives.Pair;

import java.util.*;

@Slf4j
public abstract class AbstractDependencyTracker<T, D> {
    @Getter
    private final IDependencyMap<T, D> dependencies; // Key: the dependent. Value: all things that the key depends on
    @Getter
    private final IDependencyMap<T, Pair<D, D>> orDependencies; // Key: the dependent. Value: the set of OR dependencies
    @Getter
    private final Map<D, Set<T>> reverseDependencies = new LinkedHashMap<>(); // Key: the dependee. Value: The set of
                                                                              // all dependents that depend on this
                                                                              // value
    @Getter
    private final Map<D, Set<T>> reverseOrDependencies = new HashMap<>();
    @Getter
    private final Set<D> satisfiedDependencies = new LinkedHashSet<>(); // Mark the dependency as satisfied. If not in
                                                                        // set: assumed to not be satisfied
    @Getter
    private final Set<T> allSatisfied; // Set of all dependent values (Ys) that have all dependencies satisfied
    @Getter
    private final Queue<T> allSatisfiedQueue = new LinkedList<>(); // Queue for *new* "all satisfied" values. Values are
                                                                   // removed using the "new all satisfied" methods

    protected AbstractDependencyTracker() {
        dependencies = (IDependencyMap<T, D>) newTMap();
        orDependencies = (IDependencyMap<T, Pair<D, D>>) newTMap();
        allSatisfied = newTSet();
    }

    /**
     * @return A new map where the dependents (i.e., Y in "X -> Y") are the key
     */
    protected abstract IDependencyMap<T, ?> newTMap();

    /**
     * @return A new set where the dependents (i.e., Y in "X -> Y") are the key
     */
    protected abstract Set<T> newTSet();

    /**
     * @return A String representation of the dependent object
     */
    protected abstract String toStringT(T t);

    /**
     * @return A String representation of the dependee object
     */
    protected abstract String toStringD(D d);

    /**
     * Clear all internal state for the dependency tracker
     */
    public void clear() {
        dependencies.clear();
        orDependencies.clear();
        reverseDependencies.clear();
        reverseOrDependencies.clear();
        satisfiedDependencies.clear();
        allSatisfied.clear();
        allSatisfiedQueue.clear();
    }

    /**
     * @return True if no dependencies have been defined
     */
    public boolean isEmpty() {
        return dependencies.isEmpty() && orDependencies.isEmpty() &&
                allSatisfiedQueue.isEmpty();
    }

    /**
     * @return True if the dependency has been marked as satisfied using
     *         {@link #markSatisfied(Object, boolean)}
     */
    public boolean isSatisfied(@NonNull D x) {

        boolean ret = satisfiedDependencies.contains(x);

        return ret;
    }

    /**
     * Mark the specified value as satisfied.
     * For example, if two dependencies have been previously added (X -> Y) and (X
     * -> A) then after the markSatisfied(X, true)
     * call, both of these dependencies are considered satisfied.
     *
     * @param x         Value to mark
     * @param satisfied Whether to mark as satisfied (true) or unsatisfied (false)
     */
    public void markSatisfied(@NonNull D x, boolean satisfied) {

        if (satisfied) {
            boolean alreadySatisfied = satisfiedDependencies.contains(x);

            if (!alreadySatisfied) {
                satisfiedDependencies.add(x);

                // Check if any Y's exist that have dependencies that are all satisfied, for X
                // -> Y
                Set<T> s = reverseDependencies.get(x);
                Set<T> s2 = reverseOrDependencies.get(x);

                Set<T> set;
                if (s != null && s2 != null) {
                    set = newTSet();
                    set.addAll(s);
                    set.addAll(s2);
                } else if (s != null) {
                    set = s;
                } else if (s2 != null) {
                    set = s2;
                } else {
                    if (log.isTraceEnabled()) {
                        log.trace("No values depend on: {}", toStringD(x));
                    }
                    return;
                }

                for (T t : set) {

                    boolean allSatisfied = true;
                    Iterable<D> it = dependencies.getDependantsForEach(t);
                    if (it != null) {
                        for (D d : it) {
                            if (!isSatisfied(d)) {
                                allSatisfied = false;
                                break;
                            }
                        }
                    }

                    if (allSatisfied) {
                        Iterable<Pair<D, D>> itOr = orDependencies.getDependantsForEach(t);
                        if (itOr != null) {
                            for (Pair<D, D> p : itOr) {
                                if (!isSatisfied(p.getFirst()) && !isSatisfied(p.getSecond())) {
                                    allSatisfied = false;
                                    break;
                                }
                            }
                        }
                    }

                    if (allSatisfied && !this.allSatisfied.contains(t)) {
                        this.allSatisfied.add(t);
                        this.allSatisfiedQueue.add(t);
                    }
                }
            }

        } else {
            satisfiedDependencies.remove(x);
            if (!allSatisfied.isEmpty()) {

                Set<T> reverse = reverseDependencies.get(x);
                if (reverse != null) {
                    for (T y : reverse) {
                        if (allSatisfied.contains(y)) {
                            allSatisfied.remove(y);
                            allSatisfiedQueue.remove(y);
                        }
                    }
                }
                Set<T> orReverse = reverseOrDependencies.get(x);
                if (orReverse != null) {
                    for (T y : orReverse) {
                        if (allSatisfied.contains(y) && !isAllSatisfied(y)) {
                            allSatisfied.remove(y);
                            allSatisfiedQueue.remove(y);
                        }
                    }
                }
            }
        }
    }

    /**
     * Check whether any dependencies x -> y exist, for y (i.e., anything previously
     * added by {@link #addDependency(Object, Object)}
     * or {@link #addOrDependency(Object, Object, Object)}
     *
     * @param y Dependent to check
     * @return True if Y depends on any values
     */
    public boolean hasDependency(@NonNull T y) {

        if (dependencies.containsAny(y)) {

            return true;
        }

        if (orDependencies.containsAny(y)) {

            return true;
        }

        return false;
    }

    /**
     * Get all dependencies x, for x -> y, and (x1 or x2) -> y
     *
     * @param y Dependent to get dependencies for
     * @return List of dependencies
     */
    public DependencyList<T, D> getDependencies(@NonNull T y) {
        Iterable<D> s1 = dependencies.getDependantsForEach(y);
        Iterable<Pair<D, D>> s2 = orDependencies.getDependantsForEach(y);

        return new DependencyList<>(y, s1, s2);
    }

    /**
     * Add a dependency: y depends on x, as in x -> y
     *
     * @param y The dependent
     * @param x The dependee that is required for Y
     */
    public void addDependency(@NonNull T y, @NonNull D x) {

        if (!reverseDependencies.containsKey(x))
            reverseDependencies.put(x, newTSet());

        dependencies.add(y, x);
        reverseDependencies.get(x).add(y);

        checkAndUpdateIfAllSatisfied(y);
    }

    protected void checkAndUpdateIfAllSatisfied(@NonNull T y) {

        boolean allSat = isAllSatisfied(y);
        if (allSat) {
            // Case where "x is satisfied" happened before x->y added
            if (!allSatisfied.contains(y)) {
                allSatisfied.add(y);
                allSatisfiedQueue.add(y);
            }
        } else if (allSatisfied.contains(y)) {
            if (!allSatisfiedQueue.contains(y)) {
                StringBuilder sb = new StringBuilder();
                sb.append("Dependent object \"").append(toStringT(y))
                        .append("\" was previously processed after all dependencies")
                        .append(" were marked satisfied, but is now additional dependencies have been added.\n");
                Iterable<D> dl = dependencies.getDependantsForEach(y);
                if (dl != null) {
                    sb.append("Dependencies:\n");
                    for (D d : dl) {
                        sb.append(d).append(" - ").append(isSatisfied(d) ? "Satisfied" : "Not satisfied").append("\n");
                    }
                }
                Iterable<Pair<D, D>> dlOr = orDependencies.getDependantsForEach(y);
                if (dlOr != null) {
                    sb.append("Or dependencies:\n");
                    for (Pair<D, D> p : dlOr) {
                        sb.append(p).append(" - satisfied=(").append(isSatisfied(p.getFirst())).append(",")
                                .append(isSatisfied(p.getSecond())).append(")");
                    }
                }

                allSatisfiedQueue.add(y);
                log.warn(sb.toString());
            }

            // Not satisfied, but is in the queue -> needs to be removed
            allSatisfied.remove(y);
            allSatisfiedQueue.remove(y);
        }
    }

    protected boolean isAllSatisfied(@NonNull T y) {

        Iterable<D> set1 = dependencies.getDependantsForEach(y);

        boolean retVal = true;
        if (set1 != null) {
            for (D d : set1) {
                retVal = isSatisfied(d);
                if (!retVal)
                    break;
            }
        }
        if (retVal) {
            Iterable<Pair<D, D>> set2 = orDependencies.getDependantsForEach(y);
            if (set2 != null) {
                for (Pair<D, D> p : set2) {
                    retVal = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond());
                    if (!retVal)
                        break;
                }
            }
        }

        return retVal;
    }

    /**
     * Remove a dependency (x -> y)
     *
     * @param y The dependent that currently requires X
     * @param x The dependee that is no longer required for Y
     */
    public void removeDependency(@NonNull T y, @NonNull D x) {

        dependencies.removeGroupReturn(y, t -> t.equals(x));

        Set<T> s2 = reverseDependencies.get(x);
        if (s2 != null) {
            s2.remove(y);
            if (s2.isEmpty())
                reverseDependencies.remove(x);
        }

        Iterable<Pair<D, D>> s3 = orDependencies.removeGroupReturn(y, t -> {
            return x.equals(t.getFirst()) || x.equals(t.getSecond());
        });
        if (s3 != null) {
            boolean removedReverse = false;
            for (Pair<D, D> p : s3) {
                if (!removedReverse) {
                    Set<T> set1 = reverseOrDependencies.get(p.getFirst());
                    Set<T> set2 = reverseOrDependencies.get(p.getSecond());

                    set1.remove(y);
                    set2.remove(y);

                    if (set1.isEmpty())
                        reverseOrDependencies.remove(p.getFirst());
                    if (set2.isEmpty())
                        reverseOrDependencies.remove(p.getSecond());

                    removedReverse = true;
                }
            }
        }

    }

    /**
     * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) ->
     * Y<br>
     * If either x1 or x2 (or both) are marked satisfied via
     * {@link #markSatisfied(Object, boolean)} then the
     * dependency is considered satisfied
     *
     * @param y  Dependent
     * @param x1 Dependee 1
     * @param x2 Dependee 2
     */
    public void addOrDependency(@NonNull T y, @NonNull D x1, @NonNull D x2) {

        if (!reverseOrDependencies.containsKey(x1))
            reverseOrDependencies.put(x1, newTSet());
        if (!reverseOrDependencies.containsKey(x2))
            reverseOrDependencies.put(x2, newTSet());

        orDependencies.add(y, new Pair<>(x1, x2));
        reverseOrDependencies.get(x1).add(y);
        reverseOrDependencies.get(x2).add(y);

        checkAndUpdateIfAllSatisfied(y);
    }

    /**
     * @return True if there are any new/unprocessed "all satisfied dependents" (Ys
     *         in X->Y)
     */
    public boolean hasNewAllSatisfied() {
        return !allSatisfiedQueue.isEmpty();
    }

    /**
     * Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked
     * as satisfied via {@link #markSatisfied(Object, boolean)}
     * Throws an exception if {@link #hasNewAllSatisfied()} returns false.<br>
     * Note that once a value has been retrieved from here, no new dependencies of
     * the form (X -> Y) can be added for this value;
     * the value is considered "processed" at this point.
     *
     * @return The next new "all satisfied dependent"
     */
    public T getNewAllSatisfied() {
        Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
        return allSatisfiedQueue.remove();
    }

    /**
     * @return As per {@link #getNewAllSatisfied()} but returns all values
     */
    public List<T> getNewAllSatisfiedList() {
        Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
        List<T> ret = new ArrayList<>(allSatisfiedQueue);
        allSatisfiedQueue.clear();
        return ret;
    }

    /**
     * As per {@link #getNewAllSatisfied()} but instead of returning the first
     * dependee, it returns the first that matches
     * the provided predicate. If no value matches the predicate, null is returned
     *
     * @param predicate Predicate gor checking
     * @return The first value matching the predicate, or null if no values match
     *         the predicate
     */
    public T getFirstNewAllSatisfiedMatching(@NonNull Predicate<T> predicate) {
        Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");

        T t = allSatisfiedQueue.peek();
        if (predicate.test(t)) {
            t = allSatisfiedQueue.remove();
            allSatisfied.remove(t);
            return t;
        }

        if (allSatisfiedQueue.size() > 1) {
            Iterator<T> iter = allSatisfiedQueue.iterator();
            while (iter.hasNext()) {
                t = iter.next();
                if (predicate.test(t)) {
                    iter.remove();
                    allSatisfied.remove(t);
                    return t;
                }
            }
        }

        return null; // None match predicate
    }
}