IBM/pytorchpipe

View on GitHub
configs/vqa_med_2019/c4_classification/c4_enc_attndec_resnet152_ewm_cat_is.yml

Summary

Maintainability
Test Coverage
# Load config defining tasks for training, validation and testing.
default_configs: vqa_med_2019/default_vqa_med_2019.yml

# Training parameters:
training:
  task:
    batch_size: 200 # requires to use 4 GPUs!
    categories: C4
    question_preprocessing: lowercase, remove_punctuation, tokenize #, random_remove_stop_words #,random_shuffle_words 
    answer_preprocessing: lowercase, remove_punctuation, tokenize
    export_sample_weights: ~/data/vqa-med/answers.c4.weights.csv
  sampler:
    weights: ~/data/vqa-med/answers.c4.weights.csv
  dataloader:
    num_workers: 4
  # Termination.
  terminal_conditions:
    loss_stop_threshold: 1.0e-2
    episode_limit: 1000000
    epoch_limit: -1

# Validation parameters:
validation:
  task:
    batch_size: 200
    categories: C4
    question_preprocessing: lowercase, remove_punctuation, tokenize
    answer_preprocessing: lowercase, remove_punctuation, tokenize
  dataloader:
    num_workers: 4

pipeline:

  global_publisher:
    priority: 0
    type: GlobalVariablePublisher
    # Add input_size to globals.
    keys: [question_encoder_output_size, image_encoder_output_size, element_wise_activation_size,image_size_encoder_input_size, image_size_encoder_output_size]
    values: [100, 100, 100, 2, 10]

  # Question embeddings
  question_embeddings:
    priority: 1.0
    type: SentenceEmbeddings
    embeddings_size: 100
    pretrained_embeddings_file: glove.6B.100d.txt
    data_folder: ~/data/vqa-med
    word_mappings_file: questions.all.word.mappings.csv
    fixed_padding: 10 # The longest question! max is 19!
    additional_tokens: <PAD>,<EOS>
    streams:
      inputs: questions
      outputs: embedded_questions

 # Target encoding.
  target_indexer:
    type: SentenceIndexer
    priority: 1.1
    data_folder: ~/data/vqa-med
    word_mappings_file: answer_words.c4.preprocessed.word.mappings.csv
    import_word_mappings_from_globals: False
    export_word_mappings_to_globals: True
    export_pad_mapping_to_globals: True
    additional_tokens: <PAD>,<EOS>
    eos_token: True
    fixed_padding: 10 # The longest question! max is 19!
    streams:
      inputs: answers
      outputs: indexed_answers
    globals:
      vocabulary_size: ans_vocabulary_size
      word_mappings: ans_word_mappings
      pad_index: ans_pad_index

  # Image encoder.
  image_encoder:
    priority: 2.0
    type: GenericImageEncoder
    model_type: resnet152
    streams:
      inputs: images
      outputs: image_activations
    globals:
      output_size: image_encoder_output_size

  # Single layer GRU Encoder
  encoder:
    priority: 3
    type: RecurrentNeuralNetwork
    cell_type: GRU
    initial_state: Trainable
    hidden_size: 100
    num_layers: 1
    use_logsoftmax: False
    output_last_state: True
    prediction_mode: Dense
    ffn_output: False
    dropout_rate: 0.1
    streams:
      inputs: embedded_questions
      predictions: s2s_encoder_output
      output_state: s2s_state_output
    globals:
      input_size: embeddings_size
      prediction_size: question_encoder_output_size 

  reshaper_1:
    priority: 3.01
    type: ReshapeTensor
    input_dims: [-1, 1, 100]
    output_dims: [-1, 100]
    streams:
      inputs: s2s_state_output
      outputs: s2s_state_output_reshaped
    globals:
      output_size: s2s_state_output_reshaped_size

  # Element wise multiplication + FF.
  question_image_fusion:
    priority: 3.1
    type: LowRankBilinearPooling
    dropout_rate: 0.5
    streams:
      image_encodings: image_activations
      question_encodings: s2s_state_output_reshaped
      outputs: element_wise_activations
    globals:
      image_encoding_size: image_encoder_output_size
      question_encoding_size: question_encoder_output_size
      output_size: element_wise_activation_size

  question_image_ffn:
    priority: 3.2
    type: FeedForwardNetwork 
    hidden_sizes: [100]
    dropout_rate: 0.5
    streams:
      inputs: element_wise_activations
      predictions: question_image_activations
    globals:
      input_size: element_wise_activation_size
      prediction_size: element_wise_activation_size

  reshaper_2:
    priority: 3.3
    type: ReshapeTensor
    input_dims: [-1, 100]
    output_dims: [-1, 1, 100]
    streams:
      inputs: question_image_activations
      outputs: question_image_activations_reshaped
    globals:
      output_size: question_image_activations_reshaped_size

  # Single layer GRU Decoder with attention
  decoder:
    type: AttentionDecoder
    priority: 4
    hidden_size: 100
    use_logsoftmax: False
    autoregression_length: 10 # Current implementation requires this value to be equal to fixed_padding in SentenceEmbeddings/Indexer...
    prediction_mode: Dense
    dropout_rate: 0.1
    streams:
      inputs: s2s_encoder_output
      predictions: s2s_decoder_output
      input_state: question_image_activations_reshaped
    globals:
      input_size: element_wise_activation_size
      prediction_size: element_wise_activation_size 

  # FF, to resize the from the output size of the seq2seq to the size of the target vector
  ff_resize_s2s_output:
    type: FeedForwardNetwork 
    use_logsoftmax: True
    dimensions: 3
    priority: 5
    dropout_rate: 0.1
    streams:
      inputs: s2s_decoder_output
    globals:
      input_size: element_wise_activation_size
      prediction_size: ans_vocabulary_size

# Loss
  nllloss:
    type: NLLLoss
    priority: 6
    num_targets_dims: 2
    streams:
      targets: indexed_answers
      loss: loss
    globals:
      ignore_index: ans_pad_index

  # Prediction decoding.
  prediction_decoder:
    priority: 10
    type: SentenceIndexer
    # Reverse mode.
    reverse: True
    # Use distributions as inputs.
    use_input_distributions: True
    data_folder: ~/data/vqa-med
    import_word_mappings_from_globals: True
    globals:
      word_mappings: ans_word_mappings
    streams:
      inputs: predictions
      outputs: prediction_sentences

  # Statistics.
  batch_size:
    type: BatchSizeStatistics
    priority: 100.0

  bleu:
    type: BLEUStatistics
    priority: 100.2
    globals:
      word_mappings: ans_word_mappings
    streams:
      targets: indexed_answers

      
  # Viewers.
  viewer:
    type: StreamViewer
    priority: 100.3
    input_streams: questions,answers,indexed_answers,prediction_sentences

#: pipeline