IBM/pytorchpipe

View on GitHub
configs/wikitext/wikitext_language_modeling_rnn.yml

Summary

Maintainability
Test Coverage
# Training parameters:
training:
  task:
    type: &p_type WikiTextLanguageModeling
    data_folder: &data_folder ~/data/language_modeling/wikitext-2
    dataset: &dataset wikitext-2
    subset: train
    sentence_length: 10
    batch_size:  64

  # optimizer parameters:
  optimizer:
    type: Adam
    lr: 0.1

  # settings parameters
  terminal_conditions:
    loss_stop_threshold: 1.0e-2
    episode_limit: 10000
    epoch_limit: 100

# Validation parameters:
validation:
  partial_validation_interval: 10
  task:
    type: *p_type
    data_folder: *data_folder
    dataset: *dataset
    subset: valid
    sentence_length: 20
    batch_size:  64

# Testing parameters:
test:
  task:
    type: *p_type 
    data_folder: *data_folder
    dataset: *dataset
    subset: test
    sentence_length: 50
    batch_size: 64

pipeline:

  # Source encoding - model 1.
  source_sentence_embedding:
    type: SentenceEmbeddings
    priority: 1
    embeddings_size: 50
    pretrained_embeddings: glove.6B.50d.txt
    data_folder: *data_folder
    source_vocabulary_files: wiki.train.tokens,wiki.valid.tokens,wiki.test.tokens
    vocabulary_mappings_file: wiki.all.tokenized_words
    additional_tokens: <eos>
    export_word_mappings_to_globals: True
    streams:
      inputs: sources
      outputs: embedded_sources
        
  # Target encoding.
  target_indexer:
    type: SentenceIndexer
    priority: 2
    data_folder: *data_folder
    import_word_mappings_from_globals: True
    streams:
      inputs: targets
      outputs: indexed_targets
  
  # Model 2: RNN
  lstm:
    type: RecurrentNeuralNetwork
    priority: 3
    initial_state: Zero
    streams:
      inputs: embedded_sources
    globals:
      input_size: embeddings_size
      prediction_size: vocabulary_size 

  # Loss
  nllloss:
    type: NLLLoss
    priority: 6
    num_targets_dims: 2
    streams:
      targets: indexed_targets

  # Prediction decoding.
  prediction_decoder:
    type: SentenceIndexer
    priority: 10
    # Reverse mode.
    reverse: True
    # Use distributions as inputs.
    use_input_distributions: True
    data_folder: *data_folder
    import_word_mappings_from_globals: True
    streams:
      inputs: predictions
      outputs: prediction_sentences


  # Statistics.
  batch_size:
    type: BatchSizeStatistics
    priority: 100.0

  #accuracy:
  #  type: AccuracyStatistics
  #  priority: 100.1
  #  streams:
  #    targets: indexed_targets

  bleu:
    type: BLEUStatistics
    priority: 100.2
    streams:
      targets: indexed_targets

      
  # Viewers.
  viewer:
    type: StreamViewer
    priority: 100.3
    input_streams: sources,targets,indexed_targets,prediction_sentences

#: pipeline