
View on GitHub


3 days
Test Coverage
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  *
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************

package org.datavec.image.recordreader;

import org.nd4j.shade.guava.base.Preconditions;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.BaseRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.util.files.FileFromPathIterator;
import org.datavec.api.util.files.URIUtil;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.loader.ImageLoader;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.util.*;

public abstract class BaseImageRecordReader extends BaseRecordReader {
    protected boolean finishedInputStreamSplit;
    protected Iterator<File> iter;
    protected Configuration conf;
    protected File currentFile;
    protected PathLabelGenerator labelGenerator = null;
    protected PathMultiLabelGenerator labelMultiGenerator = null;
    protected List<String> labels = new ArrayList<>();
    protected boolean appendLabel = false;
    protected boolean writeLabel = false;
    protected List<Writable> record;
    protected boolean hitImage = false;
    protected long height = 28, width = 28, channels = 1;
    protected boolean cropImage = false;
    protected ImageTransform imageTransform;
    protected BaseImageLoader imageLoader;
    protected InputSplit inputSplit;
    protected Map<String, String> fileNameMap = new LinkedHashMap<>();
    protected String pattern; // Pattern to split and segment file name, pass in regex
    protected int patternPosition = 0;
    @Getter @Setter
    protected boolean logLabelCountOnInit = true;
    @Getter @Setter
    protected boolean nchw_channels_first = true;

    public final static String HEIGHT = NAME_SPACE + ".height";
    public final static String WIDTH = NAME_SPACE + ".width";
    public final static String CHANNELS = NAME_SPACE + ".channels";
    public final static String CROP_IMAGE = NAME_SPACE + ".cropimage";
    public final static String IMAGE_LOADER = NAME_SPACE + ".imageloader";

    public BaseImageRecordReader() {}

    public BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
        this(height, width, channels, labelGenerator, null);

    public BaseImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
        this(height, width, channels, null, labelGenerator,null);

    public BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
                                 ImageTransform imageTransform) {
        this(height, width, channels, labelGenerator, null, imageTransform);

    protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
                                    PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
        this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform);

    protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
                                    PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
        this.height = height;
        this.width = width;
        this.channels = channels;
        this.labelGenerator = labelGenerator;
        this.labelMultiGenerator = labelMultiGenerator;
        this.imageTransform = imageTransform;
        this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
        this.nchw_channels_first = nchw_channels_first;

    protected boolean containsFormat(String format) {
        for (String format2 : imageLoader.getAllowedFormats())
            if (format.endsWith("." + format2))
                return true;
        return false;

    public void initialize(InputSplit split) throws IOException {
        if (imageLoader == null) {
            imageLoader = new NativeImageLoader(height, width, channels, imageTransform);

        if(split instanceof InputStreamInputSplit) {
            this.inputSplit = split;
            this.finishedInputStreamSplit = false;

        inputSplit = split;

        URI[] locations = split.locations();
        if (locations != null && locations.length >= 1) {
            if (appendLabel && labelGenerator != null && labelGenerator.inferLabelClasses()) {
                Set<String> labelsSet = new HashSet<>();
                for (URI location : locations) {
                    File imgFile = new File(location);
                    String name = labelGenerator.getLabelForPath(location).toString();
                    if (pattern != null) {
                        String label = name.split(pattern)[patternPosition];
                        fileNameMap.put(imgFile.toString(), label);
                if(logLabelCountOnInit) {
          "ImageRecordReader: {} label classes inferred using label generator {}", labelsSet.size(), labelGenerator.getClass().getSimpleName());
            iter = new FileFromPathIterator(inputSplit.locationsPathIterator()); //This handles randomization internally if necessary
        } else
            throw new IllegalArgumentException("No path locations found in the split.");

        if (split instanceof FileSplit) {
            //remove the root directory
            FileSplit split1 = (FileSplit) split;

        //To ensure consistent order for label assignment (irrespective of file iteration order), we want to sort the list of labels

    public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
        this.appendLabel = conf.getBoolean(APPEND_LABEL, appendLabel);
        this.labels = new ArrayList<>(conf.getStringCollection(LABELS));
        this.height = conf.getLong(HEIGHT, height);
        this.width = conf.getLong(WIDTH, width);
        this.channels = conf.getLong(CHANNELS, channels);
        this.cropImage = conf.getBoolean(CROP_IMAGE, cropImage);
        if ("imageio".equals(conf.get(IMAGE_LOADER))) {
            this.imageLoader = new ImageLoader(height, width, channels, cropImage);
        } else {
            this.imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
        this.conf = conf;

     * Called once at initialization.
     * @param split          the split that defines the range of records to read
     * @param imageTransform the image transform to use to transform images while loading them
     * @throws
    public void initialize(InputSplit split, ImageTransform imageTransform) throws IOException {
        this.imageLoader = null;
        this.imageTransform = imageTransform;

     * Called once at initialization.
     * @param conf           a configuration for initialization
     * @param split          the split that defines the range of records to read
     * @param imageTransform the image transform to use to transform images while loading them
     * @throws
     * @throws InterruptedException
    public void initialize(Configuration conf, InputSplit split, ImageTransform imageTransform)
            throws IOException, InterruptedException {
        this.imageLoader = null;
        this.imageTransform = imageTransform;
        initialize(conf, split);

    public List<Writable> next() {
        if(inputSplit instanceof InputStreamInputSplit) {
            InputStreamInputSplit inputStreamInputSplit = (InputStreamInputSplit) inputSplit;
            try {
                NDArrayWritable ndArrayWritable =  new NDArrayWritable(imageLoader.asMatrix(inputStreamInputSplit.getIs()));
                finishedInputStreamSplit = true;
                return Arrays.<Writable>asList(ndArrayWritable);
            } catch (IOException e) {
        if (iter != null) {
            List<Writable> ret;
            File image =;
            currentFile = image;

            if (image.isDirectory())
                return next();
            try {
                INDArray array = imageLoader.asMatrix(image);
                    array = array.permute(0,2,3,1);     //NCHW to NHWC

                Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE);
                ret = RecordConverter.toRecord(array);
                if (appendLabel || writeLabel){
                    if(labelMultiGenerator != null){
                    } else {
                        if (labelGenerator.inferLabelClasses()) {
                            //Standard classification use case (i.e., handle String -> integer conversion
                            ret.add(new IntWritable(labels.indexOf(getLabel(image.getPath()))));
                        } else {
                            //Regression use cases, and PathLabelGenerator instances that already map to integers
            } catch (Exception e) {
                throw new RuntimeException(e);
            return ret;
        } else if (record != null) {
            hitImage = true;
            return record;
        throw new IllegalStateException("No more elements");

    public boolean hasNext() {
        if(inputSplit instanceof InputStreamInputSplit) {
            return finishedInputStreamSplit;

        if (iter != null) {
            return iter.hasNext();
        } else if (record != null) {
            return !hitImage;
        throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");

    public boolean batchesSupported() {
        return (imageLoader instanceof NativeImageLoader);

    public List<List<Writable>> next(int num) {
        Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num);

        if (imageLoader == null) {
            imageLoader = new NativeImageLoader(height, width, channels, imageTransform);

        List<File> currBatch = new ArrayList<>();

        int cnt = 0;

        int numCategories = (appendLabel || writeLabel) ? labels.size() : 0;
        List<Integer> currLabels = null;
        List<Writable> currLabelsWritable = null;
        List<List<Writable>> multiGenLabels = null;
        while (cnt < num && iter.hasNext()) {
            currentFile =;
            if (appendLabel || writeLabel) {
                //Collect the label Writables from the label generators
                if(labelMultiGenerator != null){
                    if(multiGenLabels == null)
                        multiGenLabels = new ArrayList<>();

                } else {
                    if (labelGenerator.inferLabelClasses()) {
                        if (currLabels == null)
                            currLabels = new ArrayList<>();
                    } else {
                        if (currLabelsWritable == null)
                            currLabelsWritable = new ArrayList<>();

        INDArray features = Nd4j.createUninitialized(new long[] {cnt, channels, height, width}, 'c');
        Nd4j.getAffinityManager().tagLocation(features, AffinityManager.Location.HOST);
        for (int i = 0; i < cnt; i++) {
            try {
                ((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
                        features.tensorAlongDimension(i, 1, 2, 3));
            } catch (Exception e) {
                System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath());
                throw new RuntimeException(e);
            features = features.permute(0,2,3,1);   //NCHW to NHWC
        Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);

        List<INDArray> ret = new ArrayList<>();
        if (appendLabel || writeLabel) {
            //And convert the previously collected label Writables from the label generators
            if(labelMultiGenerator != null){
                List<Writable> temp = new ArrayList<>();
                List<Writable> first = multiGenLabels.get(0);
                for(int col=0; col<first.size(); col++ ){
                    for (List<Writable> multiGenLabel : multiGenLabels) {
                    INDArray currCol = RecordConverter.toMinibatchArray(temp);
            } else {
                INDArray labels;
                if (labelGenerator.inferLabelClasses()) {
                    //Standard classification use case (i.e., handle String -> integer conversion)
                    labels = Nd4j.create(cnt, numCategories, 'c');
                    Nd4j.getAffinityManager().tagLocation(labels, AffinityManager.Location.HOST);
                    for (int i = 0; i < currLabels.size(); i++) {
                        labels.putScalar(i, currLabels.get(i), 1.0f);
                } else {
                    //Regression use cases, and PathLabelGenerator instances that already map to integers
                    if (currLabelsWritable.get(0) instanceof NDArrayWritable) {
                        List<INDArray> arr = new ArrayList<>();
                        for (Writable w : currLabelsWritable) {
                            arr.add(((NDArrayWritable) w).get());
                        labels = Nd4j.concat(0, arr.toArray(new INDArray[arr.size()]));
                    } else {
                        labels = RecordConverter.toMinibatchArray(currLabelsWritable);


        return new NDArrayRecordBatch(ret);

    public void close() throws IOException {
        //No op

    public void setConf(Configuration conf) {
        this.conf = conf;

    public Configuration getConf() {
        return conf;

     * Get the label from the given path
     * @param path the path to get the label from
     * @return the label for the given path
    public String getLabel(String path) {
        if (labelGenerator != null) {
            return labelGenerator.getLabelForPath(path).toString();
        if (fileNameMap != null && fileNameMap.containsKey(path))
            return fileNameMap.get(path);
        return (new File(path)).getParentFile().getName();

     * Accumulate the label from the path
     * @param path the path to get the label from
    protected void accumulateLabel(String path) {
        String name = getLabel(path);
        if (!labels.contains(name))

     * Returns the file loaded last by {@link #next()}.
    public File getCurrentFile() {
        return currentFile;

     * Sets manually the file returned by {@link #getCurrentFile()}.
    public void setCurrentFile(File currentFile) {
        this.currentFile = currentFile;

    public List<String> getLabels() {
        return labels;

    public void setLabels(List<String> labels) {
        this.labels = labels;
        this.writeLabel = true;

    public void reset() {
        if (inputSplit == null)
            throw new UnsupportedOperationException("Cannot reset without first initializing");
        if (iter != null) {
            iter = new FileFromPathIterator(inputSplit.locationsPathIterator());
        } else if (record != null) {
            hitImage = false;

    public boolean resetSupported(){
        if(inputSplit == null){
            return false;
        return inputSplit.resetSupported();

     * Returns {@code getLabels().size()}.
    public int numLabels() {
        return labels.size();

    public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
        if (imageLoader == null) {
            imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
        INDArray array = imageLoader.asMatrix(dataInputStream);
            array = array.permute(0,2,3,1);
        List<Writable> ret = RecordConverter.toRecord(array);
        if (appendLabel)
            ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
        return ret;

    public Record nextRecord() {
        List<Writable> list = next();
        URI uri = URIUtil.fileToURI(currentFile);
        return new org.datavec.api.records.impl.Record(list, new RecordMetaDataURI(uri, BaseImageRecordReader.class));

    public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
        return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0);

    public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
        List<Record> out = new ArrayList<>();
        for (RecordMetaData meta : recordMetaDatas) {
            URI uri = meta.getURI();
            File f = new File(uri);

            List<Writable> next;
            try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(f)))) {
                next = record(uri, dis);
            out.add(new org.datavec.api.records.impl.Record(next, meta));
        return out;