RasaHQ/rasa_core

View on GitHub
rasa/test.py

Summary

Maintainability
A
1 hr
Test Coverage
import asyncio
import logging
from typing import Text, Dict
import os

from rasa.constants import DEFAULT_RESULTS_PATH
from rasa.model import get_model, get_model_subdirectories
from rasa.cli.utils import minimal_kwargs

logger = logging.getLogger(__name__)


def test(model: Text, stories: Text, nlu_data: Text, endpoints: Text = None,
         output: Text = DEFAULT_RESULTS_PATH, **kwargs):
    test_core(model, stories, endpoints, output, **kwargs)
    test_nlu(model, nlu_data, **kwargs)


def test_core(model: Text, stories: Text, endpoints: Text = None,
              output: Text = DEFAULT_RESULTS_PATH, model_path: Text = None,
              **kwargs: Dict):
    import rasa.core.test
    import rasa.core.utils as core_utils
    from rasa_nlu import utils as nlu_utils
    from rasa.model import get_model
    from rasa.core.interpreter import NaturalLanguageInterpreter
    from rasa.core.agent import Agent

    _endpoints = core_utils.AvailableEndpoints.read_endpoints(endpoints)

    if output:
        nlu_utils.create_dir(output)

    if os.path.isfile(model):
        model_path = get_model(model)

    if model_path:
        # Single model: Normal evaluation
        loop = asyncio.get_event_loop()
        model_path = get_model(model)
        core_path, nlu_path = get_model_subdirectories(model_path)

        _interpreter = NaturalLanguageInterpreter.create(nlu_path,
                                                         _endpoints.nlu)

        _agent = Agent.load(core_path, interpreter=_interpreter)

        kwargs = minimal_kwargs(kwargs, rasa.core.test)
        loop.run_until_complete(
            rasa.core.test(stories, _agent, out_directory=output, **kwargs))

    else:
        from rasa.core.test import compare, plot_curve

        compare(model, stories, output)

        story_n_path = os.path.join(model, 'num_stories.json')

        number_of_stories = core_utils.read_json_file(story_n_path)
        plot_curve(output, number_of_stories)


def test_nlu(model: Text, nlu_data: Text, **kwargs: Dict):
    from rasa_nlu.test import run_evaluation

    unpacked_model = get_model(model)
    nlu_model = os.path.join(unpacked_model, "nlu")
    kwargs = minimal_kwargs(kwargs, run_evaluation)
    run_evaluation(nlu_data, nlu_model, **kwargs)


def test_nlu_with_cross_validation(config: Text, nlu: Text, folds: int = 3):
    import rasa_nlu.config
    import rasa_nlu.test as nlu_test

    nlu_config = rasa_nlu.config.load(config)
    data = rasa_nlu.training_data.load_data(nlu)
    data = nlu_test.drop_intents_below_freq(data, cutoff=5)
    results, entity_results = nlu_test.cross_validate(data, int(folds),
                                                      nlu_config)
    logger.info("CV evaluation (n={})".format(folds))

    if any(results):
        logger.info("Intent evaluation results")
        nlu_test.return_results(results.train, "train")
        nlu_test.return_results(results.test, "test")
    if any(entity_results):
        logger.info("Entity evaluation results")
        nlu_test.return_entity_results(entity_results.train, "train")
        nlu_test.return_entity_results(entity_results.test, "test")