Nikolay-Lysenko/rl-musician

View on GitHub
rlmusician/environment/environment.py

Summary

Maintainability
A
0 mins
Test Coverage
"""
Create environment with Gym API.

Author: Nikolay Lysenko
"""


from typing import Any, Dict, List, Tuple

import gym
import numpy as np

from rlmusician.environment.piece import Piece
from rlmusician.environment.evaluation import evaluate
from rlmusician.utils import convert_to_base


class CounterpointEnv(gym.Env):
    """
    An environment where counterpoint line is composed given cantus firmus.
    """

    reward_range = (-np.inf, np.inf)
    valid_durations = [1, 2, 4, 8]

    def __init__(
            self,
            piece: Piece,
            scoring_coefs: Dict[str, float],
            scoring_fn_params: Dict[str, Dict[str, Any]],
            reward_for_dead_end: float,
            verbose: bool = False
    ):
        """
        Initialize instance.

        :param piece:
            data structure representing musical piece with florid counterpoint
        :param scoring_coefs:
            mapping from scoring function names to their weights in final score
        :param scoring_fn_params:
            mapping from scoring function names to their parameters
        :param reward_for_dead_end:
            reward for situations where there aren't any allowed actions, but
            piece is not finished
        :param verbose:
            if it is set to `True`, breakdown of reward is printed at episode
            end
        """
        self.piece = piece
        self.scoring_coefs = scoring_coefs
        self.scoring_fn_params = scoring_fn_params
        self.reward_for_dead_end = reward_for_dead_end
        self.verbose = verbose

        n_actions = len(piece.all_movements) * len(self.valid_durations)
        self.action_space = gym.spaces.Discrete(n_actions)
        self.action_to_line_continuation = None
        self.__set_action_to_line_continuation()

        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=self.piece.piano_roll.shape,
            dtype=np.int32
        )

    def __set_action_to_line_continuation(self) -> None:
        """Create mapping from action to a pair of movement and duration."""
        base = len(self.piece.all_movements)
        required_len = 2
        raw_mapping = {
            action: (x for x in convert_to_base(action, base, required_len))
            for action in range(self.action_space.n)
        }
        offset = self.piece.max_skip
        action_to_continuation = {
            action: (movement_id - offset, 2 ** duration_id)
            for action, (duration_id, movement_id) in raw_mapping.items()
        }
        self.action_to_line_continuation = action_to_continuation

    @property
    def valid_actions(self) -> List[int]:
        """Get actions that are valid at the current step."""
        valid_actions = [
            i for i in range(self.action_space.n)
            if self.piece.check_validity(*self.action_to_line_continuation[i])
        ]
        return valid_actions

    def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
        """
        Run one step of the environment's dynamics.

        :param action:
            an action provided by an agent to the environment
        :return:
            a tuple of:
            - observation: agent's observation of the current environment,
            - reward: amount of reward returned after previous action,
            - done: whether the episode has ended, in which case further
                    `step()` calls will return undefined results,
            - info: list of next allowed actions
        """
        movement, duration = self.action_to_line_continuation[action]
        self.piece.add_line_element(movement, duration)

        observation = self.piece.piano_roll
        info = {'next_actions': self.valid_actions}

        past_duration = self.piece.current_time_in_eighths
        piece_duration = self.piece.total_duration_in_eighths
        finished = past_duration == piece_duration
        no_more_actions = len(info['next_actions']) == 0
        done = finished or no_more_actions

        if finished:
            reward = evaluate(
                self.piece,
                self.scoring_coefs,
                self.scoring_fn_params,
                self.verbose
            )
        elif no_more_actions:
            reward = self.reward_for_dead_end
        else:
            reward = 0

        return observation, reward, done, info

    def reset(self) -> np.ndarray:
        """
        Reset the state of the environment and return an initial observation.

        :return:
            initial observation
        """
        self.piece.reset()
        initial_observation = self.piece.piano_roll
        return initial_observation

    def render(self, mode='human'):  # pragma: no cover.
        """
        Save piece in various formats.

        :return:
            None
        """
        self.piece.render()