
View on GitHub


25 mins
Test Coverage
import typing
from typing import Union, Any
import pathlib
from functools import lru_cache
import importlib
import warnings
from typing_extensions import Literal
import numpy as np

from import DataSetException, DataSet
from libertem.common.async_utils import sync_to_async
from libertem.common.scheduler import Scheduler

filetypes = {
    "hdf5": "",
    "raw": "",
    "raw_csr": "",
    "mib": "",
    "blo": "",
    "k2is": "",
    "ser": "",
    "frms6": "",
    "empad": "",
    "dm": "",
    "seq": "",
    "mrc": "",
    "tvips": "",
    "npy": "",
    "dask": "",
    "memory": "",

def build_extension_map() -> dict[str, list[str]]:
    ext_map = {}
    for typ_ in filetypes:
        cls = get_dataset_cls(typ_)
        for ext in cls.get_supported_extensions():
            except KeyError:
                ext_map[ext] = [typ_]
    return ext_map

def _auto_load(
    path: str, enable_async: Literal[True], *args, executor, **kwargs,
) -> typing.Awaitable[DataSet]:

def _auto_load(
    path: str, enable_async: Literal[False], *args, executor, **kwargs,
) -> DataSet:

def _auto_load(
    path: str, enable_async: bool, *args, executor, **kwargs,
) -> typing.Union[DataSet, typing.Awaitable[DataSet]]:

def _auto_load(path, *args, executor, **kwargs):
    if path is None:
        raise DataSetException(
            "please specify the `path` argument to allow auto detection"
    detected_params = detect(path, executor=executor)
    filetype_detected: typing.Optional[str] = detected_params.get('type', None)
    if filetype_detected is None:
        raise DataSetException(
            "could not determine DataSet type for file '%s'" % path,
    return load(
        filetype_detected, path, *args, executor=executor, **kwargs

def load(
    filetype: str, *args, enable_async: Literal[True], executor, **kwargs,
) -> typing.Awaitable[DataSet]:

def load(
    filetype: str, *args, enable_async: Literal[False], executor, **kwargs,
) -> DataSet:

def load(
    filetype: str, *args, enable_async: bool, executor, **kwargs,
) -> typing.Union[DataSet, typing.Awaitable[DataSet]]:

def load(
    filetype: str,
    enable_async: bool = False,
    Low-level method to load a dataset. Usually you will want
    to use Context.load instead!

    filetype : str or DataSet type
        see for supported types, example: 'hdf5'

    executor : JobExecutor

    enable_async : bool
        If True, return a coroutine instead of blocking until the loading has

    additional parameters are passed to the concrete DataSet implementation
    if filetype == "auto":
        return _auto_load(*args, executor=executor, enable_async=enable_async, **kwargs)

    cls = get_dataset_cls(filetype)

    async def _init_async():
        ds = cls(*args, **kwargs)
        ds = await sync_to_async(ds.initialize, executor=executor.ensure_sync())
        workers = await executor.get_available_workers()
        scheduler = Scheduler(workers)
        # FIXME the partitioning should be dynamic
        # since the number of eligible workers may depend on
        # the set of UDFs that may or may not run on CPU or GPU
        # This is a workaround with a "best guess compromise"
        await executor.run_function(ds.check_valid)
        return ds

    if enable_async:
        return _init_async()
        ds = cls(*args, **kwargs)
        ds = ds.initialize(executor)
        workers = executor.get_available_workers()
        scheduler = Scheduler(workers)
        return ds

def register_dataset_cls(filetype: str, cls: str) -> None:
    filetypes[filetype] = cls

def unregister_dataset_cls(filetype: str) -> None:
    del filetypes[filetype]

def get_dataset_cls(filetype: str) -> type[DataSet]:
    if not isinstance(filetype, str):
        return filetype
        ft = filetypes[filetype.lower()]
    except KeyError:
        raise DataSetException("unknown filetype: %s" % filetype)
    if not isinstance(ft, str):
        return ft
    parts = ft.split(".")
    module_name = ".".join(parts[:-1])
    cls_name = parts[-1]
        module = importlib.import_module(module_name)
    except ImportError as e:
        raise DataSetException("could not load dataset: %s" % str(e))
    cls: type[DataSet] = getattr(module, cls_name)
    return cls

def get_search_order(path: Union[str, np.ndarray]) -> list[str]:
    Return the keys from filetypes in an order which
    is perhaps optimal for dataset auto-detection
    extension_map = build_extension_map()
    search_order = list(filetypes.keys())
        # If the file format is registered, float the associated
        # datasets to the top of the search order (maintaining
        # the order in which they were first registered)
        file_format = pathlib.Path(path).suffix.strip().lstrip('.').lower()
        if file_format in extension_map:
            for ds_key in reversed(extension_map[file_format]):
                search_order = [ds_key] + search_order
    except (TypeError, ValueError):
        # Let downstream code handle the fact that
        # path cannot be cast to pathlib.Path or provide a suffix
        # If path has a shape attribute there is good chance
        # it implements the array interface and as such we should
        # check MemoryDataSet first
        _ = path.shape
        search_order = ['memory'] + search_order
        warnings.warn('Auto-loading a MemoryDataSet is currently unsupported, '
                      'use ctx.load("memory", data=array).')
    except AttributeError:
        # Cannot interpret as a memory dataset
    return search_order

def detect(path: Union[str, np.ndarray], executor) -> dict[str, Any]:
    Returns dataset's detected type, parameters and
    additional info.
    search_order = get_search_order(path)
    for filetype in search_order:
            cls = get_dataset_cls(filetype)
            params = cls.detect_params(path, executor)
        except (NotImplementedError, DataSetException):
        if not params:
        params.update({"type": filetype})
        return params
    return {}

def get_extensions() -> set[str]:
    Return supported extensions as a set of strings.

    Plain extensions only, no pattern!
    types: set[str] = set()
    for filetype in filetypes.keys():
        cls = get_dataset_cls(filetype)
        types = types.union({ext.lower() for ext in cls.get_supported_extensions()})
    return types