Forced-Alignment-and-Vowel-Extraction/new-fave

View on GitHub
src/new_fave/measurements/decorators.py

Summary

Maintainability
A
35 mins
Test Coverage
from functools import cached_property, update_wrapper
import numpy as np
from new_fave.measurements.calcs import mahalanobis, \
    mahal_log_prob,\
    param_to_cov,\
    cov_to_icov,\
    clear_cached_properties

from typing import Any
from nptyping import NDArray, Shape, Float
from typing import TYPE_CHECKING
if TYPE_CHECKING:
    from new_fave.measurements.vowel_measurement import VowelMeasurement,\
        VowelClass,\
        VowelClassCollection

class OptimWrapper:
    """
    A base class for Optimization parameter wrappers
    """
    def __init__(self, func):
        update_wrapper(self, func)
        self.func = func

    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)

class MahalWrap(OptimWrapper):
    """ Mahalanobis Wrapper

    A wrapper to mark parameters for
    mahalanobis distance estimation,
    to be set as a property
    """
    prop:type[property] = property
    def __init__(self, func):
        super().__init__(func)

class MahalCacheWrap(OptimWrapper):
    """Cached Mahalanobis Wrapper

    A wrapper to mark parameters for
    mahalanobis distance estimation,
    to be set as a property
    """
    prop:type[cached_property] = cached_property
    def __init__(self, func):
        super().__init__(func)

class PropertyFactory():
    """ A property factory
    
    A property factory to generate 
    the necessary properties in aggregated classes 
    (VowelClass and VowelClassCollection)
    as well as log probability properties
    in VowelMeasurement

    """
    def __init__(self, attr:str, *args):
        self.attr = attr
    
    @property
    def winner_factory(self):
        """The parameters of the winners
        """
        def parameterized_func(
                obj:'VowelClass|VowelClassCollection'
            ):
            if not hasattr(obj, "vowel_measurements"):
                return None
            return np.array([
                getattr(x, self.attr)[...,x.winner_index].T
                for x in obj.vowel_measurements
            ]).T
        
        return parameterized_func
    
    @property
    def mean_factory(self):
        """The mean parameters of the winners
        """

        def parameterized_func(obj:'VowelClass|VowelClassCollection'):
            if not hasattr(obj, "vowel_measurements"):
                return None

            winner_array  = getattr(obj, self.attr)
            N = winner_array.shape[-1]
            square_array = winner_array.reshape(-1, N)
            winner_mean =  square_array.mean(axis = 1)
            winner_mean = winner_mean[:, np.newaxis]
            return winner_mean
        
        return parameterized_func
    
    @property
    def icov_factory(self):
        """The inverse covariance matrix of the winners.
        """

        def parameterized_func(obj:'VowelClass|VowelClassCollection'):
            if not hasattr(obj, "vowel_measurements"):
                return None            
            cov_mat = param_to_cov(getattr(obj, self.attr))
            icov_mat = cov_to_icov(cov_mat)
            return icov_mat

        return parameterized_func

    @property
    def speaker_byvclass(self):
        """log probability (based on mahalanobis distance)
        of candidate measurements from the vowel class
        within a speaker
        """
        def parameterized_func(obj:'VowelMeasurement'):
            if not hasattr(obj, "candidates"):
                return None
            if not obj.vowel_class:
                return None
            mean_vec_name = self.attr.replace("cand_", "winner_") + "_mean"
            icov_mat_name = mean_vec_name.replace("_mean", "_icov")

            cand_vals = getattr(obj, self.attr)
            mean_vals = getattr(obj.vowel_class, mean_vec_name)
            icov_vals = getattr(obj.vowel_class, icov_mat_name)

            cand_vals = cand_vals.reshape(-1, cand_vals.shape[-1])

            mahals = mahalanobis(cand_vals, mean_vals, icov_vals)
            logprob = mahal_log_prob(mahals, cand_vals)

            return logprob
        
        return parameterized_func

    @property
    def speaker_global(self):
        """log probability (based on mahalanobis distance)
        of candidate measurements from the entire speaker's 
        distribution
        """        
        def parameterized_func(obj:'VowelMeasurement'):
            if not hasattr(obj, "candidates"):
                return None
            if not obj.vowel_class and not obj.vowel_class.vowel_system:
                return None            
            mean_vec_name = self.attr.replace("cand_", "winner_") + "_mean"
            icov_mat_name = mean_vec_name.replace("_mean", "_icov")

            cand_vals = getattr(obj, self.attr)
            mean_vals = getattr(obj.vowel_class.vowel_system, mean_vec_name)
            icov_vals = getattr(obj.vowel_class.vowel_system, icov_mat_name)

            cand_vals = cand_vals.reshape(-1, cand_vals.shape[-1])

            mahals = mahalanobis(cand_vals, mean_vals, icov_vals)
            logprob = mahal_log_prob(mahals, cand_vals)

            return logprob
        
        return parameterized_func
    
    pass

def get_wrapped(cls:type, wrapper:MahalWrap|MahalCacheWrap) -> list[str]:
    """Get the class property names that have been wrapped with `wrapper`

    Args:
        wrapper (MahalWrap|MahalCacheWrap):
            The wrapper to check for 

    Returns:
        (list[str]):
            A list of attribute names
    """
    obj_dict = cls.__dict__
    props = [
        attr 
        for attr in obj_dict 
        if isinstance(obj_dict[attr], property)
        if isinstance(obj_dict[attr].fget, wrapper)
    ]

    cprops = [
        attr
        for attr in obj_dict
        if isinstance(obj_dict[attr], cached_property)
        if isinstance(obj_dict[attr].func, wrapper)
    ]

    return props + cprops


def set_prop(
        self:'VowelMeasurement|VowelClass|VowelClassCollection', 
        lhs:str, 
        rhs:str, 
        wrapper:MahalWrap|MahalCacheWrap, 
        factory:str
    ) -> None:
    """Set a property on a class

    Args:
        self (VowelMeasurement|VowelClass|VowelClassCollection):
            The object to set the property on
        lhs (str): Name of the generating attribute
        rhs (str): Name of the generated attribute
        wrapper (MahalWrap | MahalCacheWrap): The wrapper 
        factory (str): The name of the property factory
    """
    for l_attr, r_attr in zip(lhs, rhs):

        setattr(
            self.__class__,
            r_attr,
            wrapper.prop(
                getattr(PropertyFactory(l_attr), factory)
            )
        )
        if wrapper.prop is cached_property:
            self.__class__.__dict__[r_attr].__set_name__(self.__class__, r_attr)