tensorflow/models

View on GitHub
official/recommendation/uplift/metrics/poisson_metrics.py

Summary

Maintainability
B
4 hrs
Test Coverage
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Poisson regression metrics."""

from __future__ import annotations

from typing import Any

import tensorflow as tf, tf_keras

from official.recommendation.uplift import types
from official.recommendation.uplift.metrics import loss_metric
from official.recommendation.uplift.metrics import treatment_sliced_metric


@tf_keras.utils.register_keras_serializable(package="Uplift")
class LogLoss(loss_metric.LossMetric):
  """Computes the (weighted) poisson log loss sliced by treatment group.

  Given labels `y` and the model's predictions `x`, the loss is computed as:
  `loss = x - y * log(x) + [y * log(y) - y + 0.5 * log(2 * pi * y)]`

  Note that from a numerical perspective it is preferred to compute the loss
  from the model's logits as opposed to directly using its predictions, where
  `logits = log(x) = log(prediction)`. In this case the loss is computed as:
  `loss = exp(logits) - y * logits + [y * log(y) - y + 0.5 * log(2 * pi * y)]`

  Example standalone usage:

  >>> poisson_loss = poisson_metrics.LogLoss()
  >>> y_true = tf.constant([[1.0], [0.0]])
  >>> y_pred = types.TwoTowerTrainingOutputs(
      true_logits=tf.constant([[1.0], [0.0]]),
      is_treatment=tf.constant([[1], [0]]),
  )
  >>> poisson_loss(y_true=y_true, y_pred=y_pred)
  {
      "poisson_log_loss/treatment": 1.7182817  # exp(1) - 1 * 1
      "poisson_log_loss/control": 1.0  # exp(0) - 0 * 0
      "poisson_log_loss": 1.3591409  # (1.7182817 + 1.0) / 2
  }

  Example usage with the `model.compile()` API:

  >>> model.compile(
  ...     optimizer="sgd",
  ...     loss=TrueLogitsLoss(tf.nn.log_poisson_loss),
  ...     metrics=[poisson_metrics.LogLoss()]
  ... )
  """

  def __init__(
      self,
      from_logits: bool = True,
      compute_full_loss: bool = False,
      slice_by_treatment: bool = True,
      name: str = "poisson_log_loss",
      dtype: tf.DType = tf.float32,
  ):
    """Initializes the instance.

    Args:
      from_logits: When `y_pred` is of type `tf.Tensor`, specifies whether
        `y_pred` represents the model's logits or predictions. Otherwise, when
        `y_pred` is of type `TwoTowerTrainingOutputs`, set this to `True` in
        order to compute the loss using the true logits.
      compute_full_loss: Specifies whether the full log loss will be computed.
        If `True`, the expression `[y_true * log(y_true) - y_true + 0.5 * log(2
        * pi * y_true)]` will be added to the loss, otherwise the loss will be
        computed solely by the expression `[y_pred - y_true * log(y_pred)]`.
      slice_by_treatment: Specifies whether the loss should be sliced by the
        treatment indicator tensor. If `True`, the metric's result will return
        the loss values sliced by the treatment group. Note that this can only
        be set to `True` when `y_pred` is of type `TwoTowerTrainingOutputs`.
      name: Optional name for the instance.
      dtype: Optional data type for the instance.
    """
    super().__init__(
        loss_fn=tf.nn.log_poisson_loss,
        from_logits=from_logits,
        compute_full_loss=compute_full_loss,
        slice_by_treatment=slice_by_treatment,
        name=name,
        dtype=dtype,
    )

  def update_state(
      self,
      y_true: tf.Tensor,
      y_pred: types.TwoTowerTrainingOutputs | tf.Tensor,
      sample_weight: tf.Tensor | None = None,
  ):
    if not self._from_logits:
      if isinstance(y_pred, types.TwoTowerTrainingOutputs):
        raise ValueError(
            "`from_logits` must be set to `True` when `y_pred` is of type"
            " TwoTowerTrainingOutputs. Note that the true logits and true"
            " predictions are assumed to be linked to each other through the"
            " log link function: `true_logits = tf.math.log(true_predictions)."
        )
      y_pred = tf.math.log(y_pred)

    super().update_state(y_true, y_pred, sample_weight)

  def get_config(self) -> dict[str, Any]:
    config = super().get_config()
    del config["loss_fn"]
    return config

  @classmethod
  def from_config(cls, config: dict[str, Any]) -> LogLoss:
    return cls(**config)


def _safe_x_minus_xlogx(x: tf.Tensor) -> tf.Tensor:
  """Computes x - x * log(x) with 0 as its continuity point when x equals 0."""
  values = x * (1.0 - tf.math.log(x))
  return tf.where(tf.equal(x, 0.0), tf.zeros_like(x), values)


@tf_keras.utils.register_keras_serializable(package="Uplift")
class LogLossMeanBaseline(tf_keras.metrics.Metric):
  """Computes the (weighted) poisson log loss for a mean predictor."""

  def __init__(
      self,
      compute_full_loss: bool = False,
      slice_by_treatment: bool = True,
      name: str = "poisson_log_loss_mean_baseline",
      dtype: tf.DType = tf.float32,
  ):
    """Initializes the instance.

    Args:
      compute_full_loss: Specifies whether to compute the full poisson log loss
        for the mean predictor or not. Defaults to `False`.
      slice_by_treatment: Specifies whether the loss should be sliced by the
        treatment indicator tensor. If `True`, the metric's result will return
        the loss values sliced by the treatment group. Note that this can only
        be set to `True` when `y_pred` is of type `TwoTowerTrainingOutputs`.
      name: Optional name for the instance.
      dtype: Optional data type for the instance.
    """
    super().__init__(name=name, dtype=dtype)

    if compute_full_loss:
      raise NotImplementedError("Full loss computation is not yet supported.")

    self._compute_full_loss = compute_full_loss
    self._slice_by_treatment = slice_by_treatment

    if slice_by_treatment:
      self._mean_label = treatment_sliced_metric.TreatmentSlicedMetric(
          metric=tf_keras.metrics.Mean(name=name, dtype=dtype)
      )
    else:
      self._mean_label = tf_keras.metrics.Mean(name=name, dtype=dtype)

  def update_state(
      self,
      y_true: tf.Tensor,
      y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | None = None,
      sample_weight: tf.Tensor | None = None,
  ):
    is_treatment = {}
    if self._slice_by_treatment:
      if not isinstance(y_pred, types.TwoTowerTrainingOutputs):
        raise ValueError(
            "`slice_by_treatment` must be set to `False` when `y_pred` is not"
            " of type `TwoTowerTrainingOutputs`."
        )
      is_treatment["is_treatment"] = y_pred.is_treatment

    self._mean_label.update_state(
        y_true, sample_weight=sample_weight, **is_treatment
    )

  def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
    return tf.nest.map_structure(_safe_x_minus_xlogx, self._mean_label.result())

  def get_config(self) -> dict[str, Any]:
    config = super().get_config()
    config["compute_full_loss"] = self._compute_full_loss
    config["slice_by_treatment"] = self._slice_by_treatment
    return config

  @classmethod
  def from_config(cls, config: dict[str, Any]) -> LogLossMeanBaseline:
    return cls(**config)


@tf_keras.utils.register_keras_serializable(package="Uplift")
class LogLossMinimum(tf_keras.metrics.Metric):
  """Computes the minimum achievable (weighted) poisson log loss.

  Given labels `y` and the model's predictions `x`, the minimum loss is obtained
  when `x` equals `y`. In this case the loss is computed as:
  `loss = y - y * log(y) + [y * log(y) - y + 0.5 * log(2 * pi * y)]`

  Note that `[y * log(y) - y + 0.5 * log(2 * pi * y)]` is only computed if
  `compute_full_loss` is set to `True`.
  """

  def __init__(
      self,
      compute_full_loss: bool = False,
      slice_by_treatment: bool = True,
      name: str = "poisson_log_loss_minimum",
      dtype: tf.DType = tf.float32,
  ):
    """Initializes the instance.

    Args:
      compute_full_loss: Specifies whether to compute the full minimum log loss
        or not. Defaults to `False`.
      slice_by_treatment: Specifies whether the loss should be sliced by the
        treatment indicator tensor. If `True`, the metric's result will return
        the loss values sliced by the treatment group. Note that this can only
        be set to `True` when `y_pred` is of type `TwoTowerTrainingOutputs`.
      name: Optional name for the instance.
      dtype: Optional data type for the instance.
    """
    super().__init__(name=name, dtype=dtype)

    if compute_full_loss:
      raise NotImplementedError("Full loss computation is not yet supported.")

    self._compute_full_loss = compute_full_loss
    self._slice_by_treatment = slice_by_treatment

    if slice_by_treatment:
      self._loss = treatment_sliced_metric.TreatmentSlicedMetric(
          metric=tf_keras.metrics.Mean(name=name, dtype=dtype)
      )
    else:
      self._loss = tf_keras.metrics.Mean(name=name, dtype=dtype)

  def update_state(
      self,
      y_true: tf.Tensor,
      y_pred: types.TwoTowerTrainingOutputs | tf.Tensor | None = None,
      sample_weight: tf.Tensor | None = None,
  ):
    is_treatment = {}
    if self._slice_by_treatment:
      if not isinstance(y_pred, types.TwoTowerTrainingOutputs):
        raise ValueError(
            "`slice_by_treatment` must be set to `False` when `y_pred` is not"
            " of type `TwoTowerTrainingOutputs`."
        )
      is_treatment["is_treatment"] = y_pred.is_treatment

    self._loss.update_state(
        _safe_x_minus_xlogx(y_true), sample_weight=sample_weight, **is_treatment
    )

  def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
    return self._loss.result()

  def get_config(self) -> dict[str, Any]:
    config = super().get_config()
    config["compute_full_loss"] = self._compute_full_loss
    config["slice_by_treatment"] = self._slice_by_treatment
    return config

  @classmethod
  def from_config(cls, config: dict[str, Any]) -> LogLossMinimum:
    return cls(**config)


@tf_keras.utils.register_keras_serializable(package="Uplift")
class PseudoRSquared(tf_keras.metrics.Metric):
  """Computes the pseudo R-squared metric for poisson regression.

  The pseudo R-squared is computed from log likelihoods of three models:
  1) LLbaseline: log likelihood of a mean baseline predictor.
  2) LLfit: log likelihood of the fitted model.
  3) LLmax: maximum achievable log likelihood, which occurs when the predictions
  equal to the labels.

  The equation that computes the pseudo R-squared is:
  >>> R_squared = (LLfit - LLbaseline) / (LLmax - LLbaseline)
  """

  def __init__(
      self,
      from_logits: bool = True,
      slice_by_treatment: bool = True,
      name: str = "pseudo_r_squared",
      dtype: tf.DType = tf.float32,
  ):
    """Initializes the instance.

    Args:
      from_logits: When `y_pred` is of type `tf.Tensor`, specifies whether
        `y_pred` represents the model's logits or predictions. Otherwise, when
        `y_pred` is of type `TwoTowerTrainingOutputs`, set this to `True` in
        order to compute the loss using the true logits.
      slice_by_treatment: Specifies whether the loss should be sliced by the
        treatment indicator tensor. If `True`, the metric's result will return
        the loss values sliced by the treatment group. Note that this can only
        be set to `True` when `y_pred` is of type `TwoTowerTrainingOutputs`.
      name: Optional name for the instance.
      dtype: Optional data type for the instance.
    """
    super().__init__(name=name, dtype=dtype)

    self._from_logits = from_logits
    self._slice_by_treatment = slice_by_treatment

    # Since log_loss = -1 * log_likelihood we can just accumulate the losses.
    loss = LogLoss(
        from_logits=from_logits,
        compute_full_loss=False,
        slice_by_treatment=False,
        name=name,
        dtype=dtype,
    )
    minimum_loss = LogLossMinimum(
        compute_full_loss=False,
        slice_by_treatment=False,
        name=name,
        dtype=dtype,
    )
    mean_baseline_loss = LogLossMeanBaseline(
        compute_full_loss=False,
        slice_by_treatment=False,
        name=name,
        dtype=dtype,
    )

    if slice_by_treatment:
      self._model_loss = treatment_sliced_metric.TreatmentSlicedMetric(
          metric=loss
      )
      self._minimum_loss = treatment_sliced_metric.TreatmentSlicedMetric(
          metric=minimum_loss
      )
      self._mean_baseline_loss = treatment_sliced_metric.TreatmentSlicedMetric(
          metric=mean_baseline_loss
      )
    else:
      self._model_loss = loss
      self._minimum_loss = minimum_loss
      self._mean_baseline_loss = mean_baseline_loss

  def update_state(
      self,
      y_true: tf.Tensor,
      y_pred: types.TwoTowerTrainingOutputs | tf.Tensor,
      sample_weight: tf.Tensor | None = None,
  ):
    is_treatment = {}
    if self._slice_by_treatment:
      if not isinstance(y_pred, types.TwoTowerTrainingOutputs):
        raise ValueError(
            "`slice_by_treatment` must be set to `False` when `y_pred` is not"
            " of type `TwoTowerTrainingOutputs`."
        )
      is_treatment["is_treatment"] = y_pred.is_treatment

    self._model_loss.update_state(
        y_true, y_pred=y_pred, sample_weight=sample_weight, **is_treatment
    )
    self._minimum_loss.update_state(
        y_true, y_pred=y_pred, sample_weight=sample_weight, **is_treatment
    )
    self._mean_baseline_loss.update_state(
        y_true, y_pred=y_pred, sample_weight=sample_weight, **is_treatment
    )

  def result(self) -> tf.Tensor | dict[str, tf.Tensor]:
    def _pseudo_r_squared(
        loss_model: tf.Tensor, loss_baseline: tf.Tensor, loss_min: tf.Tensor
    ) -> tf.Tensor:
      return tf.math.divide_no_nan(
          loss_model - loss_baseline, loss_min - loss_baseline
      )

    return tf.nest.map_structure(
        _pseudo_r_squared,
        self._model_loss.result(),
        self._mean_baseline_loss.result(),
        self._minimum_loss.result(),
    )

  def get_config(self) -> dict[str, Any]:
    config = super().get_config()
    config["from_logits"] = self._from_logits
    config["slice_by_treatment"] = self._slice_by_treatment
    return config

  @classmethod
  def from_config(cls, config: dict[str, Any]) -> PseudoRSquared:
    return cls(**config)