williamfzc/stagesepx

View on GitHub
stagesepx/classifier/keras.py

Summary

Maintainability
A
0 mins
Test Coverage
from loguru import logger
import os
import cv2
import typing
import numpy as np
import pathlib

try:
    import tensorflow
except ImportError:
    raise ImportError("KerasClassifier requires tensorflow. install it first.")

try:
    from keras.preprocessing.image import ImageDataGenerator
except ImportError:
    # https://stackoverflow.com/questions/78145837/importerror-cannot-import-name-imagedatagenerator-from-keras-preprocessing-i
    from keras.src.legacy.preprocessing.image import ImageDataGenerator

# https://github.com/tensorflow/models/issues/6177
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K

from stagesepx.classifier.base import BaseModelClassifier
from stagesepx import toolbox
from stagesepx.video import VideoFrame
from stagesepx import constants


class KerasClassifier(BaseModelClassifier):
    UNKNOWN_STAGE_NAME = constants.UNKNOWN_STAGE_FLAG
    # https://github.com/williamfzc/stagesepx/issues/112
    MODEL_DENSE = 6

    def __init__(
        self,
        score_threshold: float = None,
        data_size: typing.Sequence[int] = None,
        nb_train_samples: int = None,
        nb_validation_samples: int = None,
        epochs: int = None,
        batch_size: int = None,
        *_,
        **__,
    ):
        super(KerasClassifier, self).__init__(*_, **__)

        # model
        self._model: typing.Optional[Sequential] = None
        # settings
        self.score_threshold: float = score_threshold or 0.0
        self.data_size: typing.Sequence[int] = data_size or (200, 200)
        self.nb_train_samples: int = nb_train_samples or 64
        self.nb_validation_samples: int = nb_validation_samples or 64
        self.epochs: int = epochs or 20
        self.batch_size: int = batch_size or 4

        logger.debug(f"score threshold: {self.score_threshold}")
        logger.debug(f"data size: {self.data_size}")
        logger.debug(f"nb train samples: {self.nb_train_samples}")
        logger.debug(f"nb validation samples: {self.nb_validation_samples}")
        logger.debug(f"epochs: {self.epochs}")
        logger.debug(f"batch size: {self.batch_size}")

    def clean_model(self):
        self._model = None

    def save_model(self, model_path: str, overwrite: bool = None):
        """
        save trained weights

        :param model_path:
        :param overwrite:
        :return:
        """
        logger.debug(f"save model to {model_path}")
        # assert model file
        if os.path.isfile(model_path) and not overwrite:
            raise FileExistsError(
                f"model file {model_path} already existed, you can set `overwrite` True to cover it"
            )
        # assert model data is not empty
        assert self._model, "model is empty"
        self._model.save_weights(model_path)

    def load_model(self, model_path: str, overwrite: bool = None):
        """
        load trained model

        :param model_path:
        :param overwrite:
        :return:
        """
        logger.debug(f"load model from {model_path}")
        # assert model file
        assert os.path.isfile(model_path), f"model file {model_path} not existed"
        # assert model data is empty
        if self._model and not overwrite:
            raise RuntimeError(
                f"model is not empty, you can set `overwrite` True to cover it"
            )
        self._model = self.create_model()
        self._model.load_weights(model_path)

    def create_model(self) -> Sequential:
        """model structure. you can overwrite this method to build your own model"""
        logger.info(f"creating keras sequential model")
        if K.image_data_format() == "channels_first":
            input_shape = (1, *self.data_size)
        else:
            input_shape = (*self.data_size, 1)

        model = Sequential()
        model.add(Conv2D(32, (3, 3), input_shape=input_shape))
        model.add(Activation("relu"))
        model.add(MaxPooling2D(pool_size=(2, 2)))

        model.add(Conv2D(32, (3, 3)))
        model.add(Activation("relu"))
        model.add(MaxPooling2D(pool_size=(2, 2)))

        model.add(Conv2D(64, (3, 3)))
        model.add(Activation("relu"))
        model.add(MaxPooling2D(pool_size=(2, 2)))
        model.add(Flatten())
        model.add(Dense(64))
        model.add(Activation("relu"))
        model.add(Dropout(0.5))
        model.add(Dense(self.MODEL_DENSE))

        model.add(Activation("softmax"))

        model.compile(
            loss="sparse_categorical_crossentropy",
            optimizer="rmsprop",
            metrics=["accuracy"],
        )
        logger.info("model created")
        return model

    def train(self, data_path: str = None, *_, **__):
        """
        train your classifier with data. must be called before prediction

        :return:
        """

        def _data_verify(p: str):
            p = pathlib.Path(p)
            assert p.is_dir(), f"{p} is not a valid directory"
            # validate: at least two classes
            number_of_dir = len([each for each in os.listdir(p) if (p / each).is_dir()])
            assert (
                number_of_dir > 1
            ), f"dataset only contains one class. maybe some path errors happened: {p}?"

            # more than 6 classes?
            assert number_of_dir <= self.MODEL_DENSE, (
                f"dataset has {number_of_dir} classes (more than " + str(self.MODEL_DENSE) + "), please see "
                f"https://github.com/williamfzc/stagesepx/issues/112 "
            )

        _data_verify(data_path)

        if not self._model:
            logger.debug("no model can be used. build a new one.")
            self._model = self.create_model()
        else:
            logger.debug("model found")

        datagen = ImageDataGenerator(
            rescale=1.0 / 16, shear_range=0.2, zoom_range=0.2, validation_split=0.33
        )

        train_generator = datagen.flow_from_directory(
            data_path,
            target_size=self.data_size,
            batch_size=self.batch_size,
            color_mode="grayscale",
            class_mode="sparse",
            subset="training",
        )

        validation_generator = datagen.flow_from_directory(
            data_path,
            target_size=self.data_size,
            batch_size=self.batch_size,
            color_mode="grayscale",
            class_mode="sparse",
            subset="validation",
        )

        self._model.fit(
            train_generator,
            epochs=self.epochs,
            validation_data=validation_generator,
        )

        logger.debug("train finished")

    def predict(self, pic_path: str, *args, **kwargs) -> str:
        """
        predict a single picture

        :param pic_path:
        :return:
        """
        pic_object = toolbox.imread(pic_path)
        # fake VideoFrame for apply_hook
        fake_frame = VideoFrame(0, 0.0, pic_object)
        fake_frame = self._apply_hook(fake_frame, *args, **kwargs)
        return self.predict_with_object(fake_frame.data)

    def predict_with_object(self, frame: np.ndarray) -> str:
        """
        predict a single object

        :param frame:
        :return:
        """
        # resize for model
        frame = cv2.resize(frame, dsize=self.data_size)
        frame = np.expand_dims(frame, axis=[0, -1])

        result = self._model.predict(frame)
        tag = str(np.argmax(result, axis=1)[0])
        confidence = result.max()
        logger.debug(f"confidence: {confidence}")
        if confidence < self.score_threshold:
            logger.warning(
                f"max score is lower than {self.score_threshold}, unknown class"
            )
            return self.UNKNOWN_STAGE_NAME
        return tag

    def _classify_frame(self, frame: VideoFrame, *_, **__) -> str:
        return self.predict_with_object(frame.data)