deeplearning4j/deeplearning4j

View on GitHub
deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/HistogramBin.java

Summary

Maintainability
A
3 hrs
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.ui.model.weights;

import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.math.BigDecimal;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

@Data
public class HistogramBin implements Serializable {
    private transient INDArray sourceArray;
    private int numberOfBins;
    private int rounds;
    private transient INDArray bins;
    private double max;
    private double min;
    private Map<BigDecimal, AtomicInteger> data = new LinkedHashMap<>();

    private static final Logger log = LoggerFactory.getLogger(HistogramBin.class);

    /**
     * No-Args constructor should be used only for serialization/deserialization purposes.
     * In all other cases please use Histogram.Builder()
     */
    public HistogramBin() {

    }

    /**
     * Builds histogram bin for specified array
     * @param array
     */
    public HistogramBin(INDArray array) {

    }

    @JsonIgnore
    private synchronized void calcHistogram() {
        max = sourceArray.maxNumber().doubleValue();
        min = sourceArray.minNumber().doubleValue();

        // TODO: there's probably better way to get around of possible NaNs in max/min
        if (Double.isInfinite(max))
            max = Float.MAX_VALUE;

        if (Double.isNaN(max))
            max = Float.MIN_VALUE;

        if (Double.isInfinite(min))
            min = Float.MAX_VALUE;

        if (Double.isNaN(min))
            min = Float.MIN_VALUE;

        bins = Nd4j.create(numberOfBins);
        final double binSize = (max - min) / (numberOfBins - 1);


        data = new LinkedHashMap<>();
        BigDecimal[] keys = new BigDecimal[numberOfBins];

        for (int x = 0; x < numberOfBins; x++) {
            BigDecimal pos = new BigDecimal((min + (x * binSize))).setScale(rounds, BigDecimal.ROUND_CEILING);
            data.put(pos, new AtomicInteger(0));
            keys[x] = pos;
        }

        for (int x = 0; x < sourceArray.length(); x++) {
            double d = sourceArray.getDouble(x);
            int bin = (int) ((d - min) / binSize);

            if (bin < 0) {
                bins.putScalar(0, bins.getDouble(0) + 1);
                data.get(keys[0]).incrementAndGet();
            } else if (bin >= numberOfBins) {
                bins.putScalar(numberOfBins - 1, bins.getDouble(numberOfBins - 1) + 1);
                data.get(keys[numberOfBins - 1]).incrementAndGet();
            } else {
                bins.putScalar(bin, bins.getDouble(bin) + 1);
                data.get(keys[bin]).incrementAndGet();
            }
        }
    }

    public static class Builder {
        private INDArray source;
        private int binCount;
        private int rounds = 2;

        /**
         * Build Histogram Builder instance for specified array
         * @param array
         */
        public Builder(INDArray array) {
            this.source = array;
        }

        /**
         * Sets number of numbers behind decimal part
         *
         * @param rounds
         * @return
         */
        public Builder setRounding(int rounds) {
            this.rounds = rounds;
            return this;
        }

        /**
         * Specifies number of bins for output histogram
         *
         * @param bins
         * @return
         */
        public Builder setBinCount(int bins) {
            this.binCount = bins;
            return this;
        }

        /**
         * Returns ready-to-use Histogram instance
         * @return
         */
        public HistogramBin build() {
            HistogramBin histogram = new HistogramBin();
            histogram.sourceArray = this.source;
            histogram.numberOfBins = this.binCount;
            histogram.rounds = this.rounds;

            histogram.calcHistogram();

            return histogram;
        }
    }
}