stagesepx/cutter/cut_result.py
import os
import typing
import cv2
import uuid
import json
import numpy as np
from loguru import logger
import difflib
from stagesepx import toolbox
from stagesepx.hook import BaseHook
from stagesepx.video import VideoObject, VideoFrame
from stagesepx.cutter.cut_range import VideoCutRange
class VideoCutResult(object):
def __init__(
self,
video: VideoObject,
range_list: typing.List[VideoCutRange],
cut_kwargs: typing.Dict = None,
):
self.video = video
self.range_list = range_list
# kwargs sent to `cut` function
self.cut_kwargs = cut_kwargs or {}
def get_target_range_by_id(self, frame_id: int) -> VideoCutRange:
"""get target VideoCutRange by id (which belongs to)"""
for each in self.range_list:
if each.contain(frame_id):
return each
raise RuntimeError(f"frame {frame_id} not found in video")
@staticmethod
def _length_filter(
range_list: typing.List[VideoCutRange], limit: int
) -> typing.List[VideoCutRange]:
after = list()
for each in range_list:
if each.get_length() >= limit:
after.append(each)
return after
def get_unstable_range(
self, limit: int = None, range_threshold: float = None, **kwargs
) -> typing.List[VideoCutRange]:
"""return unstable range only"""
change_range_list = sorted(
[i for i in self.range_list if not i.is_stable(**kwargs)],
key=lambda x: x.start,
)
# video can be totally stable ( nothing changed )
# or only one unstable range
if len(change_range_list) <= 1:
return change_range_list
# merge
i = 0
merged_change_range_list = list()
while i < len(change_range_list) - 1:
cur = change_range_list[i]
while cur.can_merge(change_range_list[i + 1], **kwargs):
# can be merged
i += 1
cur = cur.merge(change_range_list[i], **kwargs)
# out of range
if i + 1 >= len(change_range_list):
break
merged_change_range_list.append(cur)
i += 1
if change_range_list[-1].start > merged_change_range_list[-1].end:
merged_change_range_list.append(change_range_list[-1])
if limit:
merged_change_range_list = self._length_filter(
merged_change_range_list, limit
)
# merged range check
if range_threshold:
merged_change_range_list = [
i for i in merged_change_range_list if not i.is_loop(range_threshold)
]
logger.debug(
f"unstable range of [{self.video.path}]: {merged_change_range_list}"
)
return merged_change_range_list
def get_range(
self, limit: int = None, unstable_limit: int = None, **kwargs
) -> typing.Tuple[typing.List[VideoCutRange], typing.List[VideoCutRange]]:
"""
return stable_range_list and unstable_range_list
:param limit: ignore some ranges which are too short, 5 means ignore stable ranges which length < 5
:param unstable_limit: ignore some ranges which are too short, 5 means ignore unstable ranges which length < 5
:param kwargs:
threshold: float, 0-1, default to 0.98. decided whether a range is stable. larger => more unstable ranges
range_threshold:
same as threshold, but it decided whether a merged range is stable.
see https://github.com/williamfzc/stagesepx/issues/17 for details
offset:
it will change the way to decided whether two ranges can be merged
before: first_range.end == second_range.start
after: first_range.end + offset >= secord_range.start
:return:
"""
"""
videos have 4 kinds of status:
- stable start + stable end (usually)
- stable start + unstable end
- unstable start + stable end
- unstable start + unstable end
so, unstable range list can be:
- start > 0, end < frame_count
- start = 0, end < frame_count
- start > 0, end = frame_count
- start = 0, end = frame_count
"""
unstable_range_list = self.get_unstable_range(unstable_limit, **kwargs)
# start point
video_start_frame_id = 1
video_start_timestamp = 0.0
# end point
video_end_frame_id = self.range_list[-1].end
video_end_timestamp = self.range_list[-1].end_time
# default values
_default = {
"ssim": [1.0],
"mse": [0.0],
"psnr": [0.0],
}
# stable all the time
if len(unstable_range_list) == 0:
logger.warning(
"no unstable stage detected, seems nothing happened in your video"
)
return (
# stable
[
VideoCutRange(
video=self.video,
start=video_start_frame_id,
end=video_end_frame_id,
start_time=video_start_timestamp,
end_time=video_end_timestamp,
**_default,
)
],
# unstable
[],
)
# IMPORTANT: +1 and -1 easily cause error
# end of first stable range == start of first unstable range
first_stable_range_end_id = unstable_range_list[0].start - 1
# start of last stable range == end of last unstable range
end_stable_range_start_id = unstable_range_list[-1].end + 1
# IMPORTANT: len(ssim_list) + 1 == video_end_frame_id
range_list: typing.List[VideoCutRange] = list()
# stable start
if first_stable_range_end_id >= 1:
logger.debug(f"stable start")
range_list.append(
VideoCutRange(
video=self.video,
start=video_start_frame_id,
end=first_stable_range_end_id,
start_time=video_start_timestamp,
end_time=self.get_target_range_by_id(
first_stable_range_end_id
).end_time,
**_default,
)
)
# unstable start
else:
logger.debug("unstable start")
# stable end
if end_stable_range_start_id <= video_end_frame_id:
logger.debug("stable end")
range_list.append(
VideoCutRange(
video=self.video,
start=end_stable_range_start_id,
end=video_end_frame_id,
start_time=self.get_target_range_by_id(
end_stable_range_start_id
).end_time,
end_time=video_end_timestamp,
**_default,
)
)
# unstable end
else:
logger.debug("unstable end")
# diff range
for i in range(len(unstable_range_list) - 1):
range_start_id = unstable_range_list[i].end + 1
range_end_id = unstable_range_list[i + 1].start - 1
# stable range's length is 1
if range_start_id > range_end_id:
range_start_id, range_end_id = range_end_id, range_start_id
range_list.append(
# IMPORTANT: frame's timestamp => start time of this frame
# because frame 1's timestamp is 0.0
# frame {range_start_id} start time - frame {range_end_id} start time
VideoCutRange(
video=self.video,
start=range_start_id,
end=range_end_id,
start_time=self.get_target_range_by_id(range_start_id).start_time,
end_time=self.get_target_range_by_id(range_end_id).start_time,
**_default,
)
)
# remove some ranges, which is limit
if limit:
range_list = self._length_filter(range_list, limit)
logger.debug(f"stable range of [{self.video.path}]: {range_list}")
stable_range_list = sorted(range_list, key=lambda x: x.start)
return stable_range_list, unstable_range_list
def get_stable_range(
self, limit: int = None, **kwargs
) -> typing.List[VideoCutRange]:
"""return stable range only"""
return self.get_range(limit, **kwargs)[0]
def get_range_dynamic(
self,
stable_num_limit: typing.List[int],
threshold: float,
step: float = 0.005,
max_retry: int = 10,
**kwargs,
) -> typing.Tuple[typing.List[VideoCutRange], typing.List[VideoCutRange]]:
"""this method was designed for supporting flexible threshold range"""
assert max_retry != 0, f"fail to get range dynamically: {stable_num_limit}"
assert len(stable_num_limit) == 2, "num_limit should be something like [1, 3]"
assert 0.0 < threshold < 1.0, "threshold out of range"
stable, unstable = self.get_range(threshold=threshold, **kwargs)
cur_num = len(stable)
logger.debug(f"current stable range is {cur_num}")
if stable_num_limit[0] <= cur_num <= stable_num_limit[1]:
logger.debug(f"range num is fine")
return stable, unstable
# too fewer stages
if cur_num < stable_num_limit[0]:
logger.debug("too fewer stages")
threshold += step
# too many
elif cur_num > stable_num_limit[1]:
logger.debug("too many stages")
threshold -= step
return self.get_range_dynamic(
stable_num_limit, threshold=threshold, max_retry=max_retry - 1, **kwargs
)
def thumbnail(
self,
target_range: VideoCutRange,
to_dir: str = None,
compress_rate: float = None,
is_vertical: bool = None,
*_,
**__,
) -> np.ndarray:
"""
build a thumbnail, for easier debug or something else
:param target_range: VideoCutRange
:param to_dir: your thumbnail will be saved to this path
:param compress_rate: float, 0 - 1, about thumbnail's size, default to 0.1 (1/10)
:param is_vertical: direction
:return:
"""
if not compress_rate:
compress_rate = 0.1
# direction
if is_vertical:
stack_func = np.vstack
def get_split_line(f):
return np.zeros((5, f.shape[1]))
else:
stack_func = np.hstack
def get_split_line(f):
return np.zeros((f.shape[0], 5))
frame_list = list()
with toolbox.video_capture(self.video.path) as cap:
toolbox.video_jump(cap, target_range.start)
ret, frame = cap.read()
count = 1
length = target_range.get_length()
while ret and count <= length:
frame = toolbox.compress_frame(frame, compress_rate)
frame_list.append(frame)
frame_list.append(get_split_line(frame))
ret, frame = cap.read()
count += 1
merged = stack_func(frame_list)
# create parent dir
if to_dir:
target_path = os.path.join(
to_dir, f"thumbnail_{target_range.start}-{target_range.end}.png"
)
cv2.imwrite(target_path, merged)
logger.debug(f"save thumbnail to {target_path}")
return merged
def pick_and_save(
self,
range_list: typing.List[VideoCutRange],
frame_count: int,
to_dir: str = None,
prune: float = None,
meaningful_name: bool = None,
# in kwargs
# compress_rate: float = None,
# target_size: typing.Tuple[int, int] = None,
# to_grey: bool = None,
*args,
**kwargs,
) -> str:
"""
pick some frames from range, and save them as files
:param range_list: VideoCutRange list
:param frame_count: default to 3, and finally you will get 3 frames for each range
:param to_dir: will saved to this path
:param prune: float, 0-1. if set it 0.9, some stages which are too similar (ssim > 0.9) will be removed
:param meaningful_name: bool, False by default. if true, image names will become meaningful (with timestamp/id or something else)
:param args:
:param kwargs:
:return:
"""
stage_list = list()
# build tag and get frames
for index, each_range in enumerate(range_list):
picked = each_range.pick(frame_count, *args, **kwargs)
picked_frames = each_range.get_frames(picked)
logger.info(f"pick {picked} in range {each_range}")
stage_list.append((str(index), picked_frames))
# prune
if prune:
stage_list = self._prune(prune, stage_list)
# create parent dir
if not to_dir:
to_dir = toolbox.get_timestamp_str()
logger.debug(f"try to make dirs: {to_dir}")
os.makedirs(to_dir, exist_ok=True)
for each_stage_id, each_frame_list in stage_list:
# create sub dir
each_stage_dir = os.path.join(to_dir, str(each_stage_id))
if os.path.isdir(each_stage_dir):
logger.warning(f"sub dir [{each_stage_dir}] already existed")
logger.warning(
"NOTICE: make sure your data will not be polluted by accident"
)
os.makedirs(each_stage_dir, exist_ok=True)
# create image files
for each_frame_object in each_frame_list:
if meaningful_name:
# - video name
# - frame id
# - frame timestamp
image_name = (
f"{os.path.basename(os.path.splitext(self.video.path)[0])}"
f"_"
f"{each_frame_object.frame_id}"
f"_"
f"{each_frame_object.timestamp}"
f".png"
)
else:
image_name = f"{uuid.uuid4()}.png"
each_frame_path = os.path.join(each_stage_dir, image_name)
compressed = toolbox.compress_frame(each_frame_object.data, **kwargs)
cv2.imwrite(each_frame_path, compressed)
logger.debug(
f"frame [{each_frame_object.frame_id}] saved to {each_frame_path}"
)
return to_dir
@staticmethod
def _prune(
threshold: float,
stages: typing.List[typing.Tuple[str, typing.List[VideoFrame]]],
) -> typing.List[typing.Tuple[str, typing.List[VideoFrame]]]:
logger.debug(
f"start pruning ranges, origin length is {len(stages)}, threshold is {threshold}"
)
after = list()
for i in range(len(stages)):
index, frames = stages[i]
for j in range(i + 1, len(stages)):
next_index, next_frames = stages[j]
ssim_list = toolbox.multi_compare_ssim(frames, next_frames)
min_ssim = min(ssim_list)
logger.debug(f"compare {index} with {next_index}: {ssim_list}")
if min_ssim > threshold:
logger.debug(f"stage {index} has been pruned")
break
else:
after.append(stages[i])
return after
def dumps(self) -> str:
# for np.ndarray
def _handler(obj: object):
if isinstance(obj, np.ndarray):
# ignore
return "<np.ndarray object>"
return obj.__dict__
return json.dumps(self, sort_keys=True, default=_handler)
def dump(self, json_path: str, **kwargs):
logger.debug(f"dump result to {json_path}")
assert not os.path.exists(json_path), f"{json_path} already existed"
with open(json_path, "w+", **kwargs) as f:
f.write(self.dumps())
@classmethod
def loads(cls, content: str) -> "VideoCutResult":
json_dict: dict = json.loads(content)
return cls(
VideoObject(**json_dict["video"]),
[VideoCutRange(**each) for each in json_dict["range_list"]],
)
@classmethod
def load(cls, json_path: str, **kwargs) -> "VideoCutResult":
logger.debug(f"load result from {json_path}")
with open(json_path, **kwargs) as f:
return cls.loads(f.read())
def diff(
self,
another: "VideoCutResult",
auto_merge: bool = None,
pre_hooks: typing.List[BaseHook] = None,
output_path: str = None,
*args,
**kwargs,
) -> "VideoCutResultDiff":
"""
compare cut result with another one
:param output_path:
:param pre_hooks:
:param another: another VideoCutResult object
:param auto_merge: bool, will auto merge diff result and make it simple
:param args:
:param kwargs:
:return:
"""
self_stable, _ = self.get_range(*args, **kwargs)
another_stable, _ = another.get_range(*args, **kwargs)
self.pick_and_save(self_stable, 3, to_dir=output_path)
another.pick_and_save(another_stable, 3, to_dir=output_path)
result = VideoCutResultDiff(self_stable, another_stable)
result.apply_diff(pre_hooks)
if auto_merge:
after = dict()
for self_stage_name, each_result in result.data.items():
max_one = sorted(each_result.items(), key=lambda x: max(x[1]))[-1]
max_one = (max_one[0], max(max_one[1]))
after[self_stage_name] = max_one
result.data = after
return result
@staticmethod
def range_diff(
range_list_1: typing.List[VideoCutRange],
range_list_2: typing.List[VideoCutRange],
*args,
**kwargs,
) -> typing.Dict[int, typing.Dict[int, typing.List[float]]]:
# 1. stage length compare
self_stable_range_count = len(range_list_1)
another_stable_range_count = len(range_list_2)
if self_stable_range_count != another_stable_range_count:
logger.warning(
f"stage counts not equal: {self_stable_range_count} & {another_stable_range_count}"
)
# 2. stage content compare
# TODO will load these pictures in memory at the same time
data = dict()
for self_id, each_self_range in enumerate(range_list_1):
temp = dict()
for another_id, another_self_range in enumerate(range_list_2):
temp[another_id] = each_self_range.diff(
another_self_range, *args, **kwargs
)
data[self_id] = temp
return data
class VideoCutResultDiff(object):
"""
assume origin video's stages: 1 -> 2 -> 3 -> 4
its diff can be:
- stage new : 1-2-5-3-4
- stage replace: 1-5-3-4
- stage lost : 1-3-4
https://github.com/williamfzc/stagesepx/issues/158
"""
threshold: float = 0.7
default_stage_id: int = -1
default_score: float = -1.0
def __init__(
self, origin: typing.List[VideoCutRange], another: typing.List[VideoCutRange]
):
self.origin = origin
self.another = another
self.data: typing.Optional[
typing.Dict[int, typing.Dict[int, typing.List[float]]]
] = None
def apply_diff(self, pre_hooks: typing.List[BaseHook] = None):
self.data = VideoCutResult.range_diff(self.origin, self.another, pre_hooks)
def most_common(self, stage_id: int) -> (int, float):
assert stage_id in self.data
ret_k, ret_v = self.default_stage_id, self.default_score
for k, v in self.data[stage_id].items():
cur = max(v)
if cur > ret_v:
ret_k = k
ret_v = cur
return ret_k, ret_v
def is_stage_lost(self, stage_id: int) -> bool:
# what we care most
_, v = self.most_common(stage_id)
return v < self.threshold
def any_stage_lost(self) -> bool:
return all((self.is_stage_lost(each) for each in self.data.keys()))
def stage_shift(self) -> typing.List[int]:
ret = list()
for k in self.data.keys():
new_k, score = self.most_common(k)
if score > self.threshold:
ret.append(new_k)
return ret
def stage_diff(self):
return difflib.Differ().compare(
[str(each) for each in self.stage_shift()],
[str(each) for each in range(len(self.another))],
)