monarch-initiative/N2V

View on GitHub
embiggen/edge_prediction/edge_prediction_sklearn/nu_svc_edge_prediction.py

Summary

Maintainability
A
2 hrs
Test Coverage
"""Submodule wrapping Nu Support Vector Machine for Edge-label prediction."""
from typing import Dict, Any, Union, List
from sklearn.svm import NuSVC
from embiggen.edge_prediction.edge_prediction_sklearn.sklearn_edge_prediction_adapter import SklearnEdgePredictionAdapter
from embiggen.utils import normalize_kwargs


class NuSVCEdgePrediction(SklearnEdgePredictionAdapter):
    """Create wrapper over Sklearn Nu Support Vector Machine classifier for Edge-label prediction."""

    def __init__(
        self,
        nu: float = 0.1,
        kernel: str = "rbf",
        degree: int = 3,
        gamma: str = "scale",
        coef0: float = 0.0,
        shrinking: bool = True,
        probability: bool = False,
        tol: float = 1e-3,
        cache_size: int = 200,
        class_weight: Union[Dict, str] = "balanced",
        verbose: bool = False,
        max_iter: int = -1,
        decision_function_shape: str = "ovr",
        break_ties: bool = False,
        edge_embedding_methods: Union[List[str], str] = "Concatenate",
        training_unbalance_rate: float = 1.0,
        use_edge_metrics: bool = False,
        use_scale_free_distribution: bool = True,
        prediction_batch_size: int = 2**12,
        random_state: int = 42
    ):
        """Create the Nu SVC for Edge-label prediction."""
        self._kwargs = normalize_kwargs(
            self,
            dict(
                nu=nu,
                kernel=kernel,
                degree=degree,
                gamma=gamma,
                coef0=coef0,
                shrinking=shrinking,
                probability=probability,
                tol=tol,
                cache_size=cache_size,
                class_weight=class_weight,
                verbose=verbose,
                max_iter=max_iter,
                decision_function_shape=decision_function_shape,
                break_ties=break_ties,
            )
        )

        super().__init__(
            NuSVC(**self._kwargs, random_state=random_state),
            edge_embedding_methods=edge_embedding_methods,
            training_unbalance_rate=training_unbalance_rate,
            use_edge_metrics=use_edge_metrics,
            use_scale_free_distribution=use_scale_free_distribution,
            
            prediction_batch_size=prediction_batch_size,
            random_state=random_state
        )

    def parameters(self) -> Dict[str, Any]:
        """Returns parameters used for this model."""
        return {**super().parameters(), **self._kwargs}
    
    @classmethod
    def smoke_test_parameters(cls) -> Dict[str, Any]:
        """Returns parameters for smoke test."""
        return dict(max_iter=1)

    @classmethod
    def model_name(cls) -> str:
        return "Nu Support Vector Classifier"