def action(self, state, context=None):
    """Returns the next action for the state.

    Args:
      state: A [num_state_dims] tensor representing a state.