IBM/pytorchpipe

View on GitHub
configs/mnist/mnist_classification_softmax.yml

Summary

Maintainability
Test Coverage
# Load config defining MNIST tasks for training, validation and testing.
default_configs: mnist/default_mnist.yml

pipeline:

  # Reshapes tensors.
  reshaper:
    type: ReshapeTensor
    input_dims: [-1, 1, 28, 28]
    output_dims: [-1, 784]
    priority: 1
    streams:
      outputs: reshaped_images
    globals:
      output_size: reshaped_image_size

  # Classifier.
  classifier:
    type: FeedForwardNetwork 
    priority: 2
    dropout_rate: 0.1
    hidden_sizes: [100, 100]
    streams:
      inputs: reshaped_images
    globals:
      input_size: reshaped_image_size
      prediction_size: num_classes

#: pipeline