mfinzi/pristine-ml

View on GitHub
oil/utils/metrics.py

Summary

Maintainability
A
0 mins
Test Coverage
import torch
import torch.nn as nn
import numpy as np
import os
import subprocess
from torchvision.models.inception import inception_v3
from scipy.stats import entropy
from scipy.linalg import norm,sqrtm
from ..utils.utils import Eval, Expression
#from torch.nn.functional import adaptive_avg_pool2d
#from .pytorch-fid.fid_score import calculate_frechet_distance
#from .pytorch-fid.inception import InceptionV3

# GAN Metrics

# TODO: cache logits for existing datasets
#       should be possible if we can serialize dataloaders
def get_inception():
    """ grabs the pytorch pretrained inception_v3 with resized inputs """
    inception = inception_v3(pretrained=True,transform_input=False)
    upsample = Expression(lambda x: nn.functional.interpolate(x,size=(299,299),mode='bilinear'))
    model = nn.Sequential(upsample,inception).cuda().eval()
    return model

def get_logits(model,loader):
    """ Extracts logits from a model, dataloader returns a numpy array of size (N, K)
        where K is the number of classes """
    with torch.no_grad(), Eval(model): 
        model_logits = lambda mb: model(mb).cpu().data.numpy()
        logits = np.concatenate([model_logits(minibatch) for minibatch in loader],axis=0)
    return logits

def FID_from_logits(logits1,logits2):
    """Computes the FID between logits1 and logits2
        Inputs: [logits1 (N,C)] [logits2 (N,C)] """
    mu1 = np.mean(logits1,axis=0)
    mu2 = np.mean(logits2,axis=0)
    sigma1 = np.cov(logits1, rowvar=False)
    sigma2 = np.cov(logits2, rowvar=False)

    tr = np.trace(sigma1 + sigma2 - 2*sqrtm(sigma1@sigma2))
    distance = norm(mu1-mu2)**2 + tr
    return distance

def IS_from_logits(logits):
    """ Computes the Inception score (IS) from logits of the dataset of size N with C classes.
        Inputs: [logits (N,C)], Outputs: [IS (scalar)]"""
    # E_z[KL(Pyz||Py)] = \mean_z [\sum_y (Pyz log(Pyz) - Pyz log(Py))]
    Pyz = np.exp(logits).transpose() # Take softmax (up to a normalization constant)
    Pyz /= Pyz.sum(0)[None,:]        # divide by normalization constant
    Py = np.broadcast_to(Pyz.mean(-1)[:,None],Pyz.shape)        # Average over z
    logIS = entropy(Pyz,Py).mean()   # Average over z
    return np.exp(logIS)

cachedLogits = {} 
def FID(loader1,loader2):
    """ Computes the Frechet Inception Distance  (FID) between the two image dataloaders
        using pytorch pretrained inception_v3. Requires >2048 imgs for comparison
        Dataloader should be an iterable of minibatched images, assumed to already
        be normalized with mean 0, std 1 (per color)
        """
    model = get_inception()
    logits1 = get_logits(model,loader1)
    if loader2 not in cachedLogits:
        cachedLogits[loader2] = get_logits(model,loader2)
    logits2 = cachedLogits[loader2]
    return FID_from_logits(logits1,logits2)
    
def IS(loader):
    """Computes the Inception score of a dataloader using pytorch pretrained inception_v3"""
    model = get_inception()
    logits = get_logits(model,loader)
    return IS_from_logits(logits)

   
def FID_and_IS(loader1,loader2):
    """Computes FID and IS score for loader1 against target loader2 """
    model = get_inception()
    logits1 = get_logits(model,loader1)
    if loader2 not in cachedLogits:
        cachedLogits[loader2] = get_logits(model,loader2)
    logits2 = cachedLogits[loader2]
    return FID_from_logits(logits1,logits2),IS_from_logits(logits1)

#TODO: Implement Kernel Inception Distance (KID) from (https://openreview.net/pdf?id=r1lUOzWCW)


def get_official_FID(loader,dataset='cifar10'):
    #TODO: make function not ass and check that it still works
    dir = os.path.expanduser("~/olive-oil-ml/oil/utils/")
    path = dir+"temp"
    loader.write_imgs(path)
    if dataset not in ('cifar10',):
        raise NotImplementedError
    score = subprocess.check_output(dir+"TTUR/fid.py "+dir+"fid_stats_{}_train.npz {}.npy"
                .format(dataset,path),shell=True)
    return score


# Semantic Segmentation Metrics
# Adapted from https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
def confusion_from_logits(logits,y_gt):
    bs, num_classes, _, _ = logits.shape
    pred_image = logits.max(1)[1].type_as(y_gt).cpu().data.numpy()
    #print(pred_image)
    gt_image = y_gt.cpu().data.numpy()
    return confusion_matrix(pred_image,gt_image,num_classes)

def confusion_matrix(pred_image,gt_image,num_classes):
    """Computes the confusion matrix from two numpy class images (integer values)
        ignoring classes that are negative"""
    mask = (gt_image >= 0) & (gt_image < num_classes)
    #print(gt_image[mask])
    label = num_classes * gt_image[mask].astype(int) + pred_image[mask]
    count = np.bincount(label, minlength=num_classes**2)
    confusion_matrix = count.reshape(num_classes, num_classes)
    return confusion_matrix # confusing shape, maybe transpose

def meanIoU(confusion_matrix):
    MIoU = np.diag(confusion_matrix) / (
                np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) -
                np.diag(confusion_matrix))
    MIoU = np.nanmean(MIoU)
    return MIoU

def freqIoU(confusion_matrix):
    freq = np.sum(confusion_matrix, axis=1) / np.sum(confusion_matrix)
    iu = np.diag(confusion_matrix) / (
                np.sum(confusion_matrix, axis=1) + np.sum(confusion_matrix, axis=0) -
                np.diag(confusion_matrix))
    FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
    return FWIoU

def pixelAcc(confusion_matrix):
    return np.diag(confusion_matrix).sum() / confusion_matrix.sum()

def meanAcc(confusion_matrix):
    return np.nanmean(np.diag(confusion_matrix) / np.sum(confusion_matrix, axis=1))


# def boundary_mIoU(confusion_matrix,epsilon)