williamfzc/stagesepx

View on GitHub
stagesepx/video.py

Summary

Maintainability
A
35 mins
Test Coverage
import os
import tempfile
import typing

import cv2
import imageio_ffmpeg
import numpy as np
from loguru import logger
import moviepy.editor as mpy

from stagesepx import toolbox

if typing.TYPE_CHECKING:
    from stagesepx.hook import BaseHook


class VideoFrame(object):
    def __init__(self, frame_id: int, timestamp: float, data: np.ndarray):
        self.frame_id: int = frame_id
        self.timestamp: float = timestamp
        self.data: np.ndarray = data

    def __str__(self):
        return f"<VideoFrame id={self.frame_id} timestamp={self.timestamp}>"

    @classmethod
    def init(cls, cap: cv2.VideoCapture, frame: np.ndarray) -> "VideoFrame":
        frame_id = toolbox.get_current_frame_id(cap)
        timestamp = toolbox.get_current_frame_time(cap)
        grey = toolbox.turn_grey(frame)
        logger.debug(f"new a frame: {frame_id}({timestamp})")
        return VideoFrame(frame_id, timestamp, grey)

    def copy(self):
        return VideoFrame(self.frame_id, self.timestamp, self.data[:])

    def contain_image(
        self, *, image_path: str = None, image_object: np.ndarray = None, **kwargs
    ) -> typing.Dict[str, typing.Any]:
        assert image_path or (
            image_object is not None
        ), "should fill image_path or image_object"

        if image_path:
            logger.debug(f"found image path, use it first: {image_path}")
            return toolbox.match_template_with_path(image_path, self.data, **kwargs)
        image_object = toolbox.turn_grey(image_object)
        return toolbox.match_template_with_object(image_object, self.data, **kwargs)


class _BaseFrameOperator(object):
    def __init__(self, video: "VideoObject"):
        # pointer
        self.cur_ptr: int = 0
        self.video: VideoObject = video

    def get_frame_by_id(self, frame_id: int) -> typing.Optional[VideoFrame]:
        raise NotImplementedError

    def get_length(self) -> int:
        return self.video.frame_count


class MemFrameOperator(_BaseFrameOperator):
    def get_frame_by_id(self, frame_id: int) -> typing.Optional[VideoFrame]:
        if frame_id > self.get_length():
            return None
        # list starts from zero, but frame starts from one
        frame_id = frame_id - 1
        return self.video.data[frame_id].copy()


class FileFrameOperator(_BaseFrameOperator):
    def get_frame_by_id(self, frame_id: int) -> typing.Optional[VideoFrame]:
        if frame_id > self.get_length():
            return None
        with toolbox.video_capture(self.video.path) as cap:
            toolbox.video_jump(cap, frame_id)
            success, frame = cap.read()
            video_frame = VideoFrame.init(cap, frame) if success else None
        return video_frame


class VideoObject(object):
    def __init__(
        self,
        path: typing.Union[str, os.PathLike],
        pre_load: bool = None,
        fps: int = None,
        *_,
        **__,
    ):
        assert os.path.isfile(path), f"video {path} not existed"
        self.path: str = str(path)
        self.data: typing.Optional[typing.Tuple[VideoFrame]] = tuple()
        self._hook_list: typing.List["BaseHook"] = []

        self.fps: int = fps
        if fps:
            video_path = os.path.join(tempfile.mkdtemp(), f"tmp_{fps}.mp4")
            logger.debug(f"convert video, and bind path to {video_path}")
            toolbox.fps_convert(
                fps, self.path, video_path, imageio_ffmpeg.get_ffmpeg_exe()
            )
            self.path = video_path

        with toolbox.video_capture(self.path) as cap:
            self.frame_count = toolbox.get_frame_count(cap)
            self.frame_size = toolbox.get_frame_size(cap)

        if pre_load is not None:
            logger.warning(
                f"`pre_load` has been deprecated. use `video.load_frames()` instead"
            )
        logger.info(
            f"video object generated, length: {self.frame_count}, size: {self.frame_size}"
        )

    def __str__(self):
        return f"<VideoObject path={self.path}>"

    __repr__ = __str__

    def sync_timestamp(self):
        vid = mpy.VideoFileClip(self.path)

        # moviepy start from 0, 0.0
        # but stagesepx is 1, 0.0
        assert self.data, "load_frames() first"
        for frame_id, (timestamp, _) in enumerate(vid.iter_frames(with_times=True)):
            if frame_id >= len(self.data):
                # ignore the rest
                break
            frame_id_real = frame_id + 1
            if not self.data[frame_id].timestamp:
                logger.debug(f"fix frame {frame_id_real}'s timestamp: {timestamp}")
                self.data[frame_id].timestamp = timestamp
        logger.info("sync timestamp with moviepy finished")

    def add_preload_hook(self, new_hook: "BaseHook"):
        """this hook only will be executed when preload"""
        self._hook_list.append(new_hook)

    def clean_frames(self):
        self.data = tuple()

    def load_frames(self, *args, **kwargs):
        logger.info(f"start loading {self.path} to memory ...")

        data: typing.List[VideoFrame] = []
        with toolbox.video_capture(self.path) as cap:
            # the first
            success, frame = cap.read()
            while success:
                frame_object = VideoFrame.init(cap, frame)
                # apply hooks
                for each_hook in self._hook_list:
                    frame_object = each_hook.do(frame_object, *args, **kwargs)
                data.append(frame_object)
                # read the next one
                success, frame = cap.read()

        # calculate memory cost
        each_cost = data[0].data.nbytes
        logger.debug(f"single frame cost: {each_cost} bytes")
        total_cost = each_cost * self.frame_count
        logger.debug(f"total frame cost: {total_cost} bytes")

        # lock the order
        self.data = tuple(data)
        # fix the length ( the last frame may be broken sometimes )
        self.frame_count = len(data)
        # and size (reversed, see: https://github.com/williamfzc/stagesepx/issues/132)
        self.frame_size = data[0].data.shape[::-1]
        logger.info(
            f"frames loaded. frame count: {self.frame_count}, size: {self.frame_size}, memory cost: {total_cost} bytes"
        )

        # sync timestamp for some newer versions opencv
        # see: #178, #181
        self.sync_timestamp()

    def _read_from_file(self) -> typing.Generator[VideoFrame, None, None]:
        with toolbox.video_capture(self.path) as cap:
            success, frame = cap.read()
            while success:
                yield VideoFrame.init(cap, frame)
                success, frame = cap.read()

    def _read_from_mem(self) -> typing.Generator[VideoFrame, None, None]:
        for each_frame in self.data:
            yield each_frame

    def _read(self) -> typing.Generator[VideoFrame, None, None]:
        if self.data:
            yield from self._read_from_mem()
        else:
            yield from self._read_from_file()

    def get_iterator(self) -> typing.Generator[VideoFrame, None, None]:
        return self._read()

    def get_operator(self) -> _BaseFrameOperator:
        if self.data:
            return MemFrameOperator(self)
        return FileFrameOperator(self)

    def __iter__(self):
        return self.get_iterator()