IBM/pytorchpipe

View on GitHub
configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes_2losses.yml

Summary

Maintainability
Test Coverage
# Load config defining MNIST tasks for training, validation and testing.
default_configs: mnist/default_mnist.yml

# Training parameters - overwrite defaults:
training:
  task: 
    #resize_image: [32, 32]
    batch_size: 64
  #optimizer:
  #  #type: Adam
  #  lr: 0.001
  #terminal_conditions:
  #  loss_stop_threshold: 0.08

# Validation parameters - overwrite defaults:
#validation:
#  partial_validation_interval: 10
#  task:
#    resize_image: [32, 32]

# Testing parameters - overwrite defaults:
#test:
#  task:
#    resize_image: [32, 32]

# Definition of the pipeline.
pipeline:

  # Disable components for "default" flow.
  disable: nllloss, accuracy, precision_recall, answer_decoder, image_viewer

  ################# SHARED #################

  # Add global variables.
  global_publisher:
    type: GlobalVariablePublisher
    priority: 0.1
    keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2]
    values: [3, 7, {"Three": 0, "One": 1, "Five": 2}, {"Four": 0, "Two": 1, "Zero": 2, "Six": 3, "Seven": 4, "Eight": 5, "Nine": 6}]
    #values: [3, 7, {"Zero": 0, "One": 1, "Two": 2}, {"Three": 0, "Four": 1, "Five": 2, "Six": 3, "Seven": 4, "Eight": 5, "Nine": 6}]

  # Shared model - encoder.
  image_encoder:
    type: ConvNetEncoder
    priority: 0.2

  # Reshape inputs
  reshaper:
    type: ReshapeTensor
    input_dims: [-1, 16, 1, 1]
    output_dims: [-1, 16]
    priority: 0.3
    streams:
      inputs: feature_maps
      outputs: reshaped_maps
    globals:
      output_size: reshaped_maps_size

  ################# Flow 1 #################
  # Classifier.
  flow1_classifier:
    type: FeedForwardNetwork 
    priority: 1.1
    streams:
    globals:
      input_size: reshaped_maps_size
      prediction_size: num_classes1
    streams:
      inputs: reshaped_maps
      predictions: flow1_predictions
      
  flow1_label_to_mask1:
    type: StringToMask
    priority: 1.2
    globals:
      word_mappings: word_to_ix1
    streams:
      strings: labels
      masks: flow1_masks

  flow1_label_to_target1:
    type: LabelIndexer
    priority: 1.3
    import_word_mappings_from_globals: True
    globals:
      word_mappings: word_to_ix1
    streams:
      inputs: labels
      outputs: flow1_targets

  # Masked loss.
  flow1_nllloss:
    type: NLLLoss
    priority: 1.4
    use_masking: True
    streams:
      targets: flow1_targets
      predictions: flow1_predictions
      masks: flow1_masks
      loss: flow1_loss

  # Statistics.
  flow1_accuracy:
    type: AccuracyStatistics
    priority: 1.51
    use_masking: True
    streams:
      predictions: flow1_predictions
      targets: flow1_targets
      masks: flow1_masks
    statistics:
      accuracy: flow1_accuracy

  flow1_precision_recall:
    type: PrecisionRecallStatistics
    priority: 1.52
    use_word_mappings: True
    show_class_scores: True
    show_confusion_matrix: True
    use_masking: True
    globals:
      word_mappings: word_to_ix1
      num_classes: num_classes1
    streams:
      targets: flow1_targets
      predictions: flow1_predictions
      masks: flow1_masks
    statistics:
      precision: flow1_precision
      recall: flow1_recall
      f1score: flow1_f1score

  ################# Flow 2 #################
  # Classifier.
  flow2_classifier:
    type: FeedForwardNetwork 
    priority: 2.1
    streams:
    globals:
      input_size: reshaped_maps_size
      prediction_size: num_classes2
    streams:
      inputs: reshaped_maps
      predictions: flow2_predictions
      
  flow2_label_to_mask2:
    type: StringToMask
    priority: 2.2
    globals:
      word_mappings: word_to_ix2
    streams:
      strings: labels
      masks: flow2_masks

  flow2_label_to_target2:
    type: LabelIndexer
    priority: 2.3
    import_word_mappings_from_globals: True
    globals:
      word_mappings: word_to_ix2
    streams:
      inputs: labels
      outputs: flow2_targets

  # Masked loss.
  flow2_nllloss:
    type: NLLLoss
    priority: 2.4
    use_masking: True
    streams:
      targets: flow2_targets
      predictions: flow2_predictions
      masks: flow2_masks
      loss: flow2_loss

  # Statistics.
  flow2_accuracy:
    type: AccuracyStatistics
    priority: 2.41
    use_masking: True
    streams:
      targets: flow2_targets
      predictions: flow2_predictions
      masks: flow2_masks
    statistics:
      accuracy: flow2_accuracy

  flow2_precision_recall:
    type: PrecisionRecallStatistics
    priority: 2.42
    use_word_mappings: True
    show_class_scores: True
    show_confusion_matrix: True
    use_masking: True
    globals:
      word_mappings: word_to_ix2
      num_classes: num_classes2
    streams:
      targets: flow2_targets
      predictions: flow2_predictions
      masks: flow2_masks
    statistics:
      precision: flow2_precision
      recall: flow2_recall
      f1score: flow2_f1score

  ################# JOIN #################
  joined_predictions:
    type: JoinMaskedPredictions
    priority: 3.1
    # Names of used input streams.
    input_prediction_streams: [flow1_predictions, flow2_predictions]
    input_mask_streams: [flow1_masks, flow2_masks]
    input_word_mappings: [word_to_ix1, word_to_ix2]
    globals:
      output_word_mappings: label_word_mappings # from MNIST task.
    streams:
      output_strings: merged_predictions
      output_indices: merged_indices

  # Statistics.
  joined_precision_recall:
    type: PrecisionRecallStatistics
    priority: 3.22
    # Use prediction indices instead of distributions.
    use_prediction_distributions: False
    use_word_mappings: True
    show_class_scores: True
    show_confusion_matrix: True
    globals:
      word_mappings: label_word_mappings # straight from MNIST
      #num_classes: num_classes
    streams:
      targets: targets # straight from MNIST
      predictions: merged_indices
    statistics:
      precision: joined_precision
      recall: joined_recall
      f1score: joined_f1score

  # "Fix" (overwrite) stream names in viewers.
  image_viewer:
    streams:
      answers: merged_predictions

  stream_viewer:
    input_streams: labels, merged_predictions

#: pipeline