rasa/shared/core/training_data/visualization.py

Summary

Maintainability
D
3 days
Test Coverage
B
86%
from collections import defaultdict, deque

import random
from typing import (
    Any,
    Text,
    List,
    Deque,
    Dict,
    Optional,
    Set,
    TYPE_CHECKING,
    Union,
    cast,
)

import rasa.shared.utils.io
from rasa.shared.constants import INTENT_MESSAGE_PREFIX
from rasa.shared.core.constants import ACTION_LISTEN_NAME
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import UserUttered, ActionExecuted, Event
from rasa.shared.core.generator import TrainingDataGenerator
from rasa.shared.core.training_data.structures import StoryGraph, StoryStep
from rasa.shared.nlu.constants import (
    ENTITY_ATTRIBUTE_VALUE,
    INTENT,
    TEXT,
    ENTITY_ATTRIBUTE_TYPE,
    INTENT_NAME_KEY,
)

if TYPE_CHECKING:
    from rasa.shared.nlu.training_data.training_data import TrainingData
    from rasa.shared.nlu.training_data.message import Message
    import networkx

EDGE_NONE_LABEL = "NONE"

START_NODE_ID = 0
END_NODE_ID = -1
TMP_NODE_ID = -2

VISUALIZATION_TEMPLATE_PATH = "/visualization.html"


class UserMessageGenerator:
    def __init__(self, nlu_training_data: "TrainingData") -> None:
        self.nlu_training_data = nlu_training_data
        self.mapping = self._create_reverse_mapping(self.nlu_training_data)

    @staticmethod
    def _create_reverse_mapping(
        data: "TrainingData",
    ) -> Dict[Dict[Text, Any], List["Message"]]:
        """Create a mapping from intent to messages.

        This allows a faster intent lookup.
        """
        d = defaultdict(list)
        for example in data.training_examples:
            if example.get(INTENT, {}) is not None:
                d[example.get(INTENT, {})].append(example)
        return d

    @staticmethod
    def _contains_same_entity(entities: Dict[Text, Any], e: Dict[Text, Any]) -> bool:
        return entities.get(e.get(ENTITY_ATTRIBUTE_TYPE)) is None or entities.get(
            e.get(ENTITY_ATTRIBUTE_TYPE)
        ) != e.get(ENTITY_ATTRIBUTE_VALUE)

    def message_for_data(self, structured_info: Dict[Text, Any]) -> Any:
        """Find a data sample with the same intent."""
        if structured_info.get(INTENT) is not None:
            intent_name = structured_info.get(INTENT, {}).get(INTENT_NAME_KEY)
            usable_examples = self.mapping.get(intent_name, [])[:]
            random.shuffle(usable_examples)

            if usable_examples:
                return usable_examples[0].get(TEXT)

        return structured_info.get(TEXT)


def _fingerprint_node(
    graph: "networkx.MultiDiGraph", node: int, max_history: int
) -> Set[Text]:
    """Fingerprint a node in a graph.

    Can be used to identify nodes that are similar and can be merged within the
    graph.
    Generates all paths starting at `node` following the directed graph up to
    the length of `max_history`, and returns a set of strings describing the
    found paths. If the fingerprint creation for two nodes results in the same
    sets these nodes are indistinguishable if we walk along the path and only
    remember max history number of nodes we have visited. Hence, if we randomly
    walk on our directed graph, always only remembering the last `max_history`
    nodes we have visited, we can never remember if we have visited node A or
    node B if both have the same fingerprint.
    """
    # the candidate list contains all node paths that haven't been
    # extended till `max_history` length yet.
    candidates: Deque = deque()
    candidates.append([node])
    continuations = []
    while len(candidates) > 0:
        candidate = candidates.pop()
        last = candidate[-1]
        empty = True
        for _, succ_node in graph.out_edges(last):
            next_candidate = candidate[:]
            next_candidate.append(succ_node)
            # if the path is already long enough, we add it to the results,
            # otherwise we add it to the candidates
            # that we still need to visit
            if len(next_candidate) == max_history:
                continuations.append(next_candidate)
            else:
                candidates.append(next_candidate)
            empty = False
        if empty:
            continuations.append(candidate)
    return {
        " - ".join([graph.nodes[node]["label"] for node in continuation])
        for continuation in continuations
    }


def _incoming_edges(graph: "networkx.MultiDiGraph", node: int) -> set:
    return {(prev_node, k) for prev_node, _, k in graph.in_edges(node, keys=True)}


def _outgoing_edges(graph: "networkx.MultiDiGraph", node: int) -> set:
    return {(succ_node, k) for _, succ_node, k in graph.out_edges(node, keys=True)}


def _outgoing_edges_are_similar(
    graph: "networkx.MultiDiGraph", node_a: int, node_b: int
) -> bool:
    """If the outgoing edges from the two nodes are similar enough,
    it doesn't matter if you are in a or b.

    As your path will be the same because the outgoing edges will lead you to
    the same nodes anyways.
    """
    ignored = {node_b, node_a}
    a_edges = {
        (target, k)
        for target, k in _outgoing_edges(graph, node_a)
        if target not in ignored
    }
    b_edges = {
        (target, k)
        for target, k in _outgoing_edges(graph, node_b)
        if target not in ignored
    }
    return a_edges == b_edges or not a_edges or not b_edges


def _nodes_are_equivalent(
    graph: "networkx.MultiDiGraph", node_a: int, node_b: int, max_history: int
) -> bool:
    """Decides if two nodes are equivalent based on their fingerprints."""
    return graph.nodes[node_a]["label"] == graph.nodes[node_b]["label"] and (
        _outgoing_edges_are_similar(graph, node_a, node_b)
        or _incoming_edges(graph, node_a) == _incoming_edges(graph, node_b)
        or _fingerprint_node(graph, node_a, max_history)
        == _fingerprint_node(graph, node_b, max_history)
    )


def _add_edge(
    graph: "networkx.MultiDiGraph",
    u: int,
    v: int,
    key: Optional[Text],
    label: Optional[Text] = None,
    **kwargs: Any,
) -> None:
    """Adds an edge to the graph if the edge is not already present. Uses the
    label as the key.
    """
    if key is None:
        key = EDGE_NONE_LABEL

    if key == EDGE_NONE_LABEL:
        label = ""

    if not graph.has_edge(u, v, key=EDGE_NONE_LABEL):
        graph.add_edge(u, v, key=key, label=label, **kwargs)
    else:
        d = graph.get_edge_data(u, v, key=EDGE_NONE_LABEL)
        _transfer_style(kwargs, d)


def _transfer_style(
    source: Dict[Text, Any], target: Dict[Text, Any]
) -> Dict[Text, Any]:
    """Copy over class names from source to target for all special classes.

    Used if a node is highlighted and merged with another node.
    """
    clazzes = source.get("class", "")

    special_classes = {"dashed", "active"}

    if "class" not in target:
        target["class"] = ""

    for c in special_classes:
        if c in clazzes and c not in target["class"]:
            target["class"] += " " + c

    target["class"] = target["class"].strip()
    return target


def _merge_equivalent_nodes(graph: "networkx.MultiDiGraph", max_history: int) -> None:
    """Searches for equivalent nodes in the graph and merges them."""
    changed = True
    # every node merge changes the graph and can trigger previously
    # impossible node merges - we need to repeat until
    # the graph doesn't change anymore
    while changed:
        changed = False
        remaining_node_ids = [n for n in graph.nodes() if n > 0]
        for idx, i in enumerate(remaining_node_ids):
            if graph.has_node(i):
                # assumes node equivalence is cumulative
                for j in remaining_node_ids[idx + 1 :]:
                    if graph.has_node(j) and _nodes_are_equivalent(
                        graph, i, j, max_history
                    ):
                        # make sure we keep special styles
                        _transfer_style(
                            graph.nodes(data=True)[j], graph.nodes(data=True)[i]
                        )

                        changed = True
                        # moves all outgoing edges to the other node
                        j_outgoing_edges = list(
                            graph.out_edges(j, keys=True, data=True)
                        )
                        for _, succ_node, k, d in j_outgoing_edges:
                            _add_edge(
                                graph,
                                i,
                                succ_node,
                                k,
                                d.get("label"),
                                **{"class": d.get("class", "")},
                            )
                            graph.remove_edge(j, succ_node)
                        # moves all incoming edges to the other node
                        j_incoming_edges = list(graph.in_edges(j, keys=True, data=True))
                        for prev_node, _, k, d in j_incoming_edges:
                            _add_edge(
                                graph,
                                prev_node,
                                i,
                                k,
                                d.get("label"),
                                **{"class": d.get("class", "")},
                            )
                            graph.remove_edge(prev_node, j)
                        graph.remove_node(j)


def _replace_edge_labels_with_nodes(
    graph: "networkx.MultiDiGraph", next_id: int, nlu_training_data: "TrainingData"
) -> None:
    """Replaces edge labels with nodes.

    User messages are created as edge labels. This removes the labels and
    creates nodes instead.

    The algorithms (e.g. merging) are simpler if the user messages are labels
    on the edges. But it sometimes
    looks better if in the final graphs the user messages are nodes instead
    of edge labels.
    """
    if nlu_training_data:
        message_generator = UserMessageGenerator(nlu_training_data)
    else:
        message_generator = None

    edges = list(graph.edges(keys=True, data=True))
    for s, e, k, d in edges:
        if k != EDGE_NONE_LABEL:
            label = d.get("label", k)

            if message_generator:
                parsed_info = {TEXT: label}
                if label.startswith(INTENT_MESSAGE_PREFIX):
                    parsed_info[INTENT] = {INTENT_NAME_KEY: label[1:]}

                label = message_generator.message_for_data(parsed_info)
            next_id += 1
            graph.remove_edge(s, e, k)
            graph.add_node(
                next_id,
                label=label,
                shape="rect",
                style="filled",
                fillcolor="lightblue",
                **_transfer_style(d, {"class": "intent"}),
            )
            graph.add_edge(s, next_id, **{"class": d.get("class", "")})
            graph.add_edge(next_id, e, **{"class": d.get("class", "")})


def visualization_html_path() -> Text:
    import pkg_resources

    return pkg_resources.resource_filename(__name__, VISUALIZATION_TEMPLATE_PATH)


def persist_graph(graph: "networkx.Graph", output_file: Text) -> None:
    """Plots the graph and persists it into a html file."""
    import networkx as nx

    expg = nx.nx_pydot.to_pydot(graph)

    template = rasa.shared.utils.io.read_file(visualization_html_path())

    # Insert graph into template
    template = template.replace("// { is-client }", "isClient = true", 1)
    graph_as_text = expg.to_string()
    # escape backslashes
    graph_as_text = graph_as_text.replace("\\", "\\\\")
    template = template.replace("// { graph-content }", f"graph = `{graph_as_text}`", 1)

    rasa.shared.utils.io.write_text_file(template, output_file)


def _length_of_common_action_prefix(this: List[Event], other: List[Event]) -> int:
    """Calculate number of actions that two conversations have in common."""
    num_common_actions = 0
    t_cleaned = cast(
        List[Union[ActionExecuted, UserUttered]],
        [e for e in this if e.type_name in {"user", "action"}],
    )
    o_cleaned = cast(
        List[Union[ActionExecuted, UserUttered]],
        [e for e in other if e.type_name in {"user", "action"}],
    )

    for i, e in enumerate(t_cleaned):
        o = o_cleaned[i]
        if i == len(o_cleaned):
            break
        elif isinstance(e, UserUttered) and isinstance(o, UserUttered):
            continue
        elif (
            isinstance(e, ActionExecuted)
            and isinstance(o, ActionExecuted)
            and o.action_name == e.action_name
        ):
            num_common_actions += 1
        else:
            break
    return num_common_actions


def _add_default_nodes(graph: "networkx.MultiDiGraph", fontsize: int = 12) -> None:
    """Add the standard nodes we need."""
    graph.add_node(
        START_NODE_ID,
        label="START",
        fillcolor="green",
        style="filled",
        fontsize=fontsize,
        **{"class": "start active"},
    )
    graph.add_node(
        END_NODE_ID,
        label="END",
        fillcolor="red",
        style="filled",
        fontsize=fontsize,
        **{"class": "end"},
    )
    graph.add_node(TMP_NODE_ID, label="TMP", style="invis", **{"class": "invisible"})


def _create_graph(fontsize: int = 12) -> "networkx.MultiDiGraph":
    """Create a graph and adds the default nodes."""
    import networkx as nx

    graph = nx.MultiDiGraph()
    _add_default_nodes(graph, fontsize)
    return graph


def _add_message_edge(
    graph: "networkx.MultiDiGraph",
    message: Optional[Dict[Text, Any]],
    current_node: int,
    next_node_idx: int,
    is_current: bool,
) -> None:
    """Create an edge based on the user message."""
    if message:
        message_key = message.get("intent", {}).get("name", None)
        message_label = message.get("text", None)
    else:
        message_key = None
        message_label = None

    _add_edge(
        graph,
        current_node,
        next_node_idx,
        message_key,
        message_label,
        **{"class": "active" if is_current else ""},
    )


def visualize_neighborhood(
    current: Optional[List[Event]],
    event_sequences: List[List[Event]],
    output_file: Optional[Text] = None,
    max_history: int = 2,
    nlu_training_data: Optional["TrainingData"] = None,
    should_merge_nodes: bool = True,
    max_distance: int = 1,
    fontsize: int = 12,
) -> "networkx.MultiDiGraph":
    """Given a set of event lists, visualizing the flows."""
    graph = _create_graph(fontsize)
    _add_default_nodes(graph)

    next_node_idx = START_NODE_ID
    special_node_idx = -3
    path_ellipsis_ends = set()

    for events in event_sequences:
        if current and max_distance:
            prefix = _length_of_common_action_prefix(current, events)
        else:
            prefix = len(events)

        message = None
        current_node = START_NODE_ID
        idx = 0
        is_current = events == current

        for idx, el in enumerate(events):
            if not prefix:
                idx -= 1
                break
            if isinstance(el, UserUttered):
                message = el.parse_data
                message[TEXT] = f"{INTENT_MESSAGE_PREFIX}{el.intent_name}"  # type: ignore[literal-required]  # noqa: E501
            elif (
                isinstance(el, ActionExecuted) and el.action_name != ACTION_LISTEN_NAME
            ):
                next_node_idx += 1
                graph.add_node(
                    next_node_idx,
                    label=el.action_name,
                    fontsize=fontsize,
                    **{"class": "active" if is_current else ""},
                )

                _add_message_edge(
                    graph, message, current_node, next_node_idx, is_current
                )
                current_node = next_node_idx

                message = None
                prefix -= 1

        # determine what the end node of the conversation is going to be
        # this can either be an ellipsis "...", the conversation end node
        # "END" or a "TMP" node if this is the active conversation
        if is_current:
            event_idx = events[idx]
            if (
                isinstance(event_idx, ActionExecuted)
                and event_idx.action_name == ACTION_LISTEN_NAME
            ):
                next_node_idx += 1
                if message is None:
                    label = "  ?  "
                else:
                    intent = cast(dict, message).get("intent", {})
                    label = intent.get("name", "  ?  ")
                graph.add_node(
                    next_node_idx,
                    label=label,
                    shape="rect",
                    **{"class": "intent dashed active"},
                )
                target = next_node_idx
            elif current_node:
                d = graph.nodes(data=True)[current_node]
                d["class"] = "dashed active"
                target = TMP_NODE_ID
            else:
                target = TMP_NODE_ID
        elif idx == len(events) - 1:
            target = END_NODE_ID
        elif current_node and current_node not in path_ellipsis_ends:
            graph.add_node(special_node_idx, label="...", **{"class": "ellipsis"})
            target = special_node_idx
            path_ellipsis_ends.add(current_node)
            special_node_idx -= 1
        else:
            target = END_NODE_ID

        _add_message_edge(graph, message, current_node, target, is_current)

    if should_merge_nodes:
        _merge_equivalent_nodes(graph, max_history)
    _replace_edge_labels_with_nodes(graph, next_node_idx, nlu_training_data)

    _remove_auxiliary_nodes(graph, special_node_idx)

    if output_file:
        persist_graph(graph, output_file)
    return graph


def _remove_auxiliary_nodes(
    graph: "networkx.MultiDiGraph", special_node_idx: int
) -> None:
    """Remove any temporary or unused nodes."""
    graph.remove_node(TMP_NODE_ID)

    if not graph.predecessors(END_NODE_ID):
        graph.remove_node(END_NODE_ID)

    # remove duplicated "..." nodes after merging
    predecessors_seen = set()
    for i in range(special_node_idx + 1, TMP_NODE_ID):
        predecessors = graph.predecessors(i)
        for pred in predecessors:
            if pred in predecessors_seen:
                graph.remove_node(i)
        predecessors_seen.update(predecessors)


def visualize_stories(
    story_steps: List[StoryStep],
    domain: Domain,
    output_file: Optional[Text],
    max_history: int,
    nlu_training_data: Optional["TrainingData"] = None,
    should_merge_nodes: bool = True,
    fontsize: int = 12,
) -> "networkx.MultiDiGraph":
    """Given a set of stories, generates a graph visualizing the flows in the stories.

    Visualization is always a trade off between making the graph as small as
    possible while
    at the same time making sure the meaning doesn't change to "much". The
    algorithm will
    compress the graph generated from the stories to merge nodes that are
    similar. Hence,
    the algorithm might create paths through the graph that aren't actually
    specified in the
    stories, but we try to minimize that.

    Output file defines if and where a file containing the plotted graph
    should be stored.

    The history defines how much 'memory' the graph has. This influences in
    which situations the
    algorithm will merge nodes. Nodes will only be merged if they are equal
    within the history, this
    means the larger the history is we take into account the less likely it
    is we merge any nodes.

    The training data parameter can be used to pass in a Rasa NLU training
    data instance. It will
    be used to replace the user messages from the story file with actual
    messages from the training data.
    """
    story_graph = StoryGraph(story_steps)

    g = TrainingDataGenerator(
        story_graph,
        domain,
        use_story_concatenation=False,
        tracker_limit=100,
        augmentation_factor=0,
    )
    completed_trackers = g.generate()
    event_sequences = [t.events for t in completed_trackers]

    graph = visualize_neighborhood(
        None,
        event_sequences,
        output_file,
        max_history,
        nlu_training_data,
        should_merge_nodes,
        max_distance=1,
        fontsize=fontsize,
    )
    return graph