rasa/engine/caching.py

Summary

Maintainability
B
6 hrs
Test Coverage
A
100%
from __future__ import annotations

import abc
import logging
import os
import shutil
from datetime import datetime
from pathlib import Path
from typing import Text, Any, Optional, Tuple, List

from packaging import version
from sqlalchemy.engine import URL

from sqlalchemy.exc import OperationalError
from typing_extensions import Protocol, runtime_checkable

import rasa
import rasa.model
import rasa.utils.common
import rasa.shared.utils.common
from rasa.constants import MINIMUM_COMPATIBLE_VERSION
import sqlalchemy as sa
import sqlalchemy.orm
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta

from rasa.engine.storage.storage import ModelStorage

logger = logging.getLogger(__name__)

DEFAULT_CACHE_LOCATION = Path(".rasa", "cache")
DEFAULT_CACHE_NAME = "cache.db"
DEFAULT_CACHE_SIZE_MB = 1000

CACHE_LOCATION_ENV = "RASA_CACHE_DIRECTORY"
CACHE_DB_NAME_ENV = "RASA_CACHE_NAME"
CACHE_SIZE_ENV = "RASA_MAX_CACHE_SIZE"


class TrainingCache(abc.ABC):
    """Stores training results in a persistent cache.

    Used to minimize re-retraining when the data / config didn't change in between
    training runs.
    """

    @abc.abstractmethod
    def cache_output(
        self,
        fingerprint_key: Text,
        output: Any,
        output_fingerprint: Text,
        model_storage: ModelStorage,
    ) -> None:
        """Adds the output to the cache.

        If the output is of type `Cacheable` the output is persisted to disk in addition
        to its fingerprint.

        Args:
            fingerprint_key: The fingerprint key serves as key for the cache. Graph
                components can use their fingerprint key to lookup fingerprints of
                previous training runs.
            output: The output. The output is only cached to disk if it's of type
                `Cacheable`.
            output_fingerprint: The fingerprint of their output. This can be used
                to lookup potentially persisted outputs on disk.
            model_storage: Required for caching `Resource` instances. E.g. `Resource`s
                use that to copy data from the model storage to the cache.
        """

    ...

    @abc.abstractmethod
    def get_cached_output_fingerprint(self, fingerprint_key: Text) -> Optional[Text]:
        """Retrieves fingerprint of output based on fingerprint key.

        Args:
            fingerprint_key: The fingerprint serves as key for the lookup of output
                fingerprints.

        Returns:
            The fingerprint of a matching output or `None` in case no cache entry was
            found for the given fingerprint key.
        """
        ...

    @abc.abstractmethod
    def get_cached_result(
        self, output_fingerprint_key: Text, node_name: Text, model_storage: ModelStorage
    ) -> Optional[Cacheable]:
        """Returns a potentially cached output result.

        Args:
            output_fingerprint_key: The fingerprint key of the output serves as lookup
                key for a potentially cached version of this output.
            node_name: The name of the graph node which wants to use this cached result.
            model_storage: The current model storage (e.g. used when restoring
                `Resource` objects so that they can fill the model storage with data).

        Returns:
            `None` if no matching result was found or restored `Cacheable`.
        """
        ...


@runtime_checkable
class Cacheable(Protocol):
    """Protocol for cacheable graph component outputs.

    We only cache graph component outputs which are `Cacheable`. We only store the
    output fingerprint for everything else.
    """

    def to_cache(self, directory: Path, model_storage: ModelStorage) -> None:
        """Persists `Cacheable` to disk.

        Args:
            directory: The directory where the `Cacheable` can persist itself to.
            model_storage: The current model storage (e.g. used when caching `Resource`
                objects.
        """
        ...

    @classmethod
    def from_cache(
        cls,
        node_name: Text,
        directory: Path,
        model_storage: ModelStorage,
        output_fingerprint: Text,
    ) -> Cacheable:
        """Loads `Cacheable` from cache.

        Args:
            node_name: The name of the graph node which wants to use this cached result.
            directory: Directory containing the persisted `Cacheable`.
            model_storage: The current model storage (e.g. used when restoring
                `Resource` objects so that they can fill the model storage with data).
            output_fingerprint: The fingerprint of the cached result (e.g. used when
                restoring `Resource` objects as the fingerprint can not be easily
                calculated from the object itself).

        Returns:
            Instantiated `Cacheable`.
        """
        ...


class LocalTrainingCache(TrainingCache):
    """Caches training results on local disk (see parent class for full docstring)."""

    Base: DeclarativeMeta = declarative_base()

    class CacheEntry(Base):
        """Stores metadata about a single cache entry."""

        __tablename__ = "cache_entry"

        fingerprint_key = sa.Column(sa.String(), primary_key=True)
        output_fingerprint_key = sa.Column(sa.String(), nullable=False, index=True)
        last_used = sa.Column(sa.DateTime(timezone=True), nullable=False)
        rasa_version = sa.Column(sa.String(255), nullable=False)
        result_location = sa.Column(sa.String())
        result_type = sa.Column(sa.String())

    def __init__(self) -> None:
        """Creates cache.

        The `Cache` setting can be configured via environment variables.
        """
        self._cache_location = LocalTrainingCache._get_cache_location()

        self._max_cache_size = float(
            os.environ.get(CACHE_SIZE_ENV, DEFAULT_CACHE_SIZE_MB)
        )

        self._cache_database_name = os.environ.get(
            CACHE_DB_NAME_ENV, DEFAULT_CACHE_NAME
        )

        if not self._cache_location.exists() and not self._is_disabled():
            logger.debug(
                f"Creating caching directory '{self._cache_location}' because "
                f"it doesn't exist yet."
            )
            self._cache_location.mkdir(parents=True)

        self._sessionmaker = self._create_database()

        self._drop_cache_entries_from_incompatible_versions()

    @staticmethod
    def _get_cache_location() -> Path:
        return Path(os.environ.get(CACHE_LOCATION_ENV, DEFAULT_CACHE_LOCATION))

    def _create_database(self) -> sqlalchemy.orm.sessionmaker:
        if self._is_disabled():
            # Use in-memory database as mock to avoid having to check `_is_disabled`
            # everywhere
            database = ""
        else:
            database = str(self._cache_location / self._cache_database_name)

        # Use `future=True` as we are using the 2.x query style
        engine = sa.create_engine(
            URL.create(drivername="sqlite", database=database), future=True
        )
        self.Base.metadata.create_all(engine)

        return sa.orm.sessionmaker(engine)

    def _drop_cache_entries_from_incompatible_versions(self) -> None:
        incompatible_entries = self._find_incompatible_cache_entries()

        for entry in incompatible_entries:
            self._delete_cached_result(entry)

        self._delete_incompatible_entries_from_cache(incompatible_entries)

        logger.debug(
            f"Deleted {len(incompatible_entries)} from disk as their version "
            f"is older than the minimum compatible version "
            f"('{MINIMUM_COMPATIBLE_VERSION}')."
        )

    def _find_incompatible_cache_entries(self) -> List[LocalTrainingCache.CacheEntry]:
        with self._sessionmaker() as session:
            query_for_cache_entries = sa.select(self.CacheEntry)
            all_entries: List[LocalTrainingCache.CacheEntry] = (
                session.execute(query_for_cache_entries).scalars().all()
            )

        return [
            entry
            for entry in all_entries
            if version.parse(MINIMUM_COMPATIBLE_VERSION)
            > version.parse(entry.rasa_version)
        ]

    def _delete_incompatible_entries_from_cache(
        self, incompatible_entries: List[LocalTrainingCache.CacheEntry]
    ) -> None:
        incompatible_fingerprints = [
            entry.fingerprint_key for entry in incompatible_entries
        ]
        with self._sessionmaker.begin() as session:
            delete_query = sa.delete(self.CacheEntry).where(
                self.CacheEntry.fingerprint_key.in_(incompatible_fingerprints)
            )
            session.execute(delete_query)

    @staticmethod
    def _delete_cached_result(entry: LocalTrainingCache.CacheEntry) -> None:
        if entry.result_location and Path(entry.result_location).is_dir():
            shutil.rmtree(entry.result_location)

    def cache_output(
        self,
        fingerprint_key: Text,
        output: Any,
        output_fingerprint: Text,
        model_storage: ModelStorage,
    ) -> None:
        """Adds the output to the cache (see parent class for full docstring)."""
        if self._is_disabled():
            return

        cache_dir, output_type = None, None
        if isinstance(output, Cacheable):
            cache_dir, output_type = self._cache_output_to_disk(output, model_storage)

        try:
            self._add_cache_entry(
                cache_dir, fingerprint_key, output_fingerprint, output_type
            )
        except OperationalError:
            if cache_dir:
                shutil.rmtree(cache_dir)

            raise

    def _add_cache_entry(
        self,
        cache_dir: Optional[Text],
        fingerprint_key: Text,
        output_fingerprint: Text,
        output_type: Text,
    ) -> None:
        with self._sessionmaker.begin() as session:
            cache_entry = self.CacheEntry(
                fingerprint_key=fingerprint_key,
                output_fingerprint_key=output_fingerprint,
                last_used=datetime.utcnow(),
                rasa_version=rasa.__version__,
                result_location=cache_dir,
                result_type=output_type,
            )
            session.merge(cache_entry)

    def _is_disabled(self) -> bool:
        return self._max_cache_size == 0.0

    def _cache_output_to_disk(
        self, output: Cacheable, model_storage: ModelStorage
    ) -> Tuple[Optional[Text], Optional[Text]]:
        tempdir_name = rasa.utils.common.get_temp_dir_name()

        # Use `TempDirectoryPath` instead of `tempfile.TemporaryDirectory` as this
        # leads to errors on Windows when the context manager tries to delete an
        # already deleted temporary directory (e.g. https://bugs.python.org/issue29982)
        with rasa.utils.common.TempDirectoryPath(tempdir_name) as temp_dir:
            tmp_path = Path(temp_dir)
            try:

                output.to_cache(tmp_path, model_storage)

                logger.debug(
                    f"Caching output of type '{type(output).__name__}' succeeded."
                )
            except Exception as e:
                logger.error(
                    f"Caching output of type '{type(output).__name__}' failed with the "
                    f"following error:\n{e}"
                )
                return None, None

            output_size = rasa.utils.common.directory_size_in_mb(tmp_path)
            if output_size > self._max_cache_size:
                logger.debug(
                    f"Caching result of type '{type(output).__name__}' was skipped "
                    f"because it exceeds the maximum cache size of "
                    f"{self._max_cache_size} MiB."
                )
                return None, None

            while (
                rasa.utils.common.directory_size_in_mb(
                    self._cache_location,
                    filenames_to_exclude=[self._cache_database_name],
                )
                + output_size
                > self._max_cache_size
            ):
                self._drop_least_recently_used_item()

            output_type = rasa.shared.utils.common.module_path_from_instance(output)
            cache_path = shutil.move(temp_dir, self._cache_location)

            return cache_path, output_type

    def _drop_least_recently_used_item(self) -> None:
        with self._sessionmaker.begin() as session:
            query_for_least_recently_used_entry = sa.select(self.CacheEntry).order_by(
                self.CacheEntry.last_used.asc()
            )
            oldest_cache_item = (
                session.execute(query_for_least_recently_used_entry).scalars().first()
            )

            if not oldest_cache_item:
                self._purge_cache_dir_content()
                return

            self._delete_cached_result(oldest_cache_item)
            delete_query = sa.delete(self.CacheEntry).where(
                self.CacheEntry.fingerprint_key == oldest_cache_item.fingerprint_key
            )
            session.execute(delete_query)

            logger.debug(
                f"Deleted item with fingerprint "
                f"'{oldest_cache_item.fingerprint_key}' to free space."
            )

    def _purge_cache_dir_content(self) -> None:
        for item in self._cache_location.glob("*"):
            if item.name == self._cache_database_name:
                continue

            if item.is_dir():
                shutil.rmtree(item)
            else:
                item.unlink()

    def get_cached_output_fingerprint(self, fingerprint_key: Text) -> Optional[Text]:
        """Returns cached output fingerprint (see parent class for full docstring)."""
        with self._sessionmaker.begin() as session:
            query = sa.select(self.CacheEntry).filter_by(
                fingerprint_key=fingerprint_key
            )
            match = session.execute(query).scalars().first()

            if match:
                # This result was used during a fingerprint run.
                match.last_used = datetime.utcnow()
                return match.output_fingerprint_key

            return None

    def get_cached_result(
        self, output_fingerprint_key: Text, node_name: Text, model_storage: ModelStorage
    ) -> Optional[Cacheable]:
        """Returns a potentially cached output (see parent class for full docstring)."""
        result_location, result_type = self._get_cached_result(output_fingerprint_key)

        if not result_location:
            logger.debug(f"No cached output found for '{output_fingerprint_key}'")
            return None

        path_to_cached = Path(result_location)
        if not path_to_cached.is_dir():
            logger.debug(
                f"Cached output for '{output_fingerprint_key}' can't be found on disk."
            )
            return None

        return self._load_from_cache(
            result_location,
            result_type,
            node_name,
            model_storage,
            output_fingerprint_key,
        )

    def _get_cached_result(
        self, output_fingerprint_key: Text
    ) -> Tuple[Optional[Path], Optional[Text]]:
        with self._sessionmaker.begin() as session:
            query = sa.select(
                self.CacheEntry.result_location, self.CacheEntry.result_type
            ).where(
                self.CacheEntry.output_fingerprint_key == output_fingerprint_key,
                self.CacheEntry.result_location != sa.null(),
            )

            match = session.execute(query).first()

            if match:
                return Path(match.result_location), match.result_type

            return None, None

    @staticmethod
    def _load_from_cache(
        path_to_cached: Path,
        result_type: Text,
        node_name: Text,
        model_storage: ModelStorage,
        output_fingerprint_key: Text,
    ) -> Optional[Cacheable]:
        try:
            module = rasa.shared.utils.common.class_from_module_path(result_type)

            if not isinstance(module, Cacheable):
                logger.warning(
                    "Failed to restore a non cacheable module from cache. "
                    "Please implement the 'Cacheable' interface for module "
                    f"'{result_type}'."
                )
                return None

            return module.from_cache(
                node_name, path_to_cached, model_storage, output_fingerprint_key
            )
        except Exception as e:
            logger.warning(
                f"Failed to restore cached output of type '{result_type}' from "
                f"cache. Error:\n{e}"
            )
            return None