deeplearning4j/deeplearning4j

View on GitHub
datavec/datavec-api/src/main/java/org/datavec/api/split/partition/NumberOfRecordsPartitioner.java

Summary

Maintainability
A
2 hrs
Test Coverage
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.datavec.api.split.partition;

import org.datavec.api.conf.Configuration;
import org.datavec.api.split.InputSplit;

import java.io.OutputStream;
import java.net.URI;

public class NumberOfRecordsPartitioner implements Partitioner {

    private URI[] locations;
    private int recordsPerFile = DEFAULT_RECORDS_PER_FILE;
    //all records in to 1 file
    public final static int DEFAULT_RECORDS_PER_FILE = -1;

    public final static String RECORDS_PER_FILE_CONFIG = "org.datavec.api.split.partition.numrecordsperfile";
    private int numRecordsSoFar = 0;
    private int currLocation;
    private InputSplit inputSplit;
    private OutputStream current;
    private boolean doneWithCurrentLocation = false;
    private int totalRecordsWritten;

    @Override
    public int totalRecordsWritten() {
        return totalRecordsWritten;
    }

    @Override
    public int numRecordsWritten() {
        return numRecordsSoFar;
    }

    @Override
    public int numPartitions() {
        //possible it's a directory
        if(locations.length < 2) {

            if(locations.length > 0 && locations[0].isAbsolute()) {
                return recordsPerFile;
            }
            //append all results to 1 file when -1
            else {
                return 1;
            }
        }

        //otherwise it's a series of specified files.
        return locations.length / recordsPerFile;
    }

    @Override
    public void init(InputSplit inputSplit) {
        this.locations = inputSplit.locations();
        this.inputSplit = inputSplit;

    }

    @Override
    public void init(Configuration configuration, InputSplit split) {
        init(split);
        this.recordsPerFile = configuration.getInt(RECORDS_PER_FILE_CONFIG,DEFAULT_RECORDS_PER_FILE);
    }

    @Override
    public void updatePartitionInfo(PartitionMetaData metadata) {
        this.numRecordsSoFar += metadata.getNumRecordsUpdated();
        this.totalRecordsWritten += metadata.getNumRecordsUpdated();
        if(numRecordsSoFar >= recordsPerFile && recordsPerFile > 0)  {
            doneWithCurrentLocation = true;
        }
    }

    @Override
    public boolean needsNewPartition() {
        doneWithCurrentLocation = numRecordsSoFar >= recordsPerFile && recordsPerFile > 0;
        return recordsPerFile > 0 && numRecordsSoFar >= recordsPerFile ||  doneWithCurrentLocation;
    }

    @Override
    public OutputStream openNewStream() {
        //reset status of location
        doneWithCurrentLocation = false;
        //ensure count is 0 for records so far for current record
        numRecordsSoFar = 0;

        //only append when directory, also ensure we can bootstrap and we can write to the current location
        if(currLocation >= locations.length - 1 && locations.length >= 1 && needsNewPartition() || inputSplit.needsBootstrapForWrite() ||
                locations.length < 1 ||
                currLocation >= locations.length || !inputSplit.canWriteToLocation(locations[currLocation])
                && needsNewPartition()) {

            String newInput = inputSplit.addNewLocation();
            try {
                OutputStream ret =  inputSplit.openOutputStreamFor(newInput);
                this.current = ret;
                return ret;
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

        else {
            try {
                OutputStream ret =  inputSplit.openOutputStreamFor(locations[currLocation].toString());
                currLocation++;
                this.current = ret;
                return ret;
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }

    }

    @Override
    public OutputStream currentOutputStream() {
        if(current == null) {
            current = openNewStream();
        }
        return current;
    }
}