mfinzi/pristine-ml

View on GitHub
oil/model_trainers/graphssl.py

Summary

Maintainability
A
0 mins
Test Coverage

from abc import ABCMeta, abstractmethod
import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin
from functools import partial
from sklearn.metrics.pairwise import rbf_kernel

def sine_kernel(X1,X2,gamma):
    X1n = X1/np.linalg.norm(X1,axis=1)[:,None]
    X2n = X2/np.linalg.norm(X2,axis=1)[:,None]
    return np.exp(-gamma*(1-X1n@X2n.T))

def oh(a, num_classes):
    return np.squeeze(np.eye(num_classes)[a.reshape(-1)])

class GraphSSL(BaseEstimator,ClassifierMixin,metaclass=ABCMeta):
    
    def __init__(self,gamma=2,reg=1,kernel='sin'):
        super().__init__()
        if kernel=='sin': self.kernel = partial(sine_kernel,gamma=gamma)
        elif kernel=='rbf': self.kernel = partial(rbf_kernel,gamma=gamma)
        elif callable(kernel): self.kernel = kernel
        else: raise NotImplementedError(f"Unknown kernel {kernel}")
        self.reg=reg

    def fit(self,X,y):
        """ Assumes y is -1 for unlabeled """
        n,d = X.shape
        c = max(y)+1
        Wxx = self.kernel(X,X)
        Wxx -= np.diag(np.diag(Wxx))
        D = np.diag(np.sum(Wxx,axis=-1))
        self.dm2 = dm2 = np.sum(Wxx,axis=-1)[:,None]**-.5
        L = np.eye(n) - dm2*Wxx*dm2.T
        Y = np.zeros((n,c))
        Y[y!=-1] = oh(y[y!=-1],c)
        self.X = X
        self.Ys = np.linalg.solve(L+self.reg*np.eye(n),Y)

    def predict(self,X_test):
        Wtx = self.kernel(X_test,self.X)
        dm2t = np.sum(Wtx,axis=1)**-.5
        Stx =  dm2t[:,None]*Wtx*self.dm2.T
        return (Stx@self.Ys).argmax(-1)