JohannesHeidecke/irl-benchmark

View on GitHub
irl_benchmark/metrics/avg_traj_return.py

Summary

Maintainability
A
0 mins
Test Coverage
"""Module for average trajectory return metric."""

from irl_benchmark.metrics.base_metric import BaseMetric
from irl_benchmark.utils.rl import true_reward_per_traj


class AverageTrajectoryReturn(BaseMetric):
    """Average sum of true rewards per trajectory."""

    def __init__(self, metric_input: dict):
        assert 'env' in metric_input.keys()
        super(AverageTrajectoryReturn, self).__init__(metric_input)

    def evaluate(self, evaluation_input: dict = None) -> float:
        """Evaluate the metric given some input and return result.

        Parameters
        ----------
        evaluation_input: dict
            The evaluation input. Requires 'irl_agent' field.

        Returns
        -------
        float
            The average sum of true rewards per trajectory.
        """
        assert 'irl_agent' in evaluation_input.keys()
        irl_trajs = self.generate_traj_if_not_exists(evaluation_input)
        return true_reward_per_traj(irl_trajs)