aswan/depot/base.py
import os
import pickle
import sys
import time
import zipfile
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field
from functools import partial, wraps
from hashlib import md5
from heapq import heappop, heappush
from itertools import islice
from pathlib import Path
from shutil import rmtree
from subprocess import CalledProcessError, check_output
from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union
import sqlalchemy as db
import yaml
from sqlalchemy.orm import Session, sessionmaker
from structlog import get_logger
from ..constants import (
DEFAULT_DEPOT_ROOT,
DEPOT_ROOT_ENV_VAR,
SUCCESS_STATUSES,
Statuses,
)
from ..metadata_handling import (
get_grouped_surls,
get_next_batch,
integrate_events,
reset_surls,
)
from ..models import Base, CollEvent, RegEvent, partial_read, partial_read_path
from ..object_store import ObjectStore
from ..url_handler import ANY_HANDLER_T
DB_KIND = "sqlite" # :///
COMPRESS = zipfile.ZIP_DEFLATED
STATUS_DB_ZIP = f"db.{DB_KIND}.zip"
EVENTS_ZIP = "events.zip"
CONTEXT_YAML = "context.yaml"
_RUN_SPLIT = "-"
MySession = sessionmaker()
logger = get_logger("base depot")
def _get_git_hash():
# maybe as tag: git tag --sort=committerdate
try:
return check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
except CalledProcessError:
return None
def _pip_freeze():
comm = [sys.executable, "-m", "pip", "freeze"]
return sorted(check_output(comm).decode("utf-8").strip().split("\n"))
def _hash_str(s: str):
return md5(s.encode("utf-8")).hexdigest()[:20]
class _DepotObj:
__path = CONTEXT_YAML
@classmethod
def read(cls, dir_path: Path):
return cls(**yaml.safe_load((dir_path / cls.__path).read_text()))
def dump(self, dir_path: Path):
(dir_path / self.__path).write_text(yaml.dump(asdict(self)))
@dataclass
class Status(_DepotObj):
parent: Optional[str] = None
integrated_runs: list[str] = field(default_factory=list)
@property
def name(self):
_run_str = "-".join(sorted(self.integrated_runs))
return _hash_str(f"{self.parent}::{_run_str}")
@property
def is_root(self):
return (self.parent is None) and (len(self.integrated_runs) == 0)
@dataclass
class Run(_DepotObj):
commit_hash: str = field(default_factory=_get_git_hash)
pip_freeze: list[str] = field(default_factory=_pip_freeze)
start_timestamp: float = field(default_factory=time.time)
@dataclass
class StatusCache:
statuses: dict[str, Status] = field(default_factory=dict)
parent_keys: dict[str, set[str]] = field(default_factory=lambda: defaultdict(set))
def add(self, status: Status):
self.statuses[status.name] = status
self.parent_keys[status.parent].add(status.name)
def merge(self, other: "StatusCache"):
self.statuses.update(other.statuses)
for k, v in other.parent_keys.items():
self.parent_keys[k].update(v)
return self
@classmethod
def read(cls, path: Path):
return pickle.loads(path.read_bytes()) if path.exists() else cls()
def dump(self, path: Path):
path.write_bytes(pickle.dumps(self))
class Current:
def __init__(self, root: Path) -> None:
def _p(s) -> Path:
return root / s
self.root = root
self.db_path, self.parent, self.events, self.run_ctx = map(
_p, [f"db.{DB_KIND}", "parent", "events", CONTEXT_YAML]
)
self.db_constr = f"{DB_KIND}:///{self.db_path.as_posix()}"
self.engine: db.engine.Engine = None
self.next_batch = self._wrap(get_next_batch)
self.reset_surls = self._wrap(reset_surls)
def setup(self):
self.engine = db.create_engine(self.db_constr)
self.events.mkdir(parents=True, exist_ok=True)
Base.metadata.create_all(self.engine)
return self
def purge(self):
if self.root.exists():
rmtree(self.root)
def get_parent(self):
try:
return self.parent.read_text()
except FileNotFoundError:
return None
def any_in_progress(self):
with self._get_session() as session:
for status, _, count in get_grouped_surls(session):
if status == Statuses.PROCESSING:
return count > 0
def integrate_events(self, events: Iterable[Union[CollEvent, RegEvent]]):
self._wrap(integrate_events)(events, self.events)
def get_run_name(self):
hash_base = self.run_ctx.read_text() + "::".join(
sorted(map(Path.name.fget, self.events.iterdir()))
)
_ts = Run.read(self.root).start_timestamp
return _RUN_SPLIT.join(map(str, [_ts, _hash_str(hash_base)]))
@contextmanager
def _get_session(self):
session: Session = MySession(bind=self.engine)
yield session
session.close()
def _wrap(self, fun):
@wraps(fun)
def f(*args, **kwargs):
with self._get_session() as session:
return fun(session, *args, **kwargs)
return f
class DepotBase:
def __init__(self, name: str, local_root: Optional[Path] = None) -> None:
self.name = name
self.root = (
Path(local_root or os.environ.get(DEPOT_ROOT_ENV_VAR) or DEFAULT_DEPOT_ROOT)
/ name
)
self.object_store_path = self.root / "object-store"
self.object_store = ObjectStore(self.object_store_path)
self.statuses_path = self.root / "statuses"
self.runs_path = self.root / "runs"
self.current = Current(self.root / "current-run")
self._cache_path = self.root / "status-cache.pkl"
self._status_cache = self._load_status_cache()
self._init_dirs = [self.runs_path, self.statuses_path, self.object_store_path]
def setup(self, init=False):
for p in self._init_dirs:
p.mkdir(exist_ok=True, parents=True)
if init:
self.init_w_complete()
return self
def purge(self):
if self.root.exists():
rmtree(self.root)
self._status_cache = self._load_status_cache()
return self
def init_w_complete(self):
self.set_as_current(self.get_complete_status())
return self
def get_complete_status(self) -> Status:
# either an existing, a new or a blank status
leaf, leaf_tree = self._get_leaf(needs_db=True)
missing_runs = self.get_all_run_ids() - leaf_tree
if missing_runs:
return self.integrate(leaf, missing_runs)
return leaf
def get_status(self, status_name):
status = self._status_cache.statuses.get(status_name)
if status is None:
status = Status.read(self.statuses_path / status_name)
self._status_cache.add(status)
return status
def get_all_run_ids(self):
return set(map(Path.name.fget, self.runs_path.iterdir()))
def set_as_current(self, status: Status):
self.current.setup()
if not status.is_root:
self.current.parent.write_text(status.name)
with self._status_db_zip(status.name, "r") as zfp:
zfp.extract(self.current.db_path.name, path=self.current.root)
Run().dump(self.current.root)
def integrate(self, status: Status, runs: Iterable[str]) -> Status:
with TemporaryDirectory() as tmp_dir:
tmp_curr = Current(Path(tmp_dir)).setup()
try:
with self._status_db_zip(status.name, "r") as zfp:
zfp.extract(self.current.db_path.name, path=tmp_dir)
parent_name = status.name
except FileNotFoundError:
logger.warn(f"integrating to an empty database {status.name}")
parent_name = None
for run_name in runs:
tmp_curr.integrate_events(self._get_run_events(run_name, True))
out = Status(parent_name, list(runs))
return self._save_status_from_current(tmp_curr, out)
def save_current(self) -> Status:
# not saving a zero event run!
if not [*self.current.events.iterdir()]:
return
run_name = self.current.get_run_name()
run_dir = self.runs_path / run_name
run_dir.mkdir()
with self._run_events_zip(run_name, "w") as zfp:
for ev_path in self.current.events.iterdir():
zfp.write(ev_path, ev_path.name)
Run.read(self.current.root).dump(run_dir)
status = Status(self.current.get_parent(), [run_name])
return self._save_status_from_current(self.current, status)
def get_handler_events(
self,
handler: Optional[Union[str, ANY_HANDLER_T]] = None,
only_successful=True,
only_latest=True,
from_current: bool = False,
past_runs: Union[None, int, Iterable[str]] = None,
post_status: Optional[str] = None,
) -> Iterable["ParsedCollectionEvent"]:
urls = set()
handler_name = (
handler
if (isinstance(handler, str) or handler is None)
else handler.__name__
)
def _filter(ev: CollEvent):
return (
((handler_name is None) or (ev.handler == handler_name))
and ((not only_successful) or (ev.status in SUCCESS_STATUSES))
and ((not only_latest) or (ev.extend().url not in urls))
)
if post_status is not None:
old_tree = self._get_full_run_tree(self.get_status(post_status))
past_runs = self.get_all_run_ids() - old_tree
if from_current:
event_iters = [map(partial_read_path, self.current.events.iterdir())]
elif isinstance(past_runs, int):
event_iters = islice(self._iter_runs(), past_runs)
elif past_runs is None:
event_iters = self._iter_runs()
else:
event_iters = map(self._get_run_events, sorted(past_runs, reverse=True))
for ev_iter in event_iters:
for ev in filter(_filter, get_sorted_coll_events(ev_iter)):
yield ParsedCollectionEvent(ev, self.object_store)
if only_latest:
urls.add(ev.url)
def cleanup_statuses(self):
errs = {}
err_set = set()
while True:
for st in self.statuses_path.iterdir():
try:
status = Status.read(st)
assert status.parent not in errs.keys()
except Exception as e:
errs[st.name] = e
if err_set == set(errs.keys()):
break
err_set = set(errs.keys())
for err in err_set:
rmtree(self.statuses_path / err)
return errs
def _get_leaf(self, needs_db=False):
# just one that has no children
# and has the most runs in its tree
most_runs = 0
leaf, leaf_tree = Status(), set()
self.statuses_path.mkdir(exist_ok=True)
local_names = [sp.name for sp in self.statuses_path.iterdir()]
status_names = set([*local_names, *self._status_cache.statuses.keys()])
for status_name in status_names:
candidate = self.get_status(status_name)
if self._status_cache.parent_keys[status_name]:
continue
sdb = self.statuses_path / status_name / STATUS_DB_ZIP
if needs_db and (not sdb.exists()):
continue
candidate_tree = self._get_full_run_tree(candidate)
_run_count = len(candidate_tree)
if _run_count >= most_runs:
leaf, leaf_tree = candidate, candidate_tree
most_runs = _run_count
return leaf, leaf_tree
def _get_full_run_tree(self, status: Status) -> set[str]:
out = set(status.integrated_runs)
parent = status.parent
while parent:
parent_status = self.get_status(parent)
out |= set(parent_status.integrated_runs)
parent = parent_status.parent
return out
def _iter_runs(self) -> Iterable[Iterable[CollEvent]]:
runs = []
for run_path in self.runs_path.glob("*"):
heappush(runs, (-_start_timestamp_from_run_path(run_path), run_path.name))
while runs:
_, run_name = heappop(runs)
yield self._get_run_events(run_name)
def _get_run_events(self, run_name, extend=True):
with self._run_events_zip(run_name, "r") as zfp:
for event in zfp.filelist:
if extend:
yield partial_read(
event.filename, partial(zfp.read, event)
).extend()
else:
_fun = partial(_read_event_blob, self.runs_path, run_name, event)
yield partial_read(event.filename, _fun)
def _save_status_from_current(self, current: Current, status: Status):
status_dir = self.statuses_path / status.name
status_dir.mkdir(parents=True)
status.dump(status_dir)
with self._status_db_zip(status.name, "w") as zfp:
zfp.write(current.db_path, current.db_path.name)
self._status_cache.add(status)
return status
def _status_db_zip(self, status_name, mode):
return _zipfile(self.statuses_path, status_name, STATUS_DB_ZIP, mode)
def _run_events_zip(self, run_name, mode):
return _zipfile(self.runs_path, run_name, EVENTS_ZIP, mode)
def _load_status_cache(self) -> StatusCache:
return StatusCache.read(self._cache_path)
class ParsedCollectionEvent:
def __init__(self, cev: "CollEvent", store: ObjectStore):
self.cev = cev
self.handler_name = cev.handler
self._ostore = store
self._time = cev.timestamp
self.status = cev.status
@property
def content(self):
self.cev.extend()
of = self.cev.output_file
return self._ostore.read(of) if of else None
@property
def url(self):
self.cev.extend()
return self.cev.url
def __repr__(self):
return f"{self.status}: {self.handler_name} - {self.url} ({self._time})"
def get_sorted_coll_events(event_iterator: Iterable) -> Iterable[CollEvent]:
coll_evs = []
for ev in event_iterator:
if isinstance(ev, CollEvent):
# ordered based on most recent
heappush(coll_evs, ev)
while coll_evs:
yield heappop(coll_evs)
def _read_event_blob(root, dirname, event_name):
with _zipfile(root, dirname, EVENTS_ZIP, "r") as zfp:
return zfp.read(event_name)
def _start_timestamp_from_run_path(p: Path):
return float(p.name.split(_RUN_SPLIT)[0])
def _zipfile(root, dirname, filename, mode) -> zipfile.ZipFile:
return zipfile.ZipFile(root / dirname / filename, mode, compression=COMPRESS)