tensorflow/models

View on GitHub
official/modeling/multitask/evaluator.py

Summary

Maintainability
A
2 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multitask Evaluator implementation.

The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from typing import Dict, List, Optional, Union
import gin
import orbit
import tensorflow as tf, tf_keras

from official.core import base_task
from official.core import train_utils
from official.modeling.multitask import base_model


@gin.configurable
class MultiTaskEvaluator(orbit.AbstractEvaluator):
  """Implements the common trainer shared for TensorFlow models."""

  def __init__(
      self,
      eval_tasks: List[base_task.Task],
      model: Union[tf_keras.Model, base_model.MultiTaskBaseModel],
      global_step: Optional[tf.Variable] = None,
      eval_steps: Optional[Dict[str, int]] = None,
      checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
    """Initialize common trainer for TensorFlow models.

    Args:
      eval_tasks: A list of tasks to evaluate.
      model: tf_keras.Model instance.
      global_step: the global step variable.
      eval_steps: a dictionary of steps to run eval keyed by task names.
      checkpoint_exporter: an object that has the `maybe_export_checkpoint`
        interface.
    """
    # Gets the current distribution strategy. If not inside any strategy scope,
    # it gets a single-replica no-op strategy.
    self._strategy = tf.distribute.get_strategy()
    self._tasks = eval_tasks
    self._model = model
    self._global_step = global_step or orbit.utils.create_global_step()
    self._checkpoint_exporter = checkpoint_exporter
    if hasattr(self.model, "checkpoint_items"):
      checkpoint_items = self.model.checkpoint_items
    else:
      checkpoint_items = {}

    self._checkpoint = tf.train.Checkpoint(
        model=self.model,
        global_step=self.global_step,
        **checkpoint_items)

    self._validation_losses = None
    self._validation_metrics = None

    # Builds per-task datasets.
    self.eval_datasets = {}
    self.eval_steps = eval_steps or {}
    for task in self.tasks:
      self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
          self.strategy, task.build_inputs, task.task_config.validation_data)

    # Builds per-task validation loops.
    def get_function(task_name, task):

      task_metrics = self.validation_metrics[task_name]
      task_loss = self.validation_losses[task_name]
      if isinstance(self.model, base_model.MultiTaskBaseModel):
        model = self.model.sub_tasks[task_name]
      else:
        model = self.model

      def step_fn(inputs):
        logs = task.validation_step(inputs, model=model, metrics=task_metrics)
        task_loss.update_state(logs[task.loss])
        return logs

      @tf.function
      def eval_step_fn(iterator):
        distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
        return tf.nest.map_structure(self.strategy.experimental_local_results,
                                     distributed_outputs)

      return orbit.utils.create_loop_fn(eval_step_fn)

    self.task_fns = {
        task.name: get_function(task.name, task) for task in self.tasks
    }

  @property
  def strategy(self):
    return self._strategy

  @property
  def tasks(self):
    return self._tasks

  @property
  def model(self):
    return self._model

  @property
  def global_step(self):
    return self._global_step

  @property
  def validation_losses(self):
    """Accesses the validation loss metric object."""
    if self._validation_losses is None:
      # Builds the per-task metrics and losses.
      self._validation_losses = {}
      for task in self.tasks:
        self._validation_losses[task.name] = tf_keras.metrics.Mean(
            "validation_loss", dtype=tf.float32)
    return self._validation_losses

  @property
  def validation_metrics(self):
    """Accesses all validation metric metric objects."""
    if self._validation_metrics is None:
      # Builds the per-task metrics and losses.
      self._validation_metrics = {}
      for task in self.tasks:
        self._validation_metrics[task.name] = task.build_metrics(training=False)
    return self._validation_metrics

  @property
  def checkpoint(self):
    """Accesses the training checkpoint."""
    return self._checkpoint

  def evaluate(self, num_steps: tf.Tensor):
    """Performs evaluation for each `EvalTask`."""
    for metric in self.validation_losses.values():
      metric.reset_states()
    for metrics in self.validation_metrics.values():
      for metric in metrics:
        metric.reset_states()
    results = {}
    eval_iters = tf.nest.map_structure(iter, self.eval_datasets)

    for task in self.tasks:
      outputs = None
      name = task.name
      eval_iter = eval_iters[name]
      task_eval_steps = self.eval_steps.get(name, None) or num_steps
      outputs = self.task_fns[name](
          eval_iter,
          task_eval_steps,
          state=outputs,
          reduce_fn=task.aggregate_logs)
      task_metrics = self.validation_metrics[name]
      task_loss = self.validation_losses[name]
      logs = {}
      for metric in task_metrics + [task_loss]:
        logs[metric.name] = metric.result()
      if outputs:
        metrics = task.reduce_aggregated_logs(
            outputs, global_step=self.global_step)
        logs.update(metrics)
      results[name] = logs

    if self._checkpoint_exporter:
      self._checkpoint_exporter.maybe_export_checkpoint(
          self.checkpoint, results, self.global_step.numpy())
    return results