deeplearning4j/deeplearning4j

View on GitHub
datavec/datavec-api/src/main/java/org/datavec/api/io/filters/RandomPathFilter.java

Summary

Maintainability
A
25 mins
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.datavec.api.io.filters;

import java.net.URI;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;

/**
 * Randomizes the order of paths in an array.
 *
 * @author saudet
 */
public class RandomPathFilter implements PathFilter {

    protected Random random;
    protected String[] extensions;
    protected long maxPaths = 0;

    /** Calls {@code this(random, extensions, 0)}. */
    public RandomPathFilter(Random random, String... extensions) {
        this(random, extensions, 0);
    }

    /**
     * Constructs an instance of the PathFilter.
     *
     * @param random     object to use
     * @param extensions of files to keep
     * @param maxPaths   max number of paths to return (0 == unlimited)
     */
    public RandomPathFilter(Random random, String[] extensions, long maxPaths) {
        this.random = random;
        this.extensions = extensions;
        this.maxPaths = maxPaths;
    }

    protected boolean accept(String name) {
        if (extensions == null || extensions.length == 0) {
            return true;
        }
        for (String extension : extensions) {
            if (name.endsWith("." + extension)) {
                return true;
            }
        }
        return false;
    }

    @Override
    public URI[] filter(URI[] paths) {
        // shuffle before to avoid sampling bias
        ArrayList<URI> paths2 = new ArrayList<URI>(Arrays.asList(paths));
        Collections.shuffle(paths2, random);

        ArrayList<URI> newpaths = new ArrayList<URI>();
        for (URI path : paths2) {
            if (accept(path.toString())) {
                newpaths.add(path);
            }
            if (maxPaths > 0 && newpaths.size() >= maxPaths) {
                break;
            }
        }
        return newpaths.toArray(new URI[newpaths.size()]);
    }
}