RasaHQ/rasa_core

View on GitHub
rasa/core/policies/form_policy.py

Summary

Maintainability
A
1 hr
Test Coverage
import logging
from typing import List, Optional, Dict, Text

from rasa.core.actions.action import ACTION_LISTEN_NAME
from rasa.core.domain import PREV_PREFIX, ACTIVE_FORM_PREFIX, Domain
from rasa.core.events import FormValidation
from rasa.core.featurizers import TrackerFeaturizer
from rasa.core.policies.memoization import MemoizationPolicy
from rasa.core.trackers import DialogueStateTracker

logger = logging.getLogger(__name__)


class FormPolicy(MemoizationPolicy):
    """Policy which handles prediction of Forms"""

    ENABLE_FEATURE_STRING_COMPRESSION = True

    def __init__(self,
                 featurizer: Optional[TrackerFeaturizer] = None,
                 priority: int = 4,
                 lookup: Optional[Dict] = None
                 ) -> None:

        # max history is set to 2 in order to capture
        # previous meaningful action before action listen
        super(FormPolicy, self).__init__(featurizer=featurizer,
                                         priority=priority,
                                         max_history=2,
                                         lookup=lookup)

    @staticmethod
    def _get_active_form_name(state):
        found_forms = [state_name[len(ACTIVE_FORM_PREFIX):]
                       for state_name, prob in state.items()
                       if ACTIVE_FORM_PREFIX in state_name and prob > 0]
        # by construction there is only one active form
        return found_forms[0] if found_forms else None

    @staticmethod
    def _prev_action_listen_in_state(state):
        return any(PREV_PREFIX + ACTION_LISTEN_NAME in state_name and prob > 0
                   for state_name, prob in state.items())

    @staticmethod
    def _modified_states(states):
        """Modify the states to
            - capture previous meaningful action before action_listen
            - ignore previous intent
        """
        if states[0] is None:
            action_before_listen = None
        else:
            action_before_listen = {state_name: prob
                                    for state_name, prob in states[0].items()
                                    if PREV_PREFIX in state_name and prob > 0}

        return [action_before_listen, states[-1]]

    def _add_states_to_lookup(self, trackers_as_states, trackers_as_actions,
                              domain, online=False):
        """Add states to lookup dict"""
        for states in trackers_as_states:
            active_form = self._get_active_form_name(states[-1])
            if active_form and self._prev_action_listen_in_state(states[-1]):
                # modify the states
                states = self._modified_states(states)
                feature_key = self._create_feature_key(states)
                # even if there are two identical feature keys
                # their form will be the same
                # because of `active_form_...` feature
                self.lookup[feature_key] = active_form

    def recall(self,
               states: List[Dict[Text, float]],
               tracker: DialogueStateTracker,
               domain: Domain
               ) -> Optional[int]:
        # modify the states
        return self._recall_states(self._modified_states(states))

    def state_is_unhappy(self, tracker, domain):
        # since it is assumed that training stories contain
        # only unhappy paths, notify the form that
        # it should not be validated if predicted by other policy
        tracker_as_states = self.featurizer.prediction_states(
            [tracker], domain)
        states = tracker_as_states[0]

        memorized_form = self.recall(states, tracker, domain)

        state_is_unhappy = (
            memorized_form is not None and
            memorized_form == tracker.active_form.get('name')
        )

        if state_is_unhappy:
            logger.debug("There is a memorized tracker state {}, "
                         "added `FormValidation(False)` event"
                         "".format(self._modified_states(states)))

        return state_is_unhappy

    def predict_action_probabilities(self,
                                     tracker: DialogueStateTracker,
                                     domain: Domain) -> List[float]:
        """Predicts the corresponding form action if there is an active form"""
        result = [0.0] * domain.num_actions

        if tracker.active_form.get('name'):
            logger.debug("There is an active form '{}'"
                         "".format(tracker.active_form['name']))
            if tracker.latest_action_name == ACTION_LISTEN_NAME:
                # predict form action after user utterance

                if tracker.active_form.get('rejected'):
                    if self.state_is_unhappy(tracker, domain):
                        tracker.update(FormValidation(False))
                        return result

                idx = domain.index_for_action(tracker.active_form['name'])
                result[idx] = 1.0

            elif tracker.latest_action_name == tracker.active_form.get('name'):
                # predict action_listen after form action
                idx = domain.index_for_action(ACTION_LISTEN_NAME)
                result[idx] = 1.0
        else:
            logger.debug("There is no active form")

        return result