rasa/core/tracker_store.py
import itertools
import json
import logging
import pickle
# noinspection PyPep8Naming
from typing import Iterator, KeysView, List, Optional, Text
from rasa.core.actions.action import ACTION_LISTEN_NAME
from rasa.core.broker import EventChannel
from rasa.core.domain import Domain
from rasa.core.trackers import (
ActionExecuted, DialogueStateTracker, EventVerbosity)
from rasa.core.utils import class_from_module_path
logger = logging.getLogger(__name__)
class TrackerStore(object):
def __init__(self,
domain: Optional[Domain],
event_broker: Optional[EventChannel] = None) -> None:
self.domain = domain
self.event_broker = event_broker
self.max_event_history = None
@staticmethod
def find_tracker_store(domain, store=None, event_broker=None):
if store is None or store.type is None:
return InMemoryTrackerStore(domain, event_broker=event_broker)
elif store.type == 'redis':
return RedisTrackerStore(domain=domain,
host=store.url,
event_broker=event_broker,
**store.kwargs)
elif store.type == 'mongod':
return MongoTrackerStore(domain=domain,
host=store.url,
event_broker=event_broker,
**store.kwargs)
elif store.type.lower() == 'sql':
return SQLTrackerStore(domain=domain,
url=store.url,
event_broker=event_broker,
**store.kwargs)
else:
return TrackerStore.load_tracker_from_module_string(domain, store)
@staticmethod
def load_tracker_from_module_string(domain, store):
custom_tracker = None
try:
custom_tracker = class_from_module_path(store.type)
except (AttributeError, ImportError):
logger.warning("Store type '{}' not found. "
"Using InMemoryTrackerStore instead"
.format(store.type))
if custom_tracker:
return custom_tracker(domain=domain,
url=store.url, **store.kwargs)
else:
return InMemoryTrackerStore(domain)
def get_or_create_tracker(self, sender_id, max_event_history=None):
tracker = self.retrieve(sender_id)
self.max_event_history = max_event_history
if tracker is None:
tracker = self.create_tracker(sender_id)
return tracker
def init_tracker(self, sender_id):
if self.domain:
return DialogueStateTracker(
sender_id,
self.domain.slots,
max_event_history=self.max_event_history)
else:
return None
def create_tracker(self, sender_id, append_action_listen=True):
"""Creates a new tracker for the sender_id.
The tracker is initially listening."""
tracker = self.init_tracker(sender_id)
if tracker:
if append_action_listen:
tracker.update(ActionExecuted(ACTION_LISTEN_NAME))
self.save(tracker)
return tracker
def save(self, tracker):
raise NotImplementedError()
def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
raise NotImplementedError()
def stream_events(self, tracker: DialogueStateTracker) -> None:
old_tracker = self.retrieve(tracker.sender_id)
offset = len(old_tracker.events) if old_tracker else 0
evts = tracker.events
for evt in list(itertools.islice(evts, offset, len(evts))):
body = {
"sender_id": tracker.sender_id,
}
body.update(evt.as_dict())
self.event_broker.publish(body)
def keys(self):
# type: () -> Optional[List[Text]]
raise NotImplementedError()
@staticmethod
def serialise_tracker(tracker):
dialogue = tracker.as_dialogue()
return pickle.dumps(dialogue)
def deserialise_tracker(self, sender_id, _json):
dialogue = pickle.loads(_json)
tracker = self.init_tracker(sender_id)
tracker.recreate_from_dialogue(dialogue)
return tracker
class InMemoryTrackerStore(TrackerStore):
def __init__(self,
domain: Domain,
event_broker: Optional[EventChannel] = None
) -> None:
self.store = {}
super(InMemoryTrackerStore, self).__init__(domain, event_broker)
def save(self, tracker: DialogueStateTracker) -> None:
if self.event_broker:
self.stream_events(tracker)
serialised = InMemoryTrackerStore.serialise_tracker(tracker)
self.store[tracker.sender_id] = serialised
def retrieve(self, sender_id: Text) -> Optional[DialogueStateTracker]:
if sender_id in self.store:
logger.debug('Recreating tracker for '
'id \'{}\''.format(sender_id))
return self.deserialise_tracker(sender_id, self.store[sender_id])
else:
logger.debug('Creating a new tracker for '
'id \'{}\'.'.format(sender_id))
return None
def keys(self) -> KeysView[Text]:
return self.store.keys()
class RedisTrackerStore(TrackerStore):
def keys(self):
pass
def __init__(self, domain, host='localhost',
port=6379, db=0, password=None, event_broker=None,
record_exp=None):
import redis
self.red = redis.StrictRedis(host=host, port=port, db=db,
password=password)
self.record_exp = record_exp
super(RedisTrackerStore, self).__init__(domain, event_broker)
def save(self, tracker, timeout=None):
if self.event_broker:
self.stream_events(tracker)
if not timeout and self.record_exp:
timeout = self.record_exp
serialised_tracker = self.serialise_tracker(tracker)
self.red.set(tracker.sender_id, serialised_tracker, ex=timeout)
def retrieve(self, sender_id):
stored = self.red.get(sender_id)
if stored is not None:
return self.deserialise_tracker(sender_id, stored)
else:
return None
class MongoTrackerStore(TrackerStore):
def __init__(self,
domain,
host="mongodb://localhost:27017",
db="rasa",
username=None,
password=None,
auth_source="admin",
collection="conversations",
event_broker=None):
from pymongo.database import Database
from pymongo import MongoClient
self.client = MongoClient(host,
username=username,
password=password,
authSource=auth_source,
# delay connect until process forking is done
connect=False)
self.db = Database(self.client, db)
self.collection = collection
super(MongoTrackerStore, self).__init__(domain, event_broker)
self._ensure_indices()
@property
def conversations(self):
return self.db[self.collection]
def _ensure_indices(self):
self.conversations.create_index("sender_id")
def save(self, tracker, timeout=None):
if self.event_broker:
self.stream_events(tracker)
state = tracker.current_state(EventVerbosity.ALL)
self.conversations.update_one(
{"sender_id": tracker.sender_id},
{"$set": state},
upsert=True)
def retrieve(self, sender_id):
stored = self.conversations.find_one({"sender_id": sender_id})
# look for conversations which have used an `int` sender_id in the past
# and update them.
if stored is None and sender_id.isdigit():
from pymongo import ReturnDocument
stored = self.conversations.find_one_and_update(
{"sender_id": int(sender_id)},
{"$set": {"sender_id": str(sender_id)}},
return_document=ReturnDocument.AFTER)
if stored is not None:
if self.domain:
return DialogueStateTracker.from_dict(sender_id,
stored.get("events"),
self.domain.slots)
else:
logger.warning("Can't recreate tracker from mongo storage "
"because no domain is set. Returning `None` "
"instead.")
return None
else:
return None
def keys(self):
return [c["sender_id"] for c in self.conversations.find()]
class SQLTrackerStore(TrackerStore):
"""Store which can save and retrieve trackers from an SQL database."""
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class SQLEvent(Base):
from sqlalchemy import Column, Integer, String, Float
__tablename__ = 'events'
id = Column(Integer, primary_key=True)
sender_id = Column(String, nullable=False)
type_name = Column(String, nullable=False)
timestamp = Column(Float)
intent_name = Column(String)
action_name = Column(String)
data = Column(String)
def __init__(self,
domain: Optional[Domain] = None,
dialect: Text = 'sqlite',
url: Text = None,
db: Text = 'rasa.db',
username: Text = None,
password: Text = None,
event_broker: Optional[EventChannel] = None) -> None:
from sqlalchemy.orm import sessionmaker
from sqlalchemy.engine.url import URL
from sqlalchemy import create_engine
engine_url = URL(dialect, username, password, url, database=db)
logger.debug('Attempting to connect to database '
'via "{}"'.format(engine_url.__to_string__()))
self.engine = create_engine(engine_url)
self.session = sessionmaker(bind=self.engine)()
self.Base.metadata.create_all(self.engine)
logger.debug("Connection to SQL database '{}' "
"successful".format(db))
super(SQLTrackerStore, self).__init__(domain, event_broker)
def keys(self) -> List[Text]:
"""Collect all keys of the items stored in the database."""
# noinspection PyUnresolvedReferences
return self.SQLEvent.__table__.columns.keys()
def retrieve(self, sender_id: Text) -> DialogueStateTracker:
"""Create a tracker from all previously stored events."""
query = self.session.query(self.SQLEvent)
result = query.filter_by(sender_id=sender_id).all()
events = [json.loads(event.data) for event in result]
if self.domain and len(events) > 0:
logger.debug("Recreating tracker "
"from sender id '{}'".format(sender_id))
return DialogueStateTracker.from_dict(sender_id, events,
self.domain.slots)
else:
logger.debug("Can't retrieve tracker matching"
"sender id '{}' from SQL storage. "
"Returning `None` instead.".format(sender_id))
def save(self, tracker: DialogueStateTracker) -> None:
"""Update database with events from the current conversation."""
if self.event_broker:
self.stream_events(tracker)
events = self._additional_events(tracker) # only store recent events
for event in events:
data = event.as_dict()
intent = data.get("parse_data", {}).get("intent", {}).get("name")
action = data.get("name")
timestamp = data.get("timestamp")
# noinspection PyArgumentList
self.session.add(self.SQLEvent(sender_id=tracker.sender_id,
type_name=event.type_name,
timestamp=timestamp,
intent_name=intent,
action_name=action,
data=json.dumps(data)))
self.session.commit()
logger.debug("Tracker with sender_id '{}' "
"stored to database".format(tracker.sender_id))
def _additional_events(self, tracker: DialogueStateTracker) -> Iterator:
"""Return events from the tracker which aren't currently stored."""
from sqlalchemy import func
query = self.session.query(func.max(self.SQLEvent.timestamp))
max_timestamp = query.filter_by(sender_id=tracker.sender_id).scalar()
if max_timestamp is None:
max_timestamp = 0
latest_events = []
for event in reversed(tracker.events):
if event.timestamp > max_timestamp:
latest_events.append(event)
else:
break
return reversed(latest_events)