endremborza/parquetranger

View on GitHub
parquetranger/core.py

Summary

Maintainability
D
2 days
Test Coverage
import json
import pickle
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import cached_property, partial
from hashlib import md5
from itertools import groupby
from math import log10
from pathlib import Path
from threading import Lock
from typing import Callable, Iterable, Optional

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from atqo import acquire_lock, get_lock, parallel_map

EXTENSION = ".parquet"
DEFAULT_ENV = "default-env"
GB_KEY = "__gb_dict"


@dataclass
class HashPartitioner:
    col: str | None = None  # if none it is index
    num_groups: int = 128

    def __call__(self, df: pd.DataFrame) -> pd.Series:
        base = df.index if self.col is None else df.loc[:, self.col]
        return base.astype(str).map(self.h)

    def h(self, elem: str):
        num = int(md5(elem.encode()).hexdigest(), base=16) % self.num_groups
        return f"{num:0{self._w}d}"

    @property
    def key(self):
        return f"__pqr-hash-{self.col}-{self.num_groups}__"

    @cached_property
    def _w(self):
        return int(log10(self.num_groups)) + 1


class TableRepo:
    """helps with storing, extending and reading tabular data in parquet format

    tries dividing based on group_cols, if that is None
    tries dividing based on max_records, if max records is 0
    just writes the file to root_path.parquet

    if both group_cols and max_records is given, it will create
    directories for the groups (nested directories if multiple columns given)
    """

    def __init__(
        self,
        root_path: Path | str,
        max_records: int = 0,
        group_cols: str | list | HashPartitioner | None = None,
        env_parents: dict[str, Path | str] | None = None,
        mkdirs=True,
        extra_metadata: dict | None = None,
        drop_group_cols: bool = False,
        fixed_metadata: dict | None = None,
        allow_metadata_extension: bool = False,
    ):
        self.max_records = max_records
        self.drop_group_cols = drop_group_cols  # also means 'read group cols'
        self.extra_metadata = extra_metadata or {}
        self._raw_grouper = group_cols
        self._env_parents = env_parents or {}
        self._is_grouped = group_cols is not None
        self._is_single_file = not (max_records or self._is_grouped)
        self._remake_dirs = mkdirs

        self._fixed_meta = fixed_metadata
        self._allow_meta_extension = allow_metadata_extension

        _rp = Path(root_path)
        self.name = _rp.name
        _default_kv = filter(lambda kv: kv[1] == _rp.parent, self._env_parents.items())
        self._default_env, e_path = [*_default_kv, (DEFAULT_ENV, _rp.parent)][0]
        self._env_parents[self._default_env] = e_path
        self._current_env = self._default_env

        self.mkdirs()

    def extend(self, df: pd.DataFrame):
        if self._is_grouped:
            return self._gb_handle(df, self.extend)

        resolved_table = self._resolve_metadata(df)
        if self.max_records == 0:
            lock = acquire_lock(self._df_path)
            if self._df_path.exists():
                base_table = self.read_table_from_path(
                    self._df_path, lock, release=False
                )
                # BUG: https://github.com/apache/arrow/issues/34782
                if base_table.num_rows == 0:
                    out_table = resolved_table
                else:
                    out_table = pa.concat_tables([base_table, resolved_table])
            else:
                out_table = resolved_table
            return self._write_table_to_path(out_table, self._df_path, lock)

        with get_lock(f"{self.main_path} - ext"):
            self._extend_parts(resolved_table)

    def replace_records(self, df: pd.DataFrame, by_groups=False):
        """replace records in files based on index"""
        if by_groups:
            return self._gb_handle(df, self.replace_records)

        df = df.loc[~df.index.duplicated(keep="first"), :]

        for full_path in self.paths:
            lock = acquire_lock(full_path)
            odf = self.read_df_from_path(full_path, lock, release=False)
            inter_ind = odf.index.intersection(df.index)
            if len(inter_ind) == 0:
                lock.release()
                continue
            odf.loc[inter_ind, :] = df.loc[inter_ind, :]
            self._write_df_to_path(odf, path=full_path, lock=lock)
            df = df.drop(inter_ind)

        if df.shape[0] > 0:
            self.extend(df)

    def batch_extend(self, df_iterator, **para_kwargs):
        list(parallel_map(self.extend, df_iterator, **para_kwargs))

    def map_partitions(self, fun, level=None, **para_kwargs):
        _mi = int(self.max_records > 0)
        lev_ind = slice(-len(self.group_cols) - _mi, -_mi or None)
        if level is None:

            def _idf(p: Path):
                return p.parts[lev_ind]

        else:
            _idf = self._gb_cols_from_path_meta(level)

        p_iter = map(lambda t: list(t[1]), groupby(sorted(self.paths, key=_idf), _idf))

        return parallel_map(partial(self._map_paths, fun=fun), p_iter, **para_kwargs)

    def replace_groups(self, df: pd.DataFrame):
        """replace files based on file name, only viable if `group_cols` is set"""
        return self._gb_handle(df, self.replace_all)

    def replace_all(self, df: pd.DataFrame):
        """purges everything and writes df instead"""
        self.purge()
        self.extend(df)

    def purge(self):
        """purges everything"""
        for p in self.paths:
            p.unlink()
        if self._meta_path.exists():
            self._meta_path.unlink()

    def get_full_df(self) -> pd.DataFrame:
        return self.get_full_table().to_pandas()

    def get_full_table(self) -> pa.Table:
        if plist := list(self.paths):
            return pa.concat_tables(map(self.read_table_from_path, plist))
        return pa.Table.from_pydict({})

    def get_partition_paths(
        self, partition_col: str
    ) -> Iterable[tuple[str, Iterable[Path]]]:
        _getkey = self._gb_cols_from_path_meta(partition_col)

        return groupby(sorted(self.paths, key=_getkey), _getkey)

    def get_partition_table(
        self, partition: str, partition_col: str | None = None
    ) -> pa.Table:
        pcol = partition_col or self.group_cols[0]
        ppaths = filter(lambda t: t[0] == partition, self.get_partition_paths(pcol))
        return pa.concat_tables(map(self.read_table_from_path, next(ppaths)[1]))

    def get_partition_df(
        self, partition: str, partition_col: str | None = None
    ) -> pd.DataFrame:
        return self.get_partition_table(partition, partition_col).to_pandas()

    def set_env(self, env: str):
        self._current_env = env
        self.mkdirs()

    def set_env_to_default(self):
        self.set_env(self._default_env)

    def read_table_from_path(
        self, path, lock: Optional[Lock] = None, release=True
    ) -> pa.Table:
        assert release or (lock is not None)
        if lock is None:
            lock = acquire_lock(path)
        try:
            out: pa.Table = pq.read_table(path)
        except Exception as e:  # pragma: no cover
            lock.release()
            raise e
        if release:
            lock.release()
        if not self.drop_group_cols:
            return out
        n = out.num_rows
        for k, v in (
            self._parse_metadata(out.schema.metadata or {}).get(GB_KEY, {}).items()
        ):
            out = out.append_column(k, pa.array(np.repeat(v, n)))
        return out

    def read_df_from_path(
        self, path: Path, lock: Optional[Lock] = None, release=True
    ) -> pd.DataFrame:
        return self.read_table_from_path(path, lock, release).to_pandas()

    def get_extending_dict_batch_writer(self, max_records=1_000_000):
        return RecordWriter(self, max_records)

    def get_extending_fixed_dict_batch_writer(self, cols, max_records=1_000_000):
        return FixedRecordWriter(trepo=self, record_limit=max_records, cols=cols)

    def get_extending_df_batch_writer(self, max_records=1_000_000):
        return DfBatchWriter(self, max_records)

    def get_replacing_dict_batch_writer(self, max_records=1_000_000):
        return RecordWriter(self, max_records, TableRepo.replace_records)

    def get_replacing_df_batch_writer(self, max_records=1_000_000):
        return DfBatchWriter(self, max_records, TableRepo.replace_records)

    @contextmanager
    def env_ctx(self, env_name):
        _base = self._current_env
        self.set_env(env_name)
        yield
        self.set_env(_base)

    @property
    def main_path(self) -> Path:
        return self._current_env_parent / self.name

    @property
    def vc_path(self) -> Path:
        return self._df_path if self._is_single_file else self.main_path

    @property
    def paths(self) -> Iterable[Path]:
        if self._is_single_file:
            return iter([self._df_path] if self._df_path.exists() else [])

        return self.main_path.glob("**/*" + EXTENSION)

    @property
    def n_files(self) -> int:
        return len(list(self.paths))

    @property
    def dfs(self) -> Iterable[pd.DataFrame]:
        return map(self.read_df_from_path, self.paths)

    @property
    def tables(self) -> Iterable[pa.Table]:
        return map(self.read_table_from_path, self.paths)

    @property
    def full_metadata(self):
        return self._parse_metadata(pq.read_schema(next(self.paths)).metadata)

    @property
    def group_cols(self):
        if isinstance(self._raw_grouper, str):
            return [self._raw_grouper]
        elif isinstance(self._raw_grouper, HashPartitioner):
            return [self._raw_grouper.key]
        assert (
            isinstance(self._raw_grouper, list) or self._raw_grouper is None
        ), "must be a valid grouper"
        return self._raw_grouper

    def _write_table_to_path(self, table: pa.Table, path, lock: Optional[Lock] = None):
        new_meta = (table.schema.metadata or {}) | _render_metadata(self.extra_metadata)
        if lock is None:
            lock = acquire_lock(path)
        try:
            pq.write_table(table.replace_schema_metadata(new_meta), path)
        finally:
            lock.release()

    def _write_df_to_path(self, df, path, lock: Optional[Lock] = None):
        """if lock is given, it should already be acquired"""
        self._write_table_to_path(pa.Table.from_pandas(df), path, lock)

    def _extend_parts(self, table: pa.Table):
        start_rec = 0
        if self.n_files:
            last_path = sorted(self.paths)[-1]
            with get_lock(last_path):
                missing = self.max_records - _read_size_from_path(last_path)
            if missing > 0:
                start_rec = missing
                old_table = self.read_table_from_path(last_path)
                _ctab = pa.concat_tables([old_table, table.slice(0, missing)])
                self._write_table_to_path(_ctab, last_path)

        for i in range(start_rec, table.num_rows, self.max_records):
            new_path = self.main_path / f"file-{self.n_files:020d}{EXTENSION}"
            self._write_table_to_path(table.slice(i, self.max_records), new_path)

    def _map_paths(self, paths, fun):
        return fun(pd.concat(map(self.read_df_from_path, paths)))

    def _gb_cols_from_path_meta(self, key):
        def d(path: Path):
            i = -1 - int(self.max_records > 0)
            for gc in self.group_cols[::-1]:
                gid = path.parts[i]
                yield gc, gid.replace(EXTENSION, "") if i == -1 else gid
                i -= 1

        return lambda p: dict(d(p))[key]

    def _gb_handle(self, df: pd.DataFrame, fun):
        if not self._is_grouped:
            raise TypeError("only works if group cols is set")

        ignore_index = isinstance(df.index, pd.RangeIndex)
        _schema = pa.Schema.from_pandas(df.pipe(self._de_grc))
        new_fix_meta = self._get_full_meta_dict(_schema)

        grouper = (
            self._raw_grouper(df)
            if isinstance(self._raw_grouper, HashPartitioner)
            else self.group_cols
        )

        for gid, gdf in df.groupby(grouper):
            self._gapply(
                gdf.reset_index(drop=True) if ignore_index else gdf,
                gid,
                fun,
                new_fix_meta,
            )

    def _gapply(self, gdf: pd.DataFrame, gid_raw: tuple | str, fun, meta_dic):
        if gdf.empty:
            return
        gid = gid_raw if isinstance(gid_raw, tuple) else (gid_raw,)
        gb_meta = dict(zip(self.group_cols, gid))
        gb_kwargs = dict(
            max_records=self.max_records,
            mkdirs=self._remake_dirs,
            extra_metadata=self.extra_metadata | {GB_KEY: gb_meta},
            drop_group_cols=False,
            fixed_metadata=meta_dic,
        )
        gpath = Path(self.main_path, *map(str, gid))
        _gtrepo_fun = getattr(TableRepo(gpath, **gb_kwargs), fun.__name__)
        _gtrepo_fun(gdf.pipe(self._de_grc))

    def _resolve_metadata(self, df: pd.DataFrame):
        # cast the new to old types
        # add empty ones to old ones only
        table = pa.Table.from_pandas(df)
        new_dict = _schema_to_dic(table.schema)
        full_dict = self._get_full_meta_dict(table.schema)
        if list(new_dict.items()) != list(full_dict.items()):
            return _cast_table(table, full_dict)
        return table

    def _get_full_meta_dict(self, new_schema: pa.Schema):
        with _try_lock(f"{self.main_path} - meta"):
            return self._inner_meta_dict(new_schema)

    def _inner_meta_dict(self, new_schema: pa.Schema):
        new_dict = _schema_to_dic(new_schema)
        if self._fixed_meta is not None:
            old_dict = self._fixed_meta
        else:
            first_path = self._meta_path
            with _try_lock(first_path):
                if first_path.exists():
                    old_dict = _schema_to_dic(pq.read_schema(first_path))
                else:
                    _pmeta = new_schema.metadata
                    new_schema = pa.schema(new_dict.items(), metadata=_pmeta)
                    rep_table = pa.Table.from_pylist([], schema=new_schema)
                    pq.write_table(rep_table, first_path)
                    old_dict = new_dict
        if self._fixed_meta:
            full_dict = self._fixed_meta
        elif self._allow_meta_extension:
            full_dict = old_dict | {
                k: v for k, v in new_dict.items() if k not in old_dict
            }
        else:
            full_dict = old_dict
        keydiff = full_dict.keys() - old_dict.keys()
        if keydiff:
            _w = "mismatched schemas: \n" + "\n".join(
                [f"new: {new_dict}", f"old: {old_dict}", f"full: {full_dict}"]
            )
            warnings.warn(f"key difference: {keydiff}", UserWarning)
            warnings.warn(_w, UserWarning)
            for path in self.paths:
                lock = acquire_lock(path)
                old_table = self.read_table_from_path(path, lock, release=False)
                self._write_table_to_path(_cast_table(old_table, full_dict), path, lock)
        return full_dict

    def mkdirs(self, force=False):
        if not (self._remake_dirs or force):
            return
        self._current_env_parent.mkdir(exist_ok=True, parents=True)
        if not self._is_single_file:
            self.main_path.mkdir(exist_ok=True)

    def _parse_metadata(self, meta_dic: dict):
        my_keys = [*self.extra_metadata.keys(), GB_KEY]
        return {
            k.decode("utf-8"): (
                pickle.loads(v) if k.decode("utf-8") in my_keys else json.loads(v)
            )
            for k, v in meta_dic.items()
        }

    def _de_grc(self, df):
        return df.drop(self.group_cols, axis=1) if self.drop_group_cols else df

    @property
    def _meta_path(self):
        if self._is_single_file:
            return self._df_path
        return self.main_path / "empty.meta"

    @property
    def _df_path(self):
        return self.main_path.with_suffix(EXTENSION)

    @property
    def _current_env_parent(self) -> Path:
        return self._env_parents[self._current_env]


@dataclass
class RecordWriter:
    trepo: TableRepo
    record_limit: int = 1_000_000
    writer_function: Callable = TableRepo.extend
    batch: list = field(default_factory=list)
    record_count: int = field(default=0, init=False)
    written_count: int = field(default=0, init=False)

    def __post_init__(self):
        self.record_count = len(self.batch)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.close()

    def add_to_batch(self, element):
        self.batch.append(self._parse_elem(element))
        self.record_count += self._rec_count_from_elem(element)
        if self.record_count >= self.record_limit:
            self._write()

    def close(self):
        if self.batch:
            self._write()
        self.record_count = 0

    def _write(self):
        for _ in range(2):
            try:
                self.writer_function(self.trepo, self._wrap_batch())
                break
            except FileNotFoundError:
                self.trepo.mkdirs(True)
        self.written_count += len(self.batch)
        self.batch = []
        self.record_count = 0

    def _wrap_batch(self):
        return pd.DataFrame(self.batch)

    def _parse_elem(self, elem):
        return elem

    def _rec_count_from_elem(self, elem):
        return 1


@dataclass
class FixedRecordWriter(RecordWriter):
    cols: list[str] = field(default_factory=list)

    def _parse_elem(self, elem):
        return {k: elem.get(k) for k in self.cols}


@dataclass
class DfBatchWriter(RecordWriter):
    def _wrap_batch(self):
        ig_ind = isinstance(self.batch[0].index, pd.RangeIndex)
        return pd.concat(self.batch, ignore_index=ig_ind)

    def _rec_count_from_elem(self, elem: pd.DataFrame):
        return elem.shape[0]


@contextmanager
def _try_lock(lock_name):
    lock = acquire_lock(lock_name)
    try:
        yield
    finally:
        lock.release()


def _render_metadata(meta_dic):
    return {k: pickle.dumps(v) for k, v in meta_dic.items()}


def _schema_to_dic(sch):
    return dict(zip(sch.names, sch.types))


def _read_size_from_path(path):
    return pq.read_metadata(path).num_rows


def _cast_table(table: pa.Table, dic: dict[str, pa.DataType]):
    arrs = []
    for k, v in dic.items():
        try:
            arrs.append(table[k].cast(v))
        except Exception as e:
            if not isinstance(e, KeyError):
                warnings.warn(f"couldn't cast {k} to {v}")
            arrs.append(pa.array(np.repeat(None, table.num_rows), type=v))
    # TODO: pd schema.metadata here not added
    # difficult to create
    return pa.Table.from_arrays(arrs, schema=pa.schema(dic.items()))