datavec/datavec-data/datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.image.recordreader;
import org.nd4j.shade.guava.base.Preconditions;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.conf.Configuration;
import org.datavec.api.io.labels.PathLabelGenerator;
import org.datavec.api.io.labels.PathMultiLabelGenerator;
import org.datavec.api.records.Record;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataURI;
import org.datavec.api.records.reader.BaseRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.util.files.FileFromPathIterator;
import org.datavec.api.util.files.URIUtil;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.batch.NDArrayRecordBatch;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.loader.ImageLoader;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.transform.ImageTransform;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.*;
import java.net.URI;
import java.util.*;
@Slf4j
public abstract class BaseImageRecordReader extends BaseRecordReader {
protected boolean finishedInputStreamSplit;
protected Iterator<File> iter;
protected Configuration conf;
protected File currentFile;
protected PathLabelGenerator labelGenerator = null;
protected PathMultiLabelGenerator labelMultiGenerator = null;
protected List<String> labels = new ArrayList<>();
protected boolean appendLabel = false;
protected boolean writeLabel = false;
protected List<Writable> record;
protected boolean hitImage = false;
protected long height = 28, width = 28, channels = 1;
protected boolean cropImage = false;
protected ImageTransform imageTransform;
protected BaseImageLoader imageLoader;
protected InputSplit inputSplit;
protected Map<String, String> fileNameMap = new LinkedHashMap<>();
protected String pattern; // Pattern to split and segment file name, pass in regex
protected int patternPosition = 0;
@Getter @Setter
protected boolean logLabelCountOnInit = true;
@Getter @Setter
protected boolean nchw_channels_first = true;
public final static String HEIGHT = NAME_SPACE + ".height";
public final static String WIDTH = NAME_SPACE + ".width";
public final static String CHANNELS = NAME_SPACE + ".channels";
public final static String CROP_IMAGE = NAME_SPACE + ".cropimage";
public final static String IMAGE_LOADER = NAME_SPACE + ".imageloader";
public BaseImageRecordReader() {}
public BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator) {
this(height, width, channels, labelGenerator, null);
}
public BaseImageRecordReader(long height, long width, long channels, PathMultiLabelGenerator labelGenerator) {
this(height, width, channels, null, labelGenerator,null);
}
public BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
ImageTransform imageTransform) {
this(height, width, channels, labelGenerator, null, imageTransform);
}
protected BaseImageRecordReader(long height, long width, long channels, PathLabelGenerator labelGenerator,
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
this(height, width, channels, true, labelGenerator, labelMultiGenerator, imageTransform);
}
protected BaseImageRecordReader(long height, long width, long channels, boolean nchw_channels_first, PathLabelGenerator labelGenerator,
PathMultiLabelGenerator labelMultiGenerator, ImageTransform imageTransform) {
this.height = height;
this.width = width;
this.channels = channels;
this.labelGenerator = labelGenerator;
this.labelMultiGenerator = labelMultiGenerator;
this.imageTransform = imageTransform;
this.appendLabel = (labelGenerator != null || labelMultiGenerator != null);
this.nchw_channels_first = nchw_channels_first;
}
protected boolean containsFormat(String format) {
for (String format2 : imageLoader.getAllowedFormats())
if (format.endsWith("." + format2))
return true;
return false;
}
@Override
public void initialize(InputSplit split) throws IOException {
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
if(split instanceof InputStreamInputSplit) {
this.inputSplit = split;
this.finishedInputStreamSplit = false;
return;
}
inputSplit = split;
URI[] locations = split.locations();
if (locations != null && locations.length >= 1) {
if (appendLabel && labelGenerator != null && labelGenerator.inferLabelClasses()) {
Set<String> labelsSet = new HashSet<>();
for (URI location : locations) {
File imgFile = new File(location);
String name = labelGenerator.getLabelForPath(location).toString();
labelsSet.add(name);
if (pattern != null) {
String label = name.split(pattern)[patternPosition];
fileNameMap.put(imgFile.toString(), label);
}
}
labels.clear();
labels.addAll(labelsSet);
if(logLabelCountOnInit) {
log.info("ImageRecordReader: {} label classes inferred using label generator {}", labelsSet.size(), labelGenerator.getClass().getSimpleName());
}
}
iter = new FileFromPathIterator(inputSplit.locationsPathIterator()); //This handles randomization internally if necessary
} else
throw new IllegalArgumentException("No path locations found in the split.");
if (split instanceof FileSplit) {
//remove the root directory
FileSplit split1 = (FileSplit) split;
labels.remove(split1.getRootDir());
}
//To ensure consistent order for label assignment (irrespective of file iteration order), we want to sort the list of labels
Collections.sort(labels);
}
@Override
public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException {
this.appendLabel = conf.getBoolean(APPEND_LABEL, appendLabel);
this.labels = new ArrayList<>(conf.getStringCollection(LABELS));
this.height = conf.getLong(HEIGHT, height);
this.width = conf.getLong(WIDTH, width);
this.channels = conf.getLong(CHANNELS, channels);
this.cropImage = conf.getBoolean(CROP_IMAGE, cropImage);
if ("imageio".equals(conf.get(IMAGE_LOADER))) {
this.imageLoader = new ImageLoader(height, width, channels, cropImage);
} else {
this.imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
this.conf = conf;
initialize(split);
}
/**
* Called once at initialization.
*
* @param split the split that defines the range of records to read
* @param imageTransform the image transform to use to transform images while loading them
* @throws java.io.IOException
*/
public void initialize(InputSplit split, ImageTransform imageTransform) throws IOException {
this.imageLoader = null;
this.imageTransform = imageTransform;
initialize(split);
}
/**
* Called once at initialization.
*
* @param conf a configuration for initialization
* @param split the split that defines the range of records to read
* @param imageTransform the image transform to use to transform images while loading them
* @throws java.io.IOException
* @throws InterruptedException
*/
public void initialize(Configuration conf, InputSplit split, ImageTransform imageTransform)
throws IOException, InterruptedException {
this.imageLoader = null;
this.imageTransform = imageTransform;
initialize(conf, split);
}
@Override
public List<Writable> next() {
if(inputSplit instanceof InputStreamInputSplit) {
InputStreamInputSplit inputStreamInputSplit = (InputStreamInputSplit) inputSplit;
try {
NDArrayWritable ndArrayWritable = new NDArrayWritable(imageLoader.asMatrix(inputStreamInputSplit.getIs()));
finishedInputStreamSplit = true;
return Arrays.<Writable>asList(ndArrayWritable);
} catch (IOException e) {
log.error("",e);
}
}
if (iter != null) {
List<Writable> ret;
File image = iter.next();
currentFile = image;
if (image.isDirectory())
return next();
try {
invokeListeners(image);
INDArray array = imageLoader.asMatrix(image);
if(!nchw_channels_first){
array = array.permute(0,2,3,1); //NCHW to NHWC
}
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.DEVICE);
ret = RecordConverter.toRecord(array);
if (appendLabel || writeLabel){
if(labelMultiGenerator != null){
ret.addAll(labelMultiGenerator.getLabels(image.getPath()));
} else {
if (labelGenerator.inferLabelClasses()) {
//Standard classification use case (i.e., handle String -> integer conversion
ret.add(new IntWritable(labels.indexOf(getLabel(image.getPath()))));
} else {
//Regression use cases, and PathLabelGenerator instances that already map to integers
ret.add(labelGenerator.getLabelForPath(image.getPath()));
}
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
return ret;
} else if (record != null) {
hitImage = true;
invokeListeners(record);
return record;
}
throw new IllegalStateException("No more elements");
}
@Override
public boolean hasNext() {
if(inputSplit instanceof InputStreamInputSplit) {
return finishedInputStreamSplit;
}
if (iter != null) {
return iter.hasNext();
} else if (record != null) {
return !hitImage;
}
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
@Override
public boolean batchesSupported() {
return (imageLoader instanceof NativeImageLoader);
}
@Override
public List<List<Writable>> next(int num) {
Preconditions.checkArgument(num > 0, "Number of examples must be > 0: got %s", num);
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
List<File> currBatch = new ArrayList<>();
int cnt = 0;
int numCategories = (appendLabel || writeLabel) ? labels.size() : 0;
List<Integer> currLabels = null;
List<Writable> currLabelsWritable = null;
List<List<Writable>> multiGenLabels = null;
while (cnt < num && iter.hasNext()) {
currentFile = iter.next();
currBatch.add(currentFile);
invokeListeners(currentFile);
if (appendLabel || writeLabel) {
//Collect the label Writables from the label generators
if(labelMultiGenerator != null){
if(multiGenLabels == null)
multiGenLabels = new ArrayList<>();
multiGenLabels.add(labelMultiGenerator.getLabels(currentFile.getPath()));
} else {
if (labelGenerator.inferLabelClasses()) {
if (currLabels == null)
currLabels = new ArrayList<>();
currLabels.add(labels.indexOf(getLabel(currentFile.getPath())));
} else {
if (currLabelsWritable == null)
currLabelsWritable = new ArrayList<>();
currLabelsWritable.add(labelGenerator.getLabelForPath(currentFile.getPath()));
}
}
}
cnt++;
}
INDArray features = Nd4j.createUninitialized(new long[] {cnt, channels, height, width}, 'c');
Nd4j.getAffinityManager().tagLocation(features, AffinityManager.Location.HOST);
for (int i = 0; i < cnt; i++) {
try {
((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i),
features.tensorAlongDimension(i, 1, 2, 3));
} catch (Exception e) {
System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath());
throw new RuntimeException(e);
}
}
if(!nchw_channels_first){
features = features.permute(0,2,3,1); //NCHW to NHWC
}
Nd4j.getAffinityManager().ensureLocation(features, AffinityManager.Location.DEVICE);
List<INDArray> ret = new ArrayList<>();
ret.add(features);
if (appendLabel || writeLabel) {
//And convert the previously collected label Writables from the label generators
if(labelMultiGenerator != null){
List<Writable> temp = new ArrayList<>();
List<Writable> first = multiGenLabels.get(0);
for(int col=0; col<first.size(); col++ ){
temp.clear();
for (List<Writable> multiGenLabel : multiGenLabels) {
temp.add(multiGenLabel.get(col));
}
INDArray currCol = RecordConverter.toMinibatchArray(temp);
ret.add(currCol);
}
} else {
INDArray labels;
if (labelGenerator.inferLabelClasses()) {
//Standard classification use case (i.e., handle String -> integer conversion)
labels = Nd4j.create(cnt, numCategories, 'c');
Nd4j.getAffinityManager().tagLocation(labels, AffinityManager.Location.HOST);
for (int i = 0; i < currLabels.size(); i++) {
labels.putScalar(i, currLabels.get(i), 1.0f);
}
} else {
//Regression use cases, and PathLabelGenerator instances that already map to integers
if (currLabelsWritable.get(0) instanceof NDArrayWritable) {
List<INDArray> arr = new ArrayList<>();
for (Writable w : currLabelsWritable) {
arr.add(((NDArrayWritable) w).get());
}
labels = Nd4j.concat(0, arr.toArray(new INDArray[arr.size()]));
} else {
labels = RecordConverter.toMinibatchArray(currLabelsWritable);
}
}
ret.add(labels);
}
}
return new NDArrayRecordBatch(ret);
}
@Override
public void close() throws IOException {
//No op
}
@Override
public void setConf(Configuration conf) {
this.conf = conf;
}
@Override
public Configuration getConf() {
return conf;
}
/**
* Get the label from the given path
*
* @param path the path to get the label from
* @return the label for the given path
*/
public String getLabel(String path) {
if (labelGenerator != null) {
return labelGenerator.getLabelForPath(path).toString();
}
if (fileNameMap != null && fileNameMap.containsKey(path))
return fileNameMap.get(path);
return (new File(path)).getParentFile().getName();
}
/**
* Accumulate the label from the path
*
* @param path the path to get the label from
*/
protected void accumulateLabel(String path) {
String name = getLabel(path);
if (!labels.contains(name))
labels.add(name);
}
/**
* Returns the file loaded last by {@link #next()}.
*/
public File getCurrentFile() {
return currentFile;
}
/**
* Sets manually the file returned by {@link #getCurrentFile()}.
*/
public void setCurrentFile(File currentFile) {
this.currentFile = currentFile;
}
@Override
public List<String> getLabels() {
return labels;
}
public void setLabels(List<String> labels) {
this.labels = labels;
this.writeLabel = true;
}
@Override
public void reset() {
if (inputSplit == null)
throw new UnsupportedOperationException("Cannot reset without first initializing");
inputSplit.reset();
if (iter != null) {
iter = new FileFromPathIterator(inputSplit.locationsPathIterator());
} else if (record != null) {
hitImage = false;
}
}
@Override
public boolean resetSupported(){
if(inputSplit == null){
return false;
}
return inputSplit.resetSupported();
}
/**
* Returns {@code getLabels().size()}.
*/
public int numLabels() {
return labels.size();
}
@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) throws IOException {
invokeListeners(uri);
if (imageLoader == null) {
imageLoader = new NativeImageLoader(height, width, channels, imageTransform);
}
INDArray array = imageLoader.asMatrix(dataInputStream);
if(!nchw_channels_first)
array = array.permute(0,2,3,1);
List<Writable> ret = RecordConverter.toRecord(array);
if (appendLabel)
ret.add(new IntWritable(labels.indexOf(getLabel(uri.getPath()))));
return ret;
}
@Override
public Record nextRecord() {
List<Writable> list = next();
URI uri = URIUtil.fileToURI(currentFile);
return new org.datavec.api.records.impl.Record(list, new RecordMetaDataURI(uri, BaseImageRecordReader.class));
}
@Override
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0);
}
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) throws IOException {
List<Record> out = new ArrayList<>();
for (RecordMetaData meta : recordMetaDatas) {
URI uri = meta.getURI();
File f = new File(uri);
List<Writable> next;
try (DataInputStream dis = new DataInputStream(new BufferedInputStream(new FileInputStream(f)))) {
next = record(uri, dis);
}
out.add(new org.datavec.api.records.impl.Record(next, meta));
}
return out;
}
}