if self._steps_per_execution.numpy().item() == 1:

      def train_function(iterator):
        """Runs a training execution with one step."""
        return step_function(self, iterator)