tensorflow/models

View on GitHub
research/cvt_text/model/multitask_model.py

Summary

Maintainability
A
35 mins
Test Coverage
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A multi-task and semi-supervised NLP model."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from model import encoder
from model import shared_inputs


class Inference(object):
  def __init__(self, config, inputs, pretrained_embeddings, tasks):
    with tf.variable_scope('encoder'):
      self.encoder = encoder.Encoder(config, inputs, pretrained_embeddings)
    self.modules = {}
    for task in tasks:
      with tf.variable_scope(task.name):
        self.modules[task.name] = task.get_module(inputs, self.encoder)


class Model(object):
  def __init__(self, config, pretrained_embeddings, tasks):
    self._config = config
    self._tasks = tasks

    self._global_step, self._optimizer = self._get_optimizer()
    self._inputs = shared_inputs.Inputs(config)
    with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope:
      inference = Inference(config, self._inputs, pretrained_embeddings,
                            tasks)
      self._trainer = inference
      self._tester = inference
      self._teacher = inference
      if config.ema_test or config.ema_teacher:
        ema = tf.train.ExponentialMovingAverage(config.ema_decay)
        model_vars = tf.get_collection("trainable_variables", "model")
        ema_op = ema.apply(model_vars)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)

        def ema_getter(getter, name, *args, **kwargs):
          var = getter(name, *args, **kwargs)
          return ema.average(var)

        scope.set_custom_getter(ema_getter)
        inference_ema = Inference(
            config, self._inputs, pretrained_embeddings, tasks)
        if config.ema_teacher:
          self._teacher = inference_ema
        if config.ema_test:
          self._tester = inference_ema

    self._unlabeled_loss = self._get_consistency_loss(tasks)
    self._unlabeled_train_op = self._get_train_op(self._unlabeled_loss)
    self._labeled_train_ops = {}
    for task in self._tasks:
      task_loss = self._trainer.modules[task.name].supervised_loss
      self._labeled_train_ops[task.name] = self._get_train_op(task_loss)

  def _get_consistency_loss(self, tasks):
    return sum([self._trainer.modules[task.name].unsupervised_loss
                for task in tasks])

  def _get_optimizer(self):
    global_step = tf.get_variable('global_step', initializer=0, trainable=False)
    warm_up_multiplier = (tf.minimum(tf.to_float(global_step),
                                     self._config.warm_up_steps)
                          / self._config.warm_up_steps)
    decay_multiplier = 1.0 / (1 + self._config.lr_decay *
                              tf.sqrt(tf.to_float(global_step)))
    lr = self._config.lr * warm_up_multiplier * decay_multiplier
    optimizer = tf.train.MomentumOptimizer(lr, self._config.momentum)
    return global_step, optimizer

  def _get_train_op(self, loss):
    grads, vs = zip(*self._optimizer.compute_gradients(loss))
    grads, _ = tf.clip_by_global_norm(grads, self._config.grad_clip)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
      return self._optimizer.apply_gradients(
          zip(grads, vs), global_step=self._global_step)

  def _create_feed_dict(self, mb, model, is_training=True):
    feed = self._inputs.create_feed_dict(mb, is_training)
    if mb.task_name in model.modules:
      model.modules[mb.task_name].update_feed_dict(feed, mb)
    else:
      for module in model.modules.values():
        module.update_feed_dict(feed, mb)
    return feed

  def train_unlabeled(self, sess, mb):
    return sess.run([self._unlabeled_train_op, self._unlabeled_loss],
                    feed_dict=self._create_feed_dict(mb, self._trainer))[1]

  def train_labeled(self, sess, mb):
    return sess.run([self._labeled_train_ops[mb.task_name],
                     self._trainer.modules[mb.task_name].supervised_loss,],
                    feed_dict=self._create_feed_dict(mb, self._trainer))[1]

  def run_teacher(self, sess, mb):
    result = sess.run({task.name: self._teacher.modules[task.name].probs
                       for task in self._tasks},
                      feed_dict=self._create_feed_dict(mb, self._teacher,
                                                       False))
    for task_name, probs in result.iteritems():
      mb.teacher_predictions[task_name] = probs.astype('float16')

  def test(self, sess, mb):
    return sess.run(
        [self._tester.modules[mb.task_name].supervised_loss,
         self._tester.modules[mb.task_name].preds],
        feed_dict=self._create_feed_dict(mb, self._tester, False))

  def get_global_step(self, sess):
    return sess.run(self._global_step)