deeplearning4j/deeplearning4j-data/deeplearning4j-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/EmnistDataSetIterator.java
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.datasets.iterator.impl;
import lombok.Getter;
import org.deeplearning4j.common.resources.DL4JResources;
import org.deeplearning4j.common.resources.ResourceType;
import org.deeplearning4j.datasets.fetchers.EmnistDataFetcher;
import org.eclipse.deeplearning4j.resources.utils.EMnistSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.BaseDatasetIterator;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class EmnistDataSetIterator extends BaseDatasetIterator {
private static final int NUM_COMPLETE_TRAIN = 697932;
private static final int NUM_COMPLETE_TEST = 116323;
private static final int NUM_MERGE_TRAIN = 697932;
private static final int NUM_MERGE_TEST = 116323;
private static final int NUM_BALANCED_TRAIN = 112800;
private static final int NUM_BALANCED_TEST = 18800;
private static final int NUM_DIGITS_TRAIN = 240000;
private static final int NUM_DIGITS_TEST = 40000;
private static final int NUM_LETTERS_TRAIN = 88800;
private static final int NUM_LETTERS_TEST = 14800;
private static final int NUM_MNIST_TRAIN = 60000;
private static final int NUM_MNIST_TEST = 10000;
private static final char[] LABELS_COMPLETE = new char[] {48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68,
69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 97, 98, 99,
100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
120, 121, 122};
private static final char[] LABELS_MERGE = new char[] {48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 97, 98, 100,
101, 102, 103, 104, 110, 113, 114, 116};
private static final char[] LABELS_BALANCED = new char[] {48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 65, 66, 67, 68,
69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 97, 98, 100,
101, 102, 103, 104, 110, 113, 114, 116};
private static final char[] LABELS_DIGITS = new char[] {48, 49, 50, 51, 52, 53, 54, 55, 56, 57};
private static final char[] LABELS_LETTERS = new char[] {65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90};
protected EMnistSet dataSet;
protected int batch, numExamples;
@Getter
protected DataSetPreProcessor preProcessor;
/**
* Create an EMNIST iterator with randomly shuffled data based on a random RNG seed
*
* @param dataSet Dataset (subset) to return
* @param batch Batch size
* @param train If true: use training set. If false: use test set
* @throws IOException If an error occurs when loading/downloading the dataset
*/
public EmnistDataSetIterator(EMnistSet dataSet, int batch, boolean train) throws IOException {
this(dataSet, batch, train, System.currentTimeMillis());
}
/**
* Create an EMNIST iterator with randomly shuffled data based on a specified RNG seed
*
* @param dataSet Dataset (subset) to return
* @param batchSize Batch size
* @param train If true: use training set. If false: use test set
* @param seed Random number generator seed
*/
public EmnistDataSetIterator(EMnistSet dataSet, int batchSize, boolean train, long seed) throws IOException {
this(dataSet, batchSize, false, train, true, seed);
}
/**
* Get the specified number of MNIST examples (test or train set), with optional shuffling and binarization.
*
* @param batch Size of each minibatch
* @param binarize whether to binarize the data or not (if false: normalize in range 0 to 1)
* @param train Train vs. test set
* @param shuffle whether to shuffle the examples
* @param rngSeed random number generator seed to use when shuffling examples
*/
public EmnistDataSetIterator(EMnistSet dataSet, int batch, boolean binarize, boolean train, boolean shuffle, long rngSeed, File topLevelDir)
throws IOException {
super(batch, numExamples(train, dataSet), new EmnistDataFetcher(dataSet, binarize, train, shuffle, rngSeed));
this.dataSet = dataSet;
}
/**
* Get the specified number of MNIST examples (test or train set), with optional shuffling and binarization.
*
* @param batch Size of each minibatch
* @param binarize whether to binarize the data or not (if false: normalize in range 0 to 1)
* @param train Train vs. test set
* @param shuffle whether to shuffle the examples
* @param rngSeed random number generator seed to use when shuffling examples
*/
public EmnistDataSetIterator(EMnistSet dataSet, int batch, boolean binarize, boolean train, boolean shuffle, long rngSeed)
throws IOException {
this(dataSet,batch,binarize,train,shuffle,rngSeed, DL4JResources.getDirectory(ResourceType.DATASET,"emnist"));
}
private static int numExamples(boolean train, EMnistSet ds) {
if (train) {
return numExamplesTrain(ds);
} else {
return numExamplesTest(ds);
}
}
/**
* Get the number of training examples for the specified subset
*
* @param dataSet Subset to get
* @return Number of examples for the specified subset
*/
public static int numExamplesTrain(EMnistSet dataSet) {
switch (dataSet) {
case COMPLETE:
return NUM_COMPLETE_TRAIN;
case MERGE:
return NUM_MERGE_TRAIN;
case BALANCED:
return NUM_BALANCED_TRAIN;
case LETTERS:
return NUM_LETTERS_TRAIN;
case DIGITS:
return NUM_DIGITS_TRAIN;
case MNIST:
return NUM_MNIST_TRAIN;
default:
throw new UnsupportedOperationException("Unknown Set: " + dataSet);
}
}
/**
* Get the number of test examples for the specified subset
*
* @param dataSet Subset to get
* @return Number of examples for the specified subset
*/
public static int numExamplesTest(EMnistSet dataSet) {
switch (dataSet) {
case COMPLETE:
return NUM_COMPLETE_TEST;
case MERGE:
return NUM_MERGE_TEST;
case BALANCED:
return NUM_BALANCED_TEST;
case LETTERS:
return NUM_LETTERS_TEST;
case DIGITS:
return NUM_DIGITS_TEST;
case MNIST:
return NUM_MNIST_TEST;
default:
throw new UnsupportedOperationException("Unknown Set: " + dataSet);
}
}
/**
* Get the number of labels for the specified subset
*
* @param dataSet Subset to get
* @return Number of labels for the specified subset
*/
public static int numLabels(EMnistSet dataSet) {
switch (dataSet) {
case COMPLETE:
return 62;
case MERGE:
return 47;
case BALANCED:
return 47;
case LETTERS:
return 26;
case DIGITS:
return 10;
case MNIST:
return 10;
default:
throw new UnsupportedOperationException("Unknown Set: " + dataSet);
}
}
/**
* Get the labels as a character array
*
* @return Labels
*/
public char[] getLabelsArrays() {
return getLabelsArray(dataSet);
}
/**
* Get the labels as a List<String>
*
* @return Labels
*/
public List<String> getLabels() {
return getLabels(dataSet);
}
/**
* Get the label assignments for the given set as a character array.
*
* @param dataSet DataSet to get the label assignment for
* @return Label assignment and given dataset
*/
public static char[] getLabelsArray(EMnistSet dataSet) {
switch (dataSet) {
case COMPLETE:
return LABELS_COMPLETE;
case MERGE:
return LABELS_MERGE;
case BALANCED:
return LABELS_BALANCED;
case LETTERS:
return LABELS_LETTERS;
case DIGITS:
case MNIST:
return LABELS_DIGITS;
default:
throw new UnsupportedOperationException("Unknown Set: " + dataSet);
}
}
/**
* Get the label assignments for the given set as a List<String>
*
* @param dataSet DataSet to get the label assignment for
* @return Label assignment and given dataset
*/
public static List<String> getLabels(EMnistSet dataSet) {
char[] c = getLabelsArray(dataSet);
List<String> l = new ArrayList<>(c.length);
for (char c2 : c) {
l.add(String.valueOf(c2));
}
return l;
}
/**
* Are the labels balanced in the training set (that is: are the number of examples for each label equal?)
*
* @param dataSet Set to get balanced value for
* @return True if balanced dataset, false otherwise
*/
public static boolean isBalanced(EMnistSet dataSet) {
switch (dataSet) {
case COMPLETE:
case MERGE:
case LETTERS:
//Note: EMNIST docs claims letters is balanced, but this is not possible for training set:
// 88800 examples / 26 classes = 3418.46
return false;
case BALANCED:
case DIGITS:
case MNIST:
return true;
default:
throw new UnsupportedOperationException("Unknown Set: " + dataSet);
}
}
}