AlessioZanga/PyEEGLab

View on GitHub
pyeeglab/dataset/dataset.py

Summary

Maintainability
A
1 hr
Test Coverage
import os
import json
import logging
import hashlib
import pickle

from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import reduce
from multiprocessing import Pool, cpu_count
from operator import add, and_
from uuid import uuid4, uuid5, NAMESPACE_X500

from typing import Dict, List, Tuple

import mne
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker, Query

from .declarative_base import Base
from .file import File
from .metadata import Metadata
from .annotation import Annotation

from ..pipeline import Pipeline


@dataclass(init=False)
class Dataset(ABC):
    path: str
    name: str
    version: str

    extensions: List[str]
    exclude_file: List[str]
    exclude_channels_set: List[str]
    exclude_channels_reference: List[str]
    exclude_sampling_frequency: List[int]
    minimum_annotation_duration: float

    session: Session
    query: Query

    pipeline: Pipeline = None

    def __init__(
            self,
            path: str,
            name: str,
            version: str = None,
            extensions: List[str] = [".edf"],
            exclude_file: List[str] = None,
            exclude_channels_set: List[str] = None,
            exclude_channels_reference: List[str] = None,
            exclude_sampling_frequency: List[str] = None,
            minimum_annotation_duration: float = None
        ) -> None:
        # Set basic attributes
        self.path = os.path.abspath(os.path.join(path, version))
        self.name = name
        self.version = version

        # Set data set filter attributes
        self.extensions = extensions if extensions else []
        self.exclude_file = exclude_file if exclude_file else []
        self.exclude_channels_set = exclude_channels_set if exclude_channels_set else []
        self.exclude_channels_reference = exclude_channels_reference if exclude_channels_reference else []
        self.exclude_sampling_frequency = exclude_sampling_frequency if exclude_sampling_frequency else []
        self.minimum_annotation_duration = minimum_annotation_duration if minimum_annotation_duration else 0

        logging.info("Init dataset '%s'@'%s' at '%s'", self.name, self.version, self.path)

        # Make workspace directory
        logging.debug("Make .pyeeglab directory")
        workspace = os.path.join(self.path, ".pyeeglab")
        os.makedirs(workspace, exist_ok=True)
        logging.debug("Make .pyeeglab/cache directory")
        os.makedirs(os.path.join(workspace, "cache"), exist_ok=True)
        logging.debug("Set MNE log .pyeeglab/mne.log")
        mne.set_log_file(os.path.join(workspace, "mne.log"), overwrite=False)

        # Index data set files
        self.index()

    def __getstate__(self):
        # Workaround for unpickable sqlalchemy.orm.session
        # during multiprocess dataset loading
        state = self.__dict__.copy()
        for attribute in ["session", "query"]:
            if hasattr(self, attribute):
                del state[attribute]
        return state

    @abstractmethod
    def download(self, user: str = None, password: str = None) -> None:
        pass

    def index(self) -> None:
        # Init index session
        logging.debug("Make index session")
        connection = os.path.join(self.path, ".pyeeglab", "index.sqlite3")
        connection = create_engine("sqlite:///" + connection)
        Base.metadata.create_all(connection)
        self.session = sessionmaker(bind=connection)()
        # Open multiprocess pool
        logging.info("Index data set directory")
        pool = Pool(cpu_count())
        # Get files path from data set path
        paths = [
            os.path.join(directory, filename)
            for directory, _, filenames in os.walk(self.path)
            for filename in filenames
        ]
        # Get Files instances form paths, filtering already indexed
        files = self.session.query(File).all()
        files = [file.uuid for file in files]
        files = [
            file
            for file in pool.map(self._get_file, paths)
            if file.uuid not in files
        ]
        for file in files:
            logging.debug("Add file %s to index", file.uuid)
        # Filter raw data files by extension
        raws = [
            file
            for file in files
            if os.path.splitext(file.path)[-1] in self.extensions
        ]
        # Get metadata and annotation for data files
        metadatas = pool.map(self._get_metadata, raws)
        annotations = pool.map(self._get_annotation, raws)
        # Close multiprocess pool
        pool.close()
        pool.join()
        # Commit insertions to index
        commits = files + metadatas + reduce(add, annotations, [])
        if commits:
            logging.info("Commit insertions to index")
            self.session.add_all(commits)
            self.session.commit()
        logging.info("Index data set completed")
        # Init default query
        logging.debug("Init default query")
        self.query = self.session.query(File, Metadata, Annotation).\
            join(File.meta).\
            join(File.annotations).\
            filter(~Metadata.channels_reference.in_(self.exclude_channels_reference)).\
            filter(~Metadata.sampling_frequency.in_(self.exclude_sampling_frequency)).\
            filter((Annotation.end - Annotation.begin) >= self.minimum_annotation_duration)
        # Filter exclude file paths
        for file in self.exclude_file:
            self.query = self.query.filter(~File.path.like("%{}%".format(file)))
        logging.debug("SQL query representation: '%s'", str(self.query).replace("\n", ""))
    
    def _get_file(self, path: str) -> File:
        return File(
            uuid=str(uuid5(NAMESPACE_X500, path)),
            path=path,
            extension=os.path.splitext(path)[-1]
        )

    def _get_metadata(self, file: File) -> Metadata:
        logging.debug("Add file %s metadata to index", file.uuid)
        with file as reader:
            info = reader.info
            metadata = Metadata(
                file_uuid=file.uuid,
                duration=reader.n_times/info["sfreq"],
                channels_set=json.dumps(info["ch_names"]),
                sampling_frequency=info["sfreq"],
                max_value=reader.get_data().max(),
                min_value=reader.get_data().min(),
            )
        return metadata

    def _get_annotation(self, file: File) -> List[Annotation]:
        logging.debug("Add file %s annotations to index", file.uuid)
        with file as reader:
            annotations = [
                Annotation(
                    uuid=str(uuid4()),
                    file_uuid=file.uuid,
                    begin=annotation[0],
                    end=annotation[0]+annotation[1],
                    label=annotation[2],
                )
                for annotation in reader.annotations
            ]
        return annotations
    
    @property
    def environment(self) -> Dict:
        min_max = self.signal_min_max_range
        return {
            "channels_set": self.maximal_channels_subset,
            "lowest_frequency": self.lowest_frequency,
            "min_value": min_max[0],
            "max_value": min_max[1],
        }
    
    @property
    def lowest_frequency(self) -> float:
        frequency = self.query.all()
        frequency = min([
            f[1].sampling_frequency
            for f in frequency
        ], default=0)
        return frequency

    @property
    def maximal_channels_subset(self) -> List[str]:
        channels = self.query.group_by(Metadata.channels_set).all()
        channels = [
            frozenset(json.loads(channel[1].channels_set))
            for channel in channels
        ]
        channels = reduce(and_, channels)
        channels = channels - frozenset(self.exclude_channels_set)
        channels = sorted(channels)
        return channels
    
    @property
    def signal_min_max_range(self) -> Tuple[float]:
        min_max = self.query.all()
        min_max = [m[1] for m in min_max]
        min_max = tuple([
            min([m.min_value for m in min_max], default=0),
            max([m.max_value for m in min_max], default=0),
        ])
        return min_max
    
    def set_pipeline(self, pipeline: Pipeline) -> "Dataset":
        self.pipeline = pipeline
        self.pipeline.environment.update(self.environment)
        return self
    
    def set_minimum_event_duration(self, duration: float) -> "Dataset":
        logging.warning("This function will be deprecated in the near future")
        self.minimum_annotation_duration = duration
        return self
    
    def load(self) -> Dict:
        # Compute cache path
        cache = os.path.join(self.path, ".pyeeglab", "cache")
        # Compute cache key
        logging.info("Compute cache key")
        name = self.__class__.__name__.lower()
        if name.endswith("dataset"):
            name = name[:-len("dataset")]
        key = [hash(self), hash(self.pipeline)]
        key = [str(k).encode() for k in key]
        key = [hashlib.md5(k).hexdigest()[:10] for k in key]
        key = list(zip(["loader", "pipeline"], key))
        key = ["_".join(k) for k in key]
        key = name + "_" + "_".join(key)
        logging.info("Computed cache key: %s", key)
        # Load file cache
        cache = os.path.join(cache, key + ".pkl")
        if os.path.exists(cache):
            logging.info("Cache file found at %s", cache)
            with open(cache, "rb") as reader:
                try:
                    logging.info("Loading cache file")
                    return pickle.load(reader)
                except:
                    logging.error("Loading cache file failed")
        # Cache file not found, start preprocessing
        logging.info("Cache file not found, genereting new one")
        data = [row[2] for row in self.query.all()]
        data = self.pipeline.run(data)
        with open(cache, "wb") as file:
            logging.info("Dumping cache file")
            pickle.dump(data, file)
        return data
    
    def __hash__(self) -> int:
        key = [self.path, self.version, self.minimum_annotation_duration]
        key += self.exclude_file
        key += self.exclude_channels_set
        key += self.exclude_channels_reference
        key += self.exclude_sampling_frequency
        key = json.dumps(key).encode()
        key = hashlib.md5(key).hexdigest()
        key = int(key, 16)
        return key