IBM/pytorchpipe

View on GitHub
configs/vqa_med_2019/question_categorization/question_categorization_onehot_bow.yml

Summary

Maintainability
Test Coverage
# Load config defining tasks for training, validation and testing.
default_configs: vqa_med_2019/question_categorization/default_question_categorization.yml

pipeline:

  # Questions encoding.
  question_tokenizer:
    type: SentenceTokenizer
    priority: 1.1
    streams: 
      inputs: questions
      outputs: tokenized_questions

  question_encoder:
    type: SentenceOneHotEncoder
    priority: 1.2
    data_folder: ~/data/vqa-med
    word_mappings_file: questions.all.word.mappings.csv
    export_word_mappings_to_globals: True
    streams:
      inputs: tokenized_questions
      outputs: encoded_questions
    globals:
      vocabulary_size: question_vocabulary_size

  bow_encoder:
    type: BOWEncoder
    priority: 1.3
    streams:
      inputs: encoded_questions
      outputs: bow_questions
    globals:
        bow_size: question_vocabulary_size # Set by question_encoder.

  # Model
  classifier:
    type: FeedForwardNetwork
    #freeze: True
    priority: 3
    streams:
      inputs: bow_questions
    globals:
      input_size: question_vocabulary_size # Set by question_encoder.
      prediction_size: num_categories # C1,C2,C3,C4

  # Predictions decoder.
  prediction_decoder:
    type: WordDecoder
    priority: 4
    # Use the same word mappings as label indexer.
    import_word_mappings_from_globals: True
    streams:
      inputs: predictions
      outputs: predicted_categories
    globals:
      vocabulary_size: num_categories
      word_mappings: category_word_mappings

  # Loss
  nllloss:
    type: NLLLoss
    priority: 6
    targets_dim: 1
    streams:
      targets: category_ids
      loss: loss

  # Statistics.
  accuracy:
    type: AccuracyStatistics
    priority: 10
    streams:
      targets: category_ids
  batch_size:
    type: BatchSizeStatistics
    priority: 11

  # Viewers.
  viewer:
    type: StreamViewer
    priority: 12
    stream_names: questions,category_names,predicted_categories

  #: pipeline