if self._steps_per_execution.numpy().item() == 1:

      def test_function(iterator):
        """Runs an evaluation execution with one step."""
        return step_function(self, iterator)