dvc/repo/experiments/queue/remove.py
from collections.abc import Collection, Iterable
from typing import TYPE_CHECKING, Union
from dvc.repo.experiments.exceptions import UnresolvedExpNamesError
from dvc.repo.experiments.queue.base import QueueDoneResult
if TYPE_CHECKING:
from dvc.repo.experiments.queue.base import QueueEntry
from dvc.repo.experiments.queue.celery import LocalCeleryQueue
from dvc.repo.experiments.stash import ExpStashEntry
def remove_tasks( # noqa: C901, PLR0912
celery_queue: "LocalCeleryQueue",
queue_entries: Iterable["QueueEntry"],
):
"""Remove tasks from task queue.
Arguments:
queue_entries: An iterable list of task to remove
"""
from celery.result import AsyncResult
stash_revs: dict[str, "ExpStashEntry"] = {}
failed_stash_revs: list["ExpStashEntry"] = []
done_entry_set: set["QueueEntry"] = set()
stash_rev_all = celery_queue.stash.stash_revs
failed_rev_all: dict[str, "ExpStashEntry"] = {}
if celery_queue.failed_stash:
failed_rev_all = celery_queue.failed_stash.stash_revs
for entry in queue_entries:
if entry.stash_rev in stash_rev_all:
stash_revs[entry.stash_rev] = stash_rev_all[entry.stash_rev]
else:
done_entry_set.add(entry)
if entry.stash_rev in failed_rev_all:
failed_stash_revs.append(failed_rev_all[entry.stash_rev])
try:
for msg, queue_entry in celery_queue._iter_queued():
if queue_entry.stash_rev in stash_revs and msg.delivery_tag:
celery_queue.celery.reject(msg.delivery_tag)
finally:
celery_queue.stash.remove_revs(list(stash_revs.values()))
try:
for msg, queue_entry in celery_queue._iter_processed():
if queue_entry not in done_entry_set:
continue
task_id = msg.headers["id"]
result: AsyncResult = AsyncResult(task_id)
if result is not None:
result.forget()
if msg.delivery_tag:
celery_queue.celery.purge(msg.delivery_tag)
finally:
if celery_queue.failed_stash:
celery_queue.failed_stash.remove_revs(failed_stash_revs)
def _get_names(entries: Iterable[Union["QueueEntry", "QueueDoneResult"]]):
names: list[str] = []
for entry in entries:
if isinstance(entry, QueueDoneResult):
if entry.result and entry.result.ref_info:
names.append(entry.result.ref_info.name)
continue
entry = entry.entry
name = entry.name
name = name or entry.stash_rev[:7]
names.append(name)
return names
def celery_clear(
self: "LocalCeleryQueue",
queued: bool = False,
failed: bool = False,
success: bool = False,
) -> list[str]:
"""Remove entries from the queue.
Arguments:
queued: Remove all queued tasks.
failed: Remove all failed tasks.
success: Remove all success tasks.
Returns:
Revisions which were removed.
"""
removed: list[str] = []
entry_list: list["QueueEntry"] = []
if queued:
queue_entries: list["QueueEntry"] = list(self.iter_queued())
entry_list.extend(queue_entries)
removed.extend(_get_names(queue_entries))
if failed:
failed_tasks: list["QueueDoneResult"] = list(self.iter_failed())
entry_list.extend([result.entry for result in failed_tasks])
removed.extend(_get_names(failed_tasks))
if success:
success_tasks: list["QueueDoneResult"] = list(self.iter_success())
entry_list.extend([result.entry for result in success_tasks])
removed.extend(_get_names(success_tasks))
remove_tasks(self, entry_list)
return removed
def celery_remove(self: "LocalCeleryQueue", revs: Collection[str]) -> list[str]:
"""Remove the specified entries from the queue.
Arguments:
revs: Stash revisions or queued exp names to be removed.
Returns:
Revisions (or names) which were removed.
"""
match_results = self.match_queue_entry_by_name(
revs, self.iter_queued(), self.iter_done()
)
remained: list[str] = []
removed: list[str] = []
entry_to_remove: list["QueueEntry"] = []
for name, entry in match_results.items():
if entry:
entry_to_remove.append(entry)
removed.append(name)
else:
remained.append(name)
if remained:
raise UnresolvedExpNamesError(remained)
if entry_to_remove:
remove_tasks(self, entry_to_remove)
return removed