IBM/pytorchpipe

View on GitHub
configs/tutorials/mnist_classification_convnet_softmax.yml

Summary

Maintainability
Test Coverage
# Training parameters:
training:
  task: 
    type: MNIST
    batch_size: &b 64
    use_train_data: True
  # Use sampler that operates on a subset.
  sampler:
    type: SubsetRandomSampler
    indices: [0, 55000]
  # optimizer parameters:
  optimizer:
    type: Adam
    lr: 0.0001
  # settings parameters
  terminal_conditions:
    loss_stop_threshold: 0.15
    early_stop_validations: -1
    episode_limit: 10000
    epoch_limit: 10

# Validation parameters:
validation:
  task:
    type: MNIST
    batch_size: *b
    use_train_data: True  # True because we are splitting the training set to: validation and training
  # Use sampler that operates on a subset.
  sampler:
    type: SubsetRandomSampler
    indices: [55000, 60000]

# Testing parameters:
test:
  task:
    type: MNIST
    batch_size: *b
    use_train_data: False # Test set.

pipeline:
  # Model 1: 3 CNN layers.
  image_encoder:
    type: ConvNetEncoder
    priority: 1
    # Using default stream names, so the following could be removed (leaving it just for the clarity though).
    streams:
      inputs: inputs
      feature_maps: feature_maps

  # Reshape inputs
  reshaper:
    type: ReshapeTensor
    input_dims: [-1, 16, 1, 1]
    output_dims: [-1, 16]
    priority: 2
    streams:
      inputs: feature_maps
      outputs: reshaped_maps
    globals:
      output_size: reshaped_maps_size

  # Model 2: 1 Fully connected layer with softmax acitvation.
  classifier:
    type: FeedForwardNetwork 
    priority: 3
    streams:
      inputs: reshaped_maps
      # Using default stream name, so the following could be removed (leaving it just for the clarity though).
      predictions: predictions
    globals:
      input_size: reshaped_maps_size
      prediction_size: num_classes


  # Loss
  nllloss:
    type: NLLLoss
    priority: 4
    # Using default stream names, so the following could be removed (leaving it just for the clarity though).
    streams:
      targets: targets
      predictions: predictions

  accuracy:
    priority: 5
    type: AccuracyStatistics
    # Using default stream names, so the following could be removed (leaving it just for the clarity though).
    streams:
      targets: targets
      predictions: predictions

  answer_decoder:
    priority: 6
    type: WordDecoder
    import_word_mappings_from_globals: True
    globals:
      word_mappings: label_word_mappings
    streams:
      inputs: predictions
      outputs: predicted_answers

  stream_viewer:
    priority: 7
    type: StreamViewer
    input_streams: labels, targets, predictions, predicted_answers


#: pipeline