rasa/core/processor.py
import json
import logging
from types import LambdaType
from typing import Any, Dict, List, Optional, Text, Tuple
import numpy as np
import time
from rasa.core import jobs
from rasa.core.actions import Action
from rasa.core.actions.action import (
ACTION_LISTEN_NAME,
ActionExecutionRejection)
from rasa.core.channels import CollectingOutputChannel, UserMessage
from rasa.core.constants import (
ACTION_NAME_SENDER_ID_CONNECTOR_STR,
USER_INTENT_RESTART)
from rasa.core.dispatcher import Dispatcher
from rasa.core.domain import Domain
from rasa.core.events import (
ActionExecuted, ActionExecutionRejected,
BotUttered, Event, ReminderCancelled, ReminderScheduled, SlotSet,
UserUttered)
from rasa.core.interpreter import (
INTENT_MESSAGE_PREFIX,
NaturalLanguageInterpreter, RegexInterpreter)
from rasa.core.nlg import NaturalLanguageGenerator
from rasa.core.policies.ensemble import PolicyEnsemble
from rasa.core.tracker_store import TrackerStore
from rasa.core.trackers import DialogueStateTracker, EventVerbosity
from rasa.core.utils import EndpointConfig
logger = logging.getLogger(__name__)
class MessageProcessor(object):
def __init__(self,
interpreter: NaturalLanguageInterpreter,
policy_ensemble: PolicyEnsemble,
domain: Domain,
tracker_store: TrackerStore,
generator: NaturalLanguageGenerator,
action_endpoint: Optional[EndpointConfig] = None,
max_number_of_predictions: int = 10,
message_preprocessor: Optional[LambdaType] = None,
on_circuit_break: Optional[LambdaType] = None,
):
self.interpreter = interpreter
self.nlg = generator
self.policy_ensemble = policy_ensemble
self.domain = domain
self.tracker_store = tracker_store
self.max_number_of_predictions = max_number_of_predictions
self.message_preprocessor = message_preprocessor
self.on_circuit_break = on_circuit_break
self.action_endpoint = action_endpoint
async def handle_message(self,
message: UserMessage) -> Optional[List[Text]]:
"""Handle a single message with this processor."""
# preprocess message if necessary
tracker = await self.log_message(message)
if not tracker:
return None
await self._predict_and_execute_next_action(message, tracker)
# save tracker state to continue conversation from this state
self._save_tracker(tracker)
if isinstance(message.output_channel, CollectingOutputChannel):
return message.output_channel.messages
else:
return None
def predict_next(self, sender_id: Text) -> Optional[Dict[Text, Any]]:
# we have a Tracker instance for each user
# which maintains conversation state
tracker = self._get_tracker(sender_id)
if not tracker:
logger.warning("Failed to retrieve or create tracker for sender "
"'{}'.".format(sender_id))
return None
probabilities, policy = \
self._get_next_action_probabilities(tracker)
# save tracker state to continue conversation from this state
self._save_tracker(tracker)
scores = [{"action": a, "score": p}
for a, p in zip(self.domain.action_names, probabilities)]
return {
"scores": scores,
"policy": policy,
"confidence": np.max(probabilities),
"tracker": tracker.current_state(EventVerbosity.AFTER_RESTART)
}
async def log_message(self,
message: UserMessage
) -> Optional[DialogueStateTracker]:
# preprocess message if necessary
if self.message_preprocessor is not None:
message.text = self.message_preprocessor(message.text)
# we have a Tracker instance for each user
# which maintains conversation state
tracker = self._get_tracker(message.sender_id)
if tracker:
await self._handle_message_with_tracker(message, tracker)
# save tracker state to continue conversation from this state
self._save_tracker(tracker)
else:
logger.warning("Failed to retrieve or create tracker for sender "
"'{}'.".format(message.sender_id))
return tracker
async def execute_action(self,
sender_id: Text,
action_name: Text,
dispatcher: Dispatcher,
policy: Text,
confidence: float
) -> Optional[DialogueStateTracker]:
# we have a Tracker instance for each user
# which maintains conversation state
tracker = self._get_tracker(sender_id)
if tracker:
action = self._get_action(action_name)
await self._run_action(action, tracker, dispatcher, policy,
confidence)
# save tracker state to continue conversation from this state
self._save_tracker(tracker)
else:
logger.warning("Failed to retrieve or create tracker for sender "
"'{}'.".format(sender_id))
return tracker
def predict_next_action(self,
tracker: DialogueStateTracker
) -> Tuple[Action, Text, float]:
"""Predicts the next action the bot should take after seeing x.
This should be overwritten by more advanced policies to use
ML to predict the action. Returns the index of the next action."""
probabilities, policy = self._get_next_action_probabilities(tracker)
max_index = int(np.argmax(probabilities))
action = self.domain.action_for_index(max_index, self.action_endpoint)
logger.debug("Predicted next action '{}' with prob {:.2f}.".format(
action.name(), probabilities[max_index]))
return action, policy, probabilities[max_index]
@staticmethod
def _is_reminder(e: Event, name: Text) -> bool:
return isinstance(e, ReminderScheduled) and e.name == name
@staticmethod
def _is_reminder_still_valid(tracker: DialogueStateTracker,
reminder_event: ReminderScheduled
) -> bool:
"""Check if the conversation has been restarted after reminder."""
for e in reversed(tracker.applied_events()):
if MessageProcessor._is_reminder(e, reminder_event.name):
return True
return False # not found in applied events --> has been restarted
@staticmethod
def _has_message_after_reminder(tracker: DialogueStateTracker,
reminder_event: ReminderScheduled
) -> bool:
"""Check if the user sent a message after the reminder."""
for e in reversed(tracker.events):
if MessageProcessor._is_reminder(e, reminder_event.name):
return False
elif isinstance(e, UserUttered) and e.text:
return True
return True # tracker has probably been restarted
async def handle_reminder(self,
reminder_event: ReminderScheduled,
dispatcher: Dispatcher
) -> None:
"""Handle a reminder that is triggered asynchronously."""
tracker = self._get_tracker(dispatcher.sender_id)
if not tracker:
logger.warning("Failed to retrieve or create tracker for sender "
"'{}'.".format(dispatcher.sender_id))
return None
if (reminder_event.kill_on_user_message and
self._has_message_after_reminder(tracker, reminder_event) or
not self._is_reminder_still_valid(tracker, reminder_event)):
logger.debug("Canceled reminder because it is outdated. "
"(event: {} id: {})".format(reminder_event.action_name,
reminder_event.name))
else:
# necessary for proper featurization, otherwise the previous
# unrelated message would influence featurization
tracker.update(UserUttered.empty())
action = self._get_action(reminder_event.action_name)
should_continue = await self._run_action(action, tracker,
dispatcher)
if should_continue:
user_msg = UserMessage(None,
dispatcher.output_channel,
dispatcher.sender_id)
await self._predict_and_execute_next_action(user_msg, tracker)
# save tracker state to continue conversation from this state
self._save_tracker(tracker)
@staticmethod
def _log_slots(tracker):
# Log currently set slots
slot_values = "\n".join(["\t{}: {}".format(s.name, s.value)
for s in tracker.slots.values()])
logger.debug("Current slot values: \n{}".format(slot_values))
def _get_action(self, action_name):
return self.domain.action_for_name(action_name, self.action_endpoint)
async def _parse_message(self, message):
# for testing - you can short-cut the NLU part with a message
# in the format /intent{"entity1": val1, "entity2": val2}
# parse_data is a dict of intent & entities
if message.text.startswith(INTENT_MESSAGE_PREFIX):
parse_data = await RegexInterpreter().parse(message.text,
message.message_id)
else:
parse_data = await self.interpreter.parse(message.text,
message.message_id)
logger.debug("Received user message '{}' with intent '{}' "
"and entities '{}'".format(message.text,
parse_data["intent"],
parse_data["entities"]))
return parse_data
async def _handle_message_with_tracker(self,
message: UserMessage,
tracker: DialogueStateTracker
) -> None:
if message.parse_data:
parse_data = message.parse_data
else:
parse_data = await self._parse_message(message)
# don't ever directly mutate the tracker
# - instead pass its events to log
tracker.update(UserUttered(message.text, parse_data["intent"],
parse_data["entities"], parse_data,
input_channel=message.input_channel,
message_id=message.message_id))
# store all entities as slots
for e in self.domain.slots_for_entities(parse_data["entities"]):
tracker.update(e)
logger.debug("Logged UserUtterance - "
"tracker now has {} events".format(len(tracker.events)))
@staticmethod
def _should_handle_message(tracker):
return (not tracker.is_paused() or
tracker.latest_message.intent.get("name") ==
USER_INTENT_RESTART)
async def _predict_and_execute_next_action(self, message, tracker):
# keep taking actions decided by the policy until it chooses to 'listen'
should_predict_another_action = True
num_predicted_actions = 0
def is_action_limit_reached():
return (num_predicted_actions == self.max_number_of_predictions and
should_predict_another_action)
# this will actually send the response to the user
dispatcher = Dispatcher(message.sender_id,
message.output_channel,
self.nlg)
self._log_slots(tracker)
# action loop. predicts actions until we hit action listen
while (should_predict_another_action and
self._should_handle_message(tracker) and
num_predicted_actions < self.max_number_of_predictions):
# this actually just calls the policy's method by the same name
action, policy, confidence = self.predict_next_action(tracker)
should_predict_another_action = await self._run_action(action,
tracker,
dispatcher,
policy,
confidence)
num_predicted_actions += 1
if is_action_limit_reached():
# circuit breaker was tripped
logger.warning(
"Circuit breaker tripped. Stopped predicting "
"more actions for sender '{}'".format(tracker.sender_id))
if self.on_circuit_break:
# call a registered callback
self.on_circuit_break(tracker, dispatcher)
# noinspection PyUnusedLocal
@staticmethod
def should_predict_another_action(action_name, events):
is_listen_action = action_name == ACTION_LISTEN_NAME
return not is_listen_action
async def _schedule_reminders(self, events: List[Event],
tracker: DialogueStateTracker,
dispatcher: Dispatcher) -> None:
"""Uses the scheduler to time a job to trigger the passed reminder.
Reminders with the same `id` property will overwrite one another
(i.e. only one of them will eventually run)."""
for e in events:
if isinstance(e, ReminderScheduled):
(await jobs.scheduler()).add_job(
self.handle_reminder, "date",
run_date=e.trigger_date_time,
args=[e, dispatcher],
id=e.name,
replace_existing=True,
name=(str(e.action_name) +
ACTION_NAME_SENDER_ID_CONNECTOR_STR +
tracker.sender_id))
@staticmethod
async def _cancel_reminders(events: List[Event],
tracker: DialogueStateTracker) -> None:
"""Cancel reminders by action_name"""
# All Reminders with the same action name will be cancelled
for e in events:
if isinstance(e, ReminderCancelled):
name_to_check = (str(e.action_name) +
ACTION_NAME_SENDER_ID_CONNECTOR_STR +
tracker.sender_id)
scheduler = await jobs.scheduler()
for j in scheduler.get_jobs():
if j.name == name_to_check:
scheduler.remove_job(j.id)
async def _run_action(self, action, tracker, dispatcher, policy=None,
confidence=None):
# events and return values are used to update
# the tracker state after an action has been taken
try:
events = await action.run(dispatcher, tracker, self.domain)
except ActionExecutionRejection:
events = [ActionExecutionRejected(action.name(),
policy, confidence)]
tracker.update(events[0])
return self.should_predict_another_action(action.name(), events)
except Exception as e:
logger.error("Encountered an exception while running action '{}'. "
"Bot will continue, but the actions events are lost. "
"Make sure to fix the exception in your custom "
"code.".format(action.name()))
logger.debug(e, exc_info=True)
events = []
self._log_action_on_tracker(tracker, action.name(), events, policy,
confidence)
self.log_bot_utterances_on_tracker(tracker, dispatcher)
await self._schedule_reminders(events, tracker, dispatcher)
await self._cancel_reminders(events, tracker)
return self.should_predict_another_action(action.name(), events)
def _warn_about_new_slots(self, tracker, action_name, events):
# these are the events from that action we have seen during training
if action_name not in self.policy_ensemble.action_fingerprints:
return
fp = self.policy_ensemble.action_fingerprints[action_name]
slots_seen_during_train = fp.get("slots", set())
for e in events:
if isinstance(e, SlotSet) and e.key not in slots_seen_during_train:
s = tracker.slots.get(e.key)
if s and s.has_features():
if e.key == 'requested_slot' and tracker.active_form:
pass
else:
logger.warning(
"Action '{0}' set a slot type '{1}' that "
"it never set during the training. This "
"can throw of the prediction. Make sure to "
"include training examples in your stories "
"for the different types of slots this "
"action can return. Remember: you need to "
"set the slots manually in the stories by "
"adding '- slot{{\"{1}\": {2}}}' "
"after the action."
"".format(action_name, e.key,
json.dumps(e.value)))
@staticmethod
def log_bot_utterances_on_tracker(tracker: DialogueStateTracker,
dispatcher: Dispatcher) -> None:
if dispatcher.latest_bot_messages:
for m in dispatcher.latest_bot_messages:
bot_utterance = BotUttered(text=m.text, data=m.data)
logger.debug("Bot utterance '{}'".format(bot_utterance))
tracker.update(bot_utterance)
dispatcher.latest_bot_messages = []
def _log_action_on_tracker(self, tracker, action_name, events, policy,
confidence):
# Ensures that the code still works even if a lazy programmer missed
# to type `return []` at the end of an action or the run method
# returns `None` for some other reason.
if events is None:
events = []
logger.debug("Action '{}' ended with events '{}'".format(
action_name, ['{}'.format(e) for e in events]))
self._warn_about_new_slots(tracker, action_name, events)
if action_name is not None:
# log the action and its produced events
tracker.update(ActionExecuted(action_name, policy, confidence))
for e in events:
# this makes sure the events are ordered by timestamp -
# since the event objects are created somewhere else,
# the timestamp would indicate a time before the time
# of the action executed
e.timestamp = time.time()
tracker.update(e)
def _get_tracker(self, sender_id: Text) -> Optional[DialogueStateTracker]:
sender_id = sender_id or UserMessage.DEFAULT_SENDER_ID
tracker = self.tracker_store.get_or_create_tracker(sender_id)
return tracker
def _save_tracker(self, tracker):
self.tracker_store.save(tracker)
def _prob_array_for_action(self,
action_name: Text
) -> Tuple[Optional[List[float]], None]:
idx = self.domain.index_for_action(action_name)
if idx is not None:
result = [0.0] * self.domain.num_actions
result[idx] = 1.0
return result, None
else:
return None, None
def _get_next_action_probabilities(self,
tracker: DialogueStateTracker
) -> Tuple[Optional[List[float]],
Optional[Text]]:
"""Collect predictions from ensemble and return action and predictions.
"""
followup_action = tracker.followup_action
if followup_action:
tracker.clear_followup_action()
result = self._prob_array_for_action(followup_action)
if result:
return result
else:
logger.error(
"Trying to run unknown follow up action '{}'!"
"Instead of running that, we will ignore the action "
"and predict the next action.".format(followup_action))
return self.policy_ensemble.probabilities_using_best_policy(
tracker, self.domain)