tsdat/io/retrievers/storage_retriever.py

Summary

Maintainability
A
3 hrs
Test Coverage
from datetime import timedelta
import pandas as pd
import xarray as xr
from pydantic import BaseModel, Field
from typing import (
    Any,
    Callable,
    Dict,
    List,
    Optional,
)

from .global_arm_transform_params import GlobalARMTransformParams
from .global_fetch_params import GlobalFetchParams
from .storage_retriever_input import StorageRetrieverInput
from .perform_data_retrieval import perform_data_retrieval
from ...config.dataset import DatasetConfig
from ..base import (
    Retriever,
    Storage,
)
from ...const import InputKey


class StorageRetriever(Retriever):
    """Retriever API for pulling input data from the storage area."""

    class TransParameters(BaseModel):
        trans_params: Optional[GlobalARMTransformParams] = Field(
            default=None, alias="transformation_parameters"
        )
        fetch_params: Optional[GlobalFetchParams] = Field(
            default=None, alias="fetch_parameters"
        )

    parameters: Optional[TransParameters] = None

    # TODO: `input_data_hook` is not included in docstring.
    def retrieve(
        self,
        input_keys: List[str],
        dataset_config: DatasetConfig,
        storage: Optional[Storage] = None,
        input_data_hook: Optional[
            Callable[[Dict[str, xr.Dataset]], Dict[str, xr.Dataset]]
        ] = None,
        **kwargs: Any,
    ) -> xr.Dataset:
        """------------------------------------------------------------------------------------
        Retrieves input data from the storage area.

        Note that each input_key is expected to be formatted according to the following
        format:

        ```python
        "--key1 value1 --key2 value2",
        ```

        e.g.,

        ```python
        "--datastream sgp.met.b0 --start 20230801 --end 20230901"
        "--datastream sgp.met.b0 --start 20230801 --end 20230901 --location_id sgp --data_level b0"
        ```

        This format allows the retriever to pull datastream data from the Storage API
        for the desired dates for each desired input source.

        Args:
            input_keys (List[str]): A list of input keys formatted as described above.
            dataset_config (DatasetConfig): The output dataset configuration.
            storage (Storage): Instance of a Storage class used to fetch saved data.

        Returns:
            xr.Dataset: The retrieved dataset

        ------------------------------------------------------------------------------------
        """
        assert storage is not None, "Missing required 'storage' parameter."

        storage_input_keys = [StorageRetrieverInput(key) for key in input_keys]

        input_data = self.__fetch_inputs(storage_input_keys, storage)

        if input_data_hook is not None:
            modded_input_data = input_data_hook(input_data)
            if modded_input_data is not None:
                input_data = modded_input_data

        # Perform coord/variable retrieval
        retrieved_data, retrieval_selections = perform_data_retrieval(
            input_data=input_data,
            coord_rules=self.coords,  # type: ignore
            data_var_rules=self.data_vars,  # type: ignore
        )

        # Ensure selected coords are indexed by themselves
        for name, coord_data in retrieved_data.coords.items():
            if coord_data.equals(xr.DataArray([])):
                continue
            new_coord = xr.DataArray(
                data=coord_data.data,
                coords={name: coord_data.data},
                dims=(name,),
                attrs=coord_data.attrs,
                name=name,
            )
            retrieved_data.coords[name] = new_coord
        # Q: Do data_vars need to be renamed or reindexed before data converters run?

        # Run data converters on coordinates, then on data variables
        for name, coord_def in retrieval_selections.coords.items():
            for converter in coord_def.data_converters:
                coord_data = retrieved_data.coords[name]
                data = converter.convert(
                    data=coord_data,
                    variable_name=name,
                    dataset_config=dataset_config,
                    retrieved_dataset=retrieved_data,
                    time_span=(storage_input_keys[0].start, storage_input_keys[0].end),
                    input_dataset=input_data.get(coord_def.source),
                    retriever=self,
                    input_key=coord_def.source,
                )
                if data is not None:
                    retrieved_data.coords[name] = data

        for name, var_def in retrieval_selections.data_vars.items():
            for converter in var_def.data_converters:
                var_data = retrieved_data.data_vars[name]
                data = converter.convert(
                    data=var_data,
                    variable_name=name,
                    dataset_config=dataset_config,
                    retrieved_dataset=retrieved_data,
                    retriever=self,
                    input_dataset=input_data.get(var_def.source),
                    input_key=var_def.source,
                )
                if data is not None:
                    retrieved_data.data_vars[name] = data

        # Construct the retrieved dataset structure
        # TODO: validate dimension alignment
        retrieved_dataset = xr.Dataset(
            coords=retrieved_data.coords,
            data_vars=retrieved_data.data_vars,
        )

        # Double check that dataset is trimmed to start and end time
        # Need to do this if adi_py is not used and more than one
        # files are pulled in.
        retrieved_dataset = self.__trim_dataset(retrieved_dataset, storage_input_keys)

        # Fix the dtype encoding
        for var_name, var_data in retrieved_dataset.data_vars.items():
            output_var_cfg = dataset_config.data_vars.get(var_name)
            if output_var_cfg is not None:
                dtype = output_var_cfg.dtype
                retrieved_dataset[var_name] = var_data.astype(dtype)
                var_data.encoding["dtype"] = dtype

        return retrieved_dataset

    # TODO: Seems like a static method here, should refactor into as such.
    def _get_timedelta(self, time_string):
        if time_string.replace(".", "").isnumeric():
            return pd.Timedelta(float(time_string), "s")
        else:
            return pd.Timedelta(time_string)

    # TODO: Method definition says that a lone `timedelta` is returned, but return statements return
    #  a `tuple[int, timedelta]`. This should be corrected.
    def _get_retrieval_padding(self, input_key: str) -> timedelta:
        if self.parameters is None:
            return 0, timedelta()
        elif self.parameters.fetch_params is not None:
            param = getattr(self.parameters.fetch_params, "time_padding")
            direction, padding = self.parameters.fetch_params.get_direction(param)
            return direction, self._get_timedelta(padding)
        elif self.parameters.trans_params is not None:
            params = self.parameters.trans_params.select_parameters(input_key)
            range_td = self._get_timedelta(params["range"].get("time", "0s"))
            width_td = self._get_timedelta(params["width"].get("time", "0s"))
            return 0, max(range_td, width_td)
        else:
            return 0, timedelta()

    def __fetch_inputs(
        self, input_keys: List[StorageRetrieverInput], storage: Storage
    ) -> Dict[InputKey, xr.Dataset]:
        input_data: Dict[InputKey, xr.Dataset] = {}
        for key in input_keys:
            padding = self._get_retrieval_padding(key.input_key)
            retrieved_dataset = storage.fetch_data(
                start=key.start - padding[1] if padding[0] < 1 else key.start,
                end=key.end + padding[1] if padding[0] > -1 else key.end,
                datastream=key.datastream,
                metadata_kwargs=key.kwargs,
            )
            input_data[key.input_key] = retrieved_dataset
        return input_data

    # TODO: Seems like a static method here, should refactor into as such.
    def __trim_dataset(
        self, dataset: xr.Dataset, input_keys: List[StorageRetrieverInput]
    ) -> xr.Dataset:
        # Trim dataset to original start and end keys
        # Start and end keys don't change between inputs
        start = input_keys[0].start
        end = input_keys[0].end
        return dataset.sel(time=slice(start, end))