rasa/core/training/story_conflict.py

Summary

Maintainability
A
3 hrs
Test Coverage
A
93%
from collections import defaultdict
import logging
import json
from typing import DefaultDict, Dict, Generator, List, NamedTuple, Optional, Text, Tuple

from rasa.core.featurizers.tracker_featurizers import MaxHistoryTrackerFeaturizer
from rasa.shared.core.constants import (
    ACTION_LISTEN_NAME,
    PREVIOUS_ACTION,
    ACTION_UNLIKELY_INTENT_NAME,
    USER,
)
from rasa.shared.core.domain import Domain, State
from rasa.shared.core.events import ActionExecuted, Event
from rasa.shared.core.generator import TrackerWithCachedStates

from rasa.nlu.tokenizers.tokenizer import Tokenizer
from rasa.shared.nlu.constants import TEXT
from rasa.shared.nlu.training_data.message import Message

logger = logging.getLogger(__name__)


class StoryConflict:
    """Represents a conflict between two or more stories.

    Here, a conflict means that different actions are supposed to follow from
    the same dialogue state, which most policies cannot learn.
    """

    def __init__(self, sliced_states: List[State]) -> None:
        """Creates a `StoryConflict` from a given state.

        Args:
            sliced_states: The (sliced) dialogue state at which the conflict occurs.
        """
        self._sliced_states = sliced_states
        # A list of actions that all follow from the same state.
        self._conflicting_actions: DefaultDict[Text, List[Text]] = defaultdict(
            list
        )  # {"action": ["story_1", ...], ...}

    def __hash__(self) -> int:
        return hash(str(list(self._sliced_states)))

    def add_conflicting_action(self, action: Text, story_name: Text) -> None:
        """Adds another action that follows from the same state.

        Args:
            action: Name of the action.
            story_name: Name of the story where this action is chosen.
        """
        self._conflicting_actions[action] += [story_name]

    @property
    def conflicting_actions(self) -> List[Text]:
        """List of conflicting actions.

        Returns:
            List of conflicting actions.

        """
        return list(self._conflicting_actions.keys())

    @property
    def conflict_has_prior_events(self) -> bool:
        """Checks if prior events exist.

        Returns:
            `True` if anything has happened before this conflict, otherwise `False`.
        """
        return _get_previous_event(self._sliced_states[-1])[0] is not None

    def __str__(self) -> Text:
        # Describe where the conflict occurs in the stories
        last_event_type, last_event_name = _get_previous_event(self._sliced_states[-1])
        if last_event_type:
            conflict_message = (
                f"Story structure conflict after {last_event_type} "
                f"'{last_event_name}':\n"
            )
        else:
            conflict_message = "Story structure conflict at the beginning of stories:\n"

        # List which stories are in conflict with one another
        for action, stories in self._conflicting_actions.items():
            conflict_message += (
                f"  {self._summarize_conflicting_actions(action, stories)}"
            )

        return conflict_message

    @staticmethod
    def _summarize_conflicting_actions(action: Text, stories: List[Text]) -> Text:
        """Gives a summarized textual description of where one action occurs.

        Args:
            action: The name of the action.
            stories: The stories in which the action occurs.

        Returns:
            A textural summary.
        """
        if len(stories) > 3:
            # Four or more stories are present
            conflict_description = (
                f"'{stories[0]}', '{stories[1]}', and {len(stories) - 2} other trackers"
            )
        elif len(stories) == 3:
            conflict_description = f"'{stories[0]}', '{stories[1]}', and '{stories[2]}'"
        elif len(stories) == 2:
            conflict_description = f"'{stories[0]}' and '{stories[1]}'"
        elif len(stories) == 1:
            conflict_description = f"'{stories[0]}'"
        else:
            raise ValueError(
                "An internal error occurred while trying to summarise a conflict "
                "without stories. Please file a bug report at "
                "https://github.com/RasaHQ/rasa."
            )

        return f"{action} predicted in {conflict_description}\n"


class TrackerEventStateTuple(NamedTuple):
    """Holds a tracker, an event, and sliced states associated with those."""

    tracker: TrackerWithCachedStates
    event: Event
    sliced_states: List[State]

    @property
    def sliced_states_hash(self) -> int:
        """Returns the hash of the sliced states."""
        return hash(json.dumps(self.sliced_states, sort_keys=True))


def find_story_conflicts(
    trackers: List[TrackerWithCachedStates],
    domain: Domain,
    max_history: Optional[int] = None,
) -> List[StoryConflict]:
    """Generates `StoryConflict` objects, describing conflicts in the given trackers.

    Args:
        trackers: Trackers in which to search for conflicts.
        domain: The domain.
        max_history: The maximum history length to be taken into account.

    Returns:
        StoryConflict objects.
    """
    if max_history:
        logger.info(
            f"Considering the preceding {max_history} turns for conflict analysis."
        )
    else:
        logger.info("Considering all preceding turns for conflict analysis.")

    # We do this in two steps, to reduce memory consumption:

    # Create a 'state -> list of actions' dict, where the state is
    # represented by its hash
    conflicting_state_action_mapping = _find_conflicting_states(
        trackers, domain, max_history
    )

    # Iterate once more over all states and note the (unhashed) state,
    # for which a conflict occurs
    conflicts = _build_conflicts_from_states(
        trackers, domain, max_history, conflicting_state_action_mapping
    )

    return conflicts


def _find_conflicting_states(
    trackers: List[TrackerWithCachedStates],
    domain: Domain,
    max_history: Optional[int],
    tokenizer: Optional[Tokenizer] = None,
) -> Dict[int, List[int]]:
    """Identifies all states from which different actions follow.

    Args:
        trackers: Trackers that contain the states.
        domain: The domain object.
        max_history: Number of turns to take into account for the state descriptions.
        tokenizer: A tokenizer to tokenize the user messages.

    Returns:
        A dictionary mapping state-hashes to a list of actions that follow from each
        state.
    """
    # Create a 'state -> list of actions' dict, where the state is
    # represented by its hash
    state_action_mapping: DefaultDict[int, List[int]] = defaultdict(list)

    for element in _sliced_states_iterator(trackers, domain, max_history, tokenizer):
        hashed_state = element.sliced_states_hash
        current_hash = hash(element.event)

        if current_hash not in state_action_mapping[
            hashed_state
        ] or _unlearnable_action(element.event):
            state_action_mapping[hashed_state] += [current_hash]

    # Keep only conflicting `state_action_mapping`s
    # or those mappings that contain `action_unlikely_intent`
    action_unlikely_intent_hash = hash(
        ActionExecuted(action_name=ACTION_UNLIKELY_INTENT_NAME)
    )
    return {
        state_hash: actions
        for (state_hash, actions) in state_action_mapping.items()
        if len(actions) > 1 or action_unlikely_intent_hash in actions
    }


def _unlearnable_action(event: Event) -> bool:
    """Identifies if the action cannot be learned by policies that use story data.

    Args:
        event: An event to be checked.

    Returns:
        `True` if the event can be learned, `False` otherwise.
    """
    return (
        isinstance(event, ActionExecuted)
        and event.action_name == ACTION_UNLIKELY_INTENT_NAME
    )


def _build_conflicts_from_states(
    trackers: List[TrackerWithCachedStates],
    domain: Domain,
    max_history: Optional[int],
    conflicting_state_action_mapping: Dict[int, List[int]],
    tokenizer: Optional[Tokenizer] = None,
) -> List["StoryConflict"]:
    """Builds a list of `StoryConflict` objects for each given conflict.

    Args:
        trackers: Trackers that contain the states.
        domain: The domain object.
        max_history: Number of turns to take into account for the state descriptions.
        conflicting_state_action_mapping: A dictionary mapping state-hashes to a list
            of actions that follow from each state.
        tokenizer: A tokenizer to tokenize the user messages.

    Returns:
        A list of `StoryConflict` objects that describe inconsistencies in the story
        structure. These objects also contain the history that leads up to the conflict.
    """
    # Iterate once more over all states and note the (unhashed) state,
    # for which a conflict occurs
    conflicts = {}
    for element in _sliced_states_iterator(trackers, domain, max_history, tokenizer):
        hashed_state = element.sliced_states_hash

        if hashed_state in conflicting_state_action_mapping:
            if hashed_state not in conflicts:
                conflicts[hashed_state] = StoryConflict(element.sliced_states)

            conflicts[hashed_state].add_conflicting_action(
                action=str(element.event), story_name=element.tracker.sender_id
            )

    # Return list of conflicts that arise from unpredictable actions
    # (actions that start the conversation)
    return [
        conflict
        for (hashed_state, conflict) in conflicts.items()
        if conflict.conflict_has_prior_events
    ]


def _sliced_states_iterator(
    trackers: List[TrackerWithCachedStates],
    domain: Domain,
    max_history: Optional[int],
    tokenizer: Optional[Tokenizer],
) -> Generator[TrackerEventStateTuple, None, None]:
    """Creates an iterator over sliced states.

    Iterate over all given trackers and all sliced states within each tracker,
    where the slicing is based on `max_history`.

    Args:
        trackers: List of trackers.
        domain: Domain (used for tracker.past_states).
        max_history: Assumed `max_history` value for slicing.
        tokenizer: A tokenizer to tokenize the user messages.

    Yields:
        A (tracker, event, sliced_states) triplet.
    """
    for tracker in trackers:
        states = tracker.past_states(domain)

        idx = 0
        for event in tracker.events:
            if isinstance(event, ActionExecuted):
                sliced_states = MaxHistoryTrackerFeaturizer.slice_state_history(
                    states[: idx + 1], max_history
                )
                if tokenizer:
                    _apply_tokenizer_to_states(tokenizer, sliced_states)
                # TODO: deal with oov (different tokens can lead to identical features
                # if some of those tokens are out of vocabulary for all featurizers)
                yield TrackerEventStateTuple(tracker, event, sliced_states)
                idx += 1


def _apply_tokenizer_to_states(tokenizer: Tokenizer, states: List[State]) -> None:
    """Split each user text into tokens and concatenate them again.

    Args:
        tokenizer: A tokenizer to tokenize the user messages.
        states: The states to be tokenized.
    """
    for state in states:
        if USER in state and TEXT in state[USER]:
            state[USER][TEXT] = " ".join(
                token.text
                for token in tokenizer.tokenize(
                    Message({TEXT: state[USER][TEXT]}), TEXT
                )
            )


def _get_previous_event(
    state: Optional[State],
) -> Tuple[Optional[Text], Optional[Text]]:
    """Returns previous event type and name.

    Returns the type and name of the event (action or intent) previous to the
    given state (excluding action_listen).

    Args:
        state: Element of sliced states.

    Returns:
        Tuple of (type, name) strings of the prior event.
    """
    previous_event_type = None
    previous_event_name = None

    # A typical state might be
    # `{'user': {'intent': 'greet'}, 'prev_action': {'action_name': 'action_listen'}}`.
    if not state:
        previous_event_type = None
        previous_event_name = None
    elif (
        PREVIOUS_ACTION in state.keys()
        and "action_name" in state[PREVIOUS_ACTION]
        and state[PREVIOUS_ACTION]["action_name"] != ACTION_LISTEN_NAME
    ):
        previous_event_type = "action"
        previous_event_name = state[PREVIOUS_ACTION]["action_name"]
    elif PREVIOUS_ACTION in state.keys() and "action_text" in state[PREVIOUS_ACTION]:
        previous_event_type = "bot utterance"
        previous_event_name = state[PREVIOUS_ACTION]["action_text"]
    elif USER in state.keys():
        if "intent" in state[USER]:
            previous_event_type = "intent"
            previous_event_name = state[USER]["intent"]
        elif "text" in state[USER]:
            previous_event_type = "user utterance"
            previous_event_name = state[USER]["text"]

    if not isinstance(previous_event_name, (str, type(None))):
        # While the Substate type doesn't restrict the value of `action_text` /
        # `intent`, etc. to be a string, it always should be
        raise TypeError(
            f"The value '{previous_event_name}' in the substate should be a string or "
            f"None, not {type(previous_event_name)}. Did you modify Rasa source code?"
        )

    return previous_event_type, previous_event_name