takuseno/minerva

View on GitHub
minerva/dataset.py

Summary

Maintainability
A
35 mins
Test Coverage
# pylint: disable=no-name-in-module

import csv
import os
import tempfile
import zipfile

import numpy as np
from d3rlpy.dataset import MDPDataset
from PIL import Image
from tqdm import trange


def convert_ndarray_to_image(ndarray):
    # convert channel-fist to channel-last
    if ndarray.shape[0] == 1:
        array = ndarray[0]
    else:
        array = np.transpose(ndarray, [1, 2, 0])
    return Image.fromarray(array)


def convert_image_to_ndarray(image):
    array = np.asarray(image)
    # fix channel-first shape
    if image.mode == "L":
        array = array.reshape((1, *array.shape))
    else:
        array = array.transpose([2, 0, 1])
    return array


def export_mdp_dataset_as_csv(dataset, fname):
    if len(dataset.get_observation_shape()) > 1:
        # image observation
        export_image_observation_dataset_as_csv(dataset, fname)
    else:
        # vector observation
        export_vector_observation_dataset_as_csv(dataset, fname)


def _save_image_files(dataset, zip_path):
    with tempfile.TemporaryDirectory() as dname:
        with zipfile.ZipFile(zip_path, "w") as zip_fd:
            data_size = dataset.observations.shape[0]
            for i in trange(data_size, desc="saving images"):
                image = convert_ndarray_to_image(dataset.observations[i])
                image_path = os.path.join(dname, "observation_%d.png" % i)
                image.save(image_path, quality=100)
                zip_fd.write(image_path, arcname="observation_%d.png" % i)


def export_image_observation_dataset_as_csv(dataset, fname):
    action_size = dataset.get_action_size()

    # prepare image directory
    csv_file_name = os.path.basename(fname)

    # save image files as zip file
    zip_name = csv_file_name.split(".")[0] + ".zip"
    zip_path = os.path.join(os.path.dirname(fname), zip_name)
    _save_image_files(dataset, zip_path)

    with open(fname, "w") as file:
        writer = csv.writer(file)

        # write header
        header = ["episode", "observation:0"]
        if dataset.is_action_discrete():
            header += ["action:0"]
        else:
            header += ["action:%d" % i for i in range(action_size)]
        header += ["reward"]
        writer.writerow(header)

        count = 0
        for i, episode in enumerate(dataset.episodes):
            # prepare data to write
            for j in range(episode.observations.shape[0]):
                row = []
                row.append(i)

                # add image path
                row.append("observation_%d.png" % count)
                count += 1

                row += episode.actions[j].reshape(-1).tolist()
                row.append(episode.rewards[j])
                writer.writerow(row)


def export_vector_observation_dataset_as_csv(dataset, fname):
    observation_size = dataset.get_observation_shape()[0]
    action_size = dataset.get_action_size()

    with open(fname, "w") as file:
        writer = csv.writer(file)

        # write header
        header = ["episode"]
        header += ["observation:%d" % i for i in range(observation_size)]
        if dataset.is_action_discrete():
            header += ["action:0"]
        else:
            header += ["action:%d" % i for i in range(action_size)]
        header += ["reward"]
        writer.writerow(header)

        for i, episode in enumerate(dataset.episodes):
            # prepare data to write
            observations = np.asarray(episode.observations)
            episode_length = observations.shape[0]
            actions = np.asarray(episode.actions).reshape(episode_length, -1)
            rewards = episode.rewards.reshape(episode_length, 1)
            episode_ids = np.full([episode_length, 1], i)

            # write episode
            rows = np.hstack([episode_ids, observations, actions, rewards])
            writer.writerows(rows)


def import_csv_as_mdp_dataset(fname, image=False):
    if image:
        return import_csv_as_image_observation_dataset(fname)
    return import_csv_as_vector_observation_dataset(fname)


def _load_image(path):
    image = Image.open(path)

    # resize image to (84, 84)
    if image.size != (84, 84):
        image = image.resize((84, 84), Image.BICUBIC)

    return image


def import_csv_as_image_observation_dataset(fname):
    with open(fname, "r") as file:
        reader = csv.reader(file)
        rows = [row for row in reader]

        # check header
        header = rows[0]
        _validate_csv_header(header)

        # get action size
        action_size = _get_action_size_from_header(header)

        data_size = len(rows) - 1

        observations = []
        actions = []
        rewards = []
        terminals = []
        for i, row in enumerate(rows[1:]):
            episode_id = row[0]

            # load image
            image_path = os.path.join(os.path.dirname(fname), row[1])
            if not os.path.exists(image_path):
                raise ValueError(f"{image_path} does not exist.")
            image = _load_image(os.path.join(os.path.dirname(fname), row[1]))

            # convert PIL.Image to ndarray
            array = convert_image_to_ndarray(image)

            observations.append(array)

            # get action columns
            actions.append(list(map(float, row[2 : 2 + action_size])))

            # get reward column
            rewards.append(float(row[-1]))

            if i == data_size - 1 or episode_id != rows[i + 2][0]:
                terminals.append(1)
            else:
                terminals.append(0)

        # convert list to ndarray
        observations = np.array(observations, dtype=np.uint8)
        actions = np.array(actions)
        rewards = np.array(rewards, dtype=np.float32)
        terminals = np.array(terminals, dtype=np.float32)

        dataset = MDPDataset(
            observations=observations,
            actions=actions,
            rewards=rewards,
            terminals=terminals,
        )

    return dataset


def import_csv_as_vector_observation_dataset(fname):
    with open(fname, "r") as file:
        reader = csv.reader(file)
        rows = [row for row in reader]

        # get observation shape
        header = rows[0]
        _validate_csv_header(header)

        # retrieve data section
        csv_data = np.array(rows[1:], dtype=np.float32)

        # get observation columns
        observation_size = _get_observation_size_from_header(header)
        observation_last_index = observation_size + 1
        observations = csv_data[:, 1:observation_last_index]

        # get action columns
        action_size = _get_action_size_from_header(header)
        action_last_index = observation_last_index + action_size
        actions = csv_data[:, observation_last_index:action_last_index]

        # get reward column
        rewards = csv_data[:, -1]

        # make terminal flags
        episode_ids = csv_data[:, 0]
        terminals = np.zeros_like(episode_ids)
        for i, episode_id in enumerate(episode_ids):
            if i + 1 == len(episode_ids) or episode_id != episode_ids[i + 1]:
                terminals[i] = 1.0

        dataset = MDPDataset(
            observations=observations,
            actions=actions,
            rewards=rewards,
            terminals=terminals,
        )

    return dataset


def _validate_csv_header(header):
    assert header[0] == "episode", "column=0 must be 'episode'"

    # check observation section
    index = 1
    observation_index = 0
    while header[index].find("action") == -1:
        ref_name = "observation:%d" % observation_index
        message = "column=%d must be '%s'" % (index, ref_name)
        assert header[index] == ref_name, message
        index += 1
        observation_index += 1

    # check action section
    action_index = 0
    while header[index] != "reward":
        ref_name = "action:%d" % action_index
        message = "column=%d must be '%s'" % (index, ref_name)
        assert header[index] == ref_name, message
        index += 1
        action_index += 1


def _get_observation_size_from_header(header):
    size = 0
    for column in header:
        if column.find("observation") > -1:
            size += 1
    return size


def _get_action_size_from_header(header):
    size = 0
    for column in header:
        if column.find("action") > -1:
            size += 1
    return size