RasaHQ/rasa_core

View on GitHub
rasa/cli/train.py

Summary

Maintainability
A
25 mins
Test Coverage
import argparse
import tempfile
import typing
from typing import List, Optional, Text

from rasa.cli.default_arguments import (
    add_config_param, add_domain_param, add_nlu_data_param, add_stories_param)
from rasa.cli.utils import get_validated_path
from rasa.constants import (
    DEFAULT_CONFIG_PATH, DEFAULT_DATA_PATH, DEFAULT_DOMAIN_PATH,
    DEFAULT_MODELS_PATH)

if typing.TYPE_CHECKING:
    from rasa_nlu.model import Interpreter


# noinspection PyProtectedMember
def add_subparser(subparsers: argparse._SubParsersAction,
                  parents: List[argparse.ArgumentParser]):
    import rasa.core.cli.train as core_cli

    train_parser = subparsers.add_parser(
        "train", help="Train the Rasa bot")

    train_subparsers = train_parser.add_subparsers()
    train_core_parser = train_subparsers.add_parser(
        "core",
        conflict_handler="resolve",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        help="Train Rasa Core")
    train_core_parser.set_defaults(func=train_core)

    train_nlu_parser = train_subparsers.add_parser(
        "nlu",
        parents=parents,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        help="Train Rasa NLU")
    train_nlu_parser.set_defaults(func=train_nlu)

    for p in [train_parser, train_core_parser, train_nlu_parser]:
        add_general_arguments(p)

    for p in [train_core_parser, train_parser]:
        add_domain_param(p)
        core_cli.add_general_args(p)
    add_stories_param(train_core_parser)
    _add_core_compare_arguments(train_core_parser)

    add_nlu_data_param(train_nlu_parser)

    add_joint_parser_arguments(train_parser)
    train_parser.set_defaults(func=train)


def add_joint_parser_arguments(parser: argparse.ArgumentParser):
    parser.add_argument("--force", action="store_true",
                        help="Force a model training even if the data "
                             "has not changed.")
    parser.add_argument("--data", default=[DEFAULT_DATA_PATH],
                        nargs='+',
                        help="Paths to the Core and NLU training files.")


def add_general_arguments(parser: argparse.ArgumentParser):
    add_config_param(parser)
    parser.add_argument(
        "-o", "--out",
        type=str,
        default=DEFAULT_MODELS_PATH,
        help="Directory where your models are stored")


def _add_core_compare_arguments(parser: argparse.ArgumentParser):
    parser.add_argument(
        "--percentages",
        nargs="*",
        type=int,
        default=[0, 5, 25, 50, 70, 90, 95],
        help="Range of exclusion percentages")
    parser.add_argument(
        "--runs",
        type=int,
        default=3,
        help="Number of runs for experiments")
    parser.add_argument(
        "-c", "--config",
        nargs='+',
        default=[DEFAULT_CONFIG_PATH],
        help="The policy and NLU pipeline configuration of your bot."
             "If multiple configuration files are provided, multiple dialogue "
             "models are trained to compare policies.")


def train(args: argparse.Namespace) -> Optional[Text]:
    import rasa
    domain = get_validated_path(args.domain, "domain", DEFAULT_DOMAIN_PATH)
    config = get_validated_path(args.config, "config", DEFAULT_CONFIG_PATH)

    training_files = [get_validated_path(f, "data", DEFAULT_DATA_PATH)
                      for f in args.data]

    return rasa.train(domain, config, training_files, args.out, args.force)


def train_core(args: argparse.Namespace,
               train_path: Optional[Text] = None) -> Optional[Text]:
    from rasa.train import train_core
    import asyncio

    loop = asyncio.get_event_loop()
    output = train_path or args.out

    args.domain = get_validated_path(args.domain, "domain", DEFAULT_DOMAIN_PATH)
    stories = get_validated_path(args.stories, "stories", DEFAULT_DATA_PATH)

    _train_path = train_path or tempfile.mkdtemp()

    # Policies might be a list for the compare training. Do normal training
    # if only list item was passed.
    if not isinstance(args.config, list) or len(args.config) == 1:
        if isinstance(args.config, list):
            args.config = args.config[0]

        config = get_validated_path(args.config, "config", DEFAULT_CONFIG_PATH)

        return train_core(args.domain, config, stories, output, train_path)
    else:
        from rasa.core.train import do_compare_training
        loop.run_until_complete(do_compare_training(args, stories, None))
        return None


def train_nlu(args: argparse.Namespace, train_path: Optional[Text] = None
              ) -> Optional["Interpreter"]:
    from rasa.train import train_nlu

    output = train_path or args.out

    config = get_validated_path(args.config, "config", DEFAULT_CONFIG_PATH)
    nlu_data = get_validated_path(args.nlu, "nlu", DEFAULT_DATA_PATH)

    return train_nlu(config, nlu_data, output, train_path)