configs/mnist/mnist_classification_vf_shared_convnet_2softmaxes_2losses.yml
# Load config defining MNIST tasks for training, validation and testing.
default_configs: mnist/default_mnist.yml
# Training parameters - overwrite defaults:
training:
task:
#resize_image: [32, 32]
batch_size: 64
#optimizer:
# #type: Adam
# lr: 0.001
#terminal_conditions:
# loss_stop_threshold: 0.08
# Validation parameters - overwrite defaults:
#validation:
# partial_validation_interval: 10
# task:
# resize_image: [32, 32]
# Testing parameters - overwrite defaults:
#test:
# task:
# resize_image: [32, 32]
# Definition of the pipeline.
pipeline:
# Disable components for "default" flow.
disable: nllloss, accuracy, precision_recall, answer_decoder, image_viewer
################# SHARED #################
# Add global variables.
global_publisher:
type: GlobalVariablePublisher
priority: 0.1
keys: [num_classes1, num_classes2, word_to_ix1, word_to_ix2]
values: [3, 7, {"Three": 0, "One": 1, "Five": 2}, {"Four": 0, "Two": 1, "Zero": 2, "Six": 3, "Seven": 4, "Eight": 5, "Nine": 6}]
#values: [3, 7, {"Zero": 0, "One": 1, "Two": 2}, {"Three": 0, "Four": 1, "Five": 2, "Six": 3, "Seven": 4, "Eight": 5, "Nine": 6}]
# Shared model - encoder.
image_encoder:
type: ConvNetEncoder
priority: 0.2
# Reshape inputs
reshaper:
type: ReshapeTensor
input_dims: [-1, 16, 1, 1]
output_dims: [-1, 16]
priority: 0.3
streams:
inputs: feature_maps
outputs: reshaped_maps
globals:
output_size: reshaped_maps_size
################# Flow 1 #################
# Classifier.
flow1_classifier:
type: FeedForwardNetwork
priority: 1.1
streams:
globals:
input_size: reshaped_maps_size
prediction_size: num_classes1
streams:
inputs: reshaped_maps
predictions: flow1_predictions
flow1_label_to_mask1:
type: StringToMask
priority: 1.2
globals:
word_mappings: word_to_ix1
streams:
strings: labels
masks: flow1_masks
flow1_label_to_target1:
type: LabelIndexer
priority: 1.3
import_word_mappings_from_globals: True
globals:
word_mappings: word_to_ix1
streams:
inputs: labels
outputs: flow1_targets
# Masked loss.
flow1_nllloss:
type: NLLLoss
priority: 1.4
use_masking: True
streams:
targets: flow1_targets
predictions: flow1_predictions
masks: flow1_masks
loss: flow1_loss
# Statistics.
flow1_accuracy:
type: AccuracyStatistics
priority: 1.51
use_masking: True
streams:
predictions: flow1_predictions
targets: flow1_targets
masks: flow1_masks
statistics:
accuracy: flow1_accuracy
flow1_precision_recall:
type: PrecisionRecallStatistics
priority: 1.52
use_word_mappings: True
show_class_scores: True
show_confusion_matrix: True
use_masking: True
globals:
word_mappings: word_to_ix1
num_classes: num_classes1
streams:
targets: flow1_targets
predictions: flow1_predictions
masks: flow1_masks
statistics:
precision: flow1_precision
recall: flow1_recall
f1score: flow1_f1score
################# Flow 2 #################
# Classifier.
flow2_classifier:
type: FeedForwardNetwork
priority: 2.1
streams:
globals:
input_size: reshaped_maps_size
prediction_size: num_classes2
streams:
inputs: reshaped_maps
predictions: flow2_predictions
flow2_label_to_mask2:
type: StringToMask
priority: 2.2
globals:
word_mappings: word_to_ix2
streams:
strings: labels
masks: flow2_masks
flow2_label_to_target2:
type: LabelIndexer
priority: 2.3
import_word_mappings_from_globals: True
globals:
word_mappings: word_to_ix2
streams:
inputs: labels
outputs: flow2_targets
# Masked loss.
flow2_nllloss:
type: NLLLoss
priority: 2.4
use_masking: True
streams:
targets: flow2_targets
predictions: flow2_predictions
masks: flow2_masks
loss: flow2_loss
# Statistics.
flow2_accuracy:
type: AccuracyStatistics
priority: 2.41
use_masking: True
streams:
targets: flow2_targets
predictions: flow2_predictions
masks: flow2_masks
statistics:
accuracy: flow2_accuracy
flow2_precision_recall:
type: PrecisionRecallStatistics
priority: 2.42
use_word_mappings: True
show_class_scores: True
show_confusion_matrix: True
use_masking: True
globals:
word_mappings: word_to_ix2
num_classes: num_classes2
streams:
targets: flow2_targets
predictions: flow2_predictions
masks: flow2_masks
statistics:
precision: flow2_precision
recall: flow2_recall
f1score: flow2_f1score
################# JOIN #################
joined_predictions:
type: JoinMaskedPredictions
priority: 3.1
# Names of used input streams.
input_prediction_streams: [flow1_predictions, flow2_predictions]
input_mask_streams: [flow1_masks, flow2_masks]
input_word_mappings: [word_to_ix1, word_to_ix2]
globals:
output_word_mappings: label_word_mappings # from MNIST task.
streams:
output_strings: merged_predictions
output_indices: merged_indices
# Statistics.
joined_precision_recall:
type: PrecisionRecallStatistics
priority: 3.22
# Use prediction indices instead of distributions.
use_prediction_distributions: False
use_word_mappings: True
show_class_scores: True
show_confusion_matrix: True
globals:
word_mappings: label_word_mappings # straight from MNIST
#num_classes: num_classes
streams:
targets: targets # straight from MNIST
predictions: merged_indices
statistics:
precision: joined_precision
recall: joined_recall
f1score: joined_f1score
# "Fix" (overwrite) stream names in viewers.
image_viewer:
streams:
answers: merged_predictions
stream_viewer:
input_streams: labels, merged_predictions
#: pipeline