mfinzi/LieConv

View on GitHub
lie_conv/datasets.py

Summary

Maintainability
C
1 day
Test Coverage
import math
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import warnings
import h5py
import os
from torch.utils.data import Dataset
from .utils import Named, export, Expression, FixedNumpySeed, RandomZrotation, GaussianNoise
from oil.datasetup.datasets import EasyIMGDataset
from lie_conv.hamiltonian import HamiltonianDynamics, KeplerH, SpringH
from lie_conv.lieGroups import SO3
from torchdiffeq import odeint_adjoint as odeint
from corm_data.utils import initialize_datasets
import torchvision


#ModelNet40 code adapted from 
#https://github.com/DylanWusee/pointconv_pytorch/blob/master/data_utils/ModelNetDataLoader.py

def load_h5(h5_filename):
    f = h5py.File(h5_filename)
    data = f['data'][:]
    label = f['label'][:]
    seg = []
    return (data, label, seg)

def _load_data_file(name):
    f = h5py.File(name)
    data = f["data"][:]
    label = f["label"][:]
    return data, label

def load_data(dir,classification = False):
    data_train0, label_train0,Seglabel_train0  = load_h5(dir + 'ply_data_train0.h5')
    data_train1, label_train1,Seglabel_train1 = load_h5(dir + 'ply_data_train1.h5')
    data_train2, label_train2,Seglabel_train2 = load_h5(dir + 'ply_data_train2.h5')
    data_train3, label_train3,Seglabel_train3 = load_h5(dir + 'ply_data_train3.h5')
    data_train4, label_train4,Seglabel_train4 = load_h5(dir + 'ply_data_train4.h5')
    data_test0, label_test0,Seglabel_test0 = load_h5(dir + 'ply_data_test0.h5')
    data_test1, label_test1,Seglabel_test1 = load_h5(dir + 'ply_data_test1.h5')
    train_data = np.concatenate([data_train0,data_train1,data_train2,data_train3,data_train4])
    train_label = np.concatenate([label_train0,label_train1,label_train2,label_train3,label_train4])
    train_Seglabel = np.concatenate([Seglabel_train0,Seglabel_train1,Seglabel_train2,Seglabel_train3,Seglabel_train4])
    test_data = np.concatenate([data_test0,data_test1])
    test_label = np.concatenate([label_test0,label_test1])
    test_Seglabel = np.concatenate([Seglabel_test0,Seglabel_test1])

    if classification:
        return train_data, train_label, test_data, test_label
    else:
        return train_data, train_Seglabel, test_data, test_Seglabel


@export
class ModelNet40(Dataset):
    ignored_index = -100
    class_weights = None
    stratify=True
    num_targets=40
    classes=['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car',
        'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot',
        'glass_box', 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor',
        'night_stand', 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink',
        'sofa', 'stairs', 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase',
        'wardrobe', 'xbox']
    default_root_dir = '~/datasets/ModelNet40/'
    def __init__(self,root_dir=default_root_dir,train=True,transform=None,size=1024):
        super().__init__()
        #self.transform = torchvision.transforms.ToTensor() if transform is None else transform
        train_x,train_y,test_x,test_y = load_data(os.path.expanduser(root_dir),classification=True)
        self.coords = train_x if train else test_x
        # SWAP y and z so that z (gravity direction) is in component 3
        self.coords[...,2] += self.coords[...,1]
        self.coords[...,1] = self.coords[...,2]-self.coords[...,1]
        self.coords[...,2] -= self.coords[...,1]
        # N x m x 3
        self.labels = train_y if train else test_y
        self.coords_std = np.std(train_x,axis=(0,1))
        self.coords /= self.coords_std
        self.coords = self.coords.transpose((0,2,1)) # B x n x c -> B x c x n
        self.size=size
        #pt_coords = torch.from_numpy(self.coords)
        #self.coords = FarthestSubsample(ds_frac=size/2048)((pt_coords,pt_coords))[0].numpy()

    def __getitem__(self,index):
        return torch.from_numpy(self.coords[index]).float(), int(self.labels[index])
    def __len__(self):
        return len(self.labels)
    def default_aug_layers(self):
        subsample = Expression(lambda x: x[:,:,np.random.permutation(x.shape[-1])[:self.size]])
        return nn.Sequential(subsample,RandomZrotation(),GaussianNoise(.01))#,augLayers.PointcloudScale())#


try: 
    import torch_geometric
    warnings.filterwarnings('ignore')
    @export
    class MNISTSuperpixels(torch_geometric.datasets.MNISTSuperpixels):
        ignored_index = -100
        class_weights = None
        stratify=True
        num_targets = 10
        # def __init__(self,*args,**kwargs):
        #     super().__init__(*args,**kwargs)
        # coord scale is 0-25, std of unif [0-25] is 
        def __getitem__(self,index):
            datapoint = super().__getitem__(int(index))
            coords = (datapoint.pos.T-13.5)/5 # 2 x M array of coordinates
            bchannel = (datapoint.x.T-.1307)/0.3081 # 1 x M array of blackwhite info
            label = int(datapoint.y.item())
            return ((coords,bchannel),label)
        def default_aug_layers(self):
            return nn.Sequential()
except ImportError:
    warnings.warn('torch_geometric failed to import MNISTSuperpixel cannot be used.', ImportWarning)

class RandomRotateTranslate(nn.Module):
    def __init__(self,max_trans=2):
        super().__init__()
        self.max_trans = max_trans
    def forward(self,img):
        if not self.training: return img
        bs,c,h,w = img.shape
        angles = torch.rand(bs)*2*np.pi
        affineMatrices = torch.zeros(bs,2,3)
        affineMatrices[:,0,0] = angles.cos()
        affineMatrices[:,1,1] = angles.cos()
        affineMatrices[:,0,1] = angles.sin()
        affineMatrices[:,1,0] = -angles.sin()
        affineMatrices[:,0,2] = (2*torch.rand(bs)-1)*self.max_trans/w
        affineMatrices[:,1,2] = (2*torch.rand(bs)-1)*self.max_trans/h
        flowgrid = F.affine_grid(affineMatrices.to(img.device), size = img.shape)
        transformed_img = F.grid_sample(img,flowgrid)
        return transformed_img

@export
class RotMNIST(EasyIMGDataset,torchvision.datasets.MNIST):
    """ Unofficial RotMNIST dataset created on the fly by rotating MNIST"""
    means = (0.5,)
    stds = (0.25,)
    num_targets = 10
    def __init__(self,*args,dataseed=0,**kwargs):
        super().__init__(*args,download=True,**kwargs)
        # xy = (np.mgrid[:28,:28]-13.5)/5
        # disk_cutout = xy[0]**2 +xy[1]**2 < 7
        # self.img_coords = torch.from_numpy(xy[:,disk_cutout]).float()
        # self.cutout_data = self.data[:,disk_cutout].unsqueeze(1)
        N = len(self)
        with FixedNumpySeed(dataseed):
            angles = torch.rand(N)*2*np.pi
        with torch.no_grad():
            # R = torch.zeros(N,2,2)
            # R[:,0,0] = R[:,1,1] = angles.cos()
            # R[:,0,1] = R[:,1,0] = angles.sin()
            # R[:,1,0] *=-1
            # Build affine matrices for random translation of each image
            affineMatrices = torch.zeros(N,2,3)
            affineMatrices[:,0,0] = angles.cos()
            affineMatrices[:,1,1] = angles.cos()
            affineMatrices[:,0,1] = angles.sin()
            affineMatrices[:,1,0] = -angles.sin()
            # affineMatrices[:,0,2] = -2*np.random.randint(-self.max_trans, self.max_trans+1, bs)/w
            # affineMatrices[:,1,2] = 2*np.random.randint(-self.max_trans, self.max_trans+1, bs)/h
            self.data = self.data.unsqueeze(1).float()
            flowgrid = F.affine_grid(affineMatrices, size = self.data.size())
            self.data = F.grid_sample(self.data, flowgrid)
    def __getitem__(self,idx):
        return (self.data[idx]-.5)/.25, int(self.targets[idx])
    def default_aug_layers(self):
        return RandomRotateTranslate(0)# no translation

from PIL import Image
from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \
    verify_str_arg
from torchvision.datasets.vision import VisionDataset
# !wget -nc http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip
# # uncompress the zip file
# !unzip -n mnist_rotation_new.zip -d mnist_rotation_new
class MnistRotDataset(VisionDataset):
    """ Official RotMNIST dataset."""
    ignored_index = -100
    class_weights = None
    balanced = True
    stratify = True
    means = (0.130,)
    stds = (0.297,)
    num_targets=10
    resources = ["http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip"]
    training_file = 'mnist_all_rotation_normalized_float_train_valid.amat'
    test_file = 'mnist_all_rotation_normalized_float_test.amat'
    def __init__(self,root, train=True, transform=None,download=True):
        if transform is None:
            normalize = transforms.Normalize(self.means, self.stds)
            transform = transforms.Compose([transforms.ToTensor(),normalize])
        super().__init__(root,transform=transform)
        self.train = train
        if download:
            self.download()
        if train:
            file=os.path.join(self.raw_folder, self.training_file)
        else:
            file=os.path.join(self.raw_folder, self.test_file)
        
        self.transform = transform

        data = np.loadtxt(file, delimiter=' ')
            
        self.images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)
        self.labels = data[:, -1].astype(np.int64)
        self.num_samples = len(self.labels)
    
    def __getitem__(self, index):
        image, label = self.images[index], self.labels[index]
        image = Image.fromarray(image)
        if self.transform is not None:
            image = self.transform(image)
        return image, label
    def _check_exists(self):
        return (os.path.exists(os.path.join(self.raw_folder,
                                            self.training_file)) and
                os.path.exists(os.path.join(self.raw_folder,
                                            self.test_file)))
    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'processed')
    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder,exist_ok=True)
        os.makedirs(self.processed_folder,exist_ok=True)

        # download files
        for url in self.resources:
            filename = url.rpartition('/')[2]
            download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=None)
        print('Downloaded!')

    def __len__(self):
        return len(self.labels)
    def default_aug_layers(self):
        return RandomRotateTranslate(0)# no translation


class DynamicsDataset(Dataset):
    num_targets = 1

    def __len__(self):
        return self.Zs.shape[0]

    def __getitem__(self, i):
        inputs = (self.Zs[i, 0], self.SysP[i], self.Ts[i])
        targets = self.Zs[i]
        return inputs, targets

    def generate_trajectory_data(self, n_systems, sim_kwargs, batch_size=5000):
        """
        Parameters
        ----------
        n_systems: int
        batch_size: int

        Returns
        -------
        ts: torch.Tensor, [n_systems, traj_len]
        zs: torch.Tensor, [n_systems, traj_len, z_dim]
        sys_params: torch.Tensor, [n_systems, param_dim]
        """
        batch_size = min(batch_size, n_systems)
        n_gen = 0
        t_batches, z_batches, sysp_batches = [], [], []
        while n_gen < n_systems:
            z0s, sys_params = self.sample_system(n_systems=batch_size, space_dim=self.space_dim)
            dynamics = self._get_dynamics(sys_params)
            new_ts, new_zs = self.sim_trajectories(z0s, dynamics, **sim_kwargs)
            t_batches.append(new_ts)
            z_batches.append(new_zs)
            sysp_batches.append(torch.stack(sys_params, dim=-1))
            n_gen += new_ts.shape[0]
        print(n_gen)
        ts = torch.cat(t_batches, dim=0)[:n_systems]
        zs = torch.cat(z_batches, dim=0)[:n_systems]
        sys_params = torch.cat(sysp_batches, dim=0)[:n_systems]
        return ts, zs, sys_params

    def sim_trajectories(self, z0, dynamics, traj_len, delta_t):
        """
        This method should be implemented in a subclass with the following interface:
        Parameters
        ----------
        z0: torch.Tensor, [batch_size, z_dim]
        traj_len: int
        delta_t: float or torch.Tensor, [batch_size] (must be greater than 0)
        dynamics: function that computes dz/dt

        Returns
        -------
        ts: torch.Tensor, [batch_size, traj_len]
        zs: torch.Tensor, [batch_size, traj_len, z_dim]
        """
        batch_size, _ = z0.shape
        with torch.no_grad():
            ts = torch.linspace(0, traj_len * delta_t, traj_len).double()
            zs = odeint(dynamics, z0, ts, rtol=1e-8, method='rk4').detach()
            ts = ts.expand(batch_size, -1)
        zs = zs.transpose(1, 0)
        return ts, zs

    def format_training_data(self, ts, zs, chunk_len):
        """
        Randomly samples chunks of trajectory data, returns tensors shaped for training.
        Parameters
        ----------
        ts: torch.Tensor, [batch_size, traj_len]
        zs: torch.Tensor, [batch_size, traj_len, z_dim]
        chunk_len: int
        Returns
        -------
        chosen_ts: torch.Tensor, [batch_size, chunk_len]
        chosen_zs: torch.Tensor, [batch_size, chunk_len, z_dim]
        """
        batch_size, traj_len, z_dim = zs.shape
        n_chunks = traj_len // chunk_len
        chunk_idx = torch.randint(0, n_chunks, (batch_size,), device=zs.device).long()
        chunked_ts = torch.stack(ts.chunk(n_chunks, dim=1))
        chunked_zs = torch.stack(zs.chunk(n_chunks, dim=1))
        chosen_ts = chunked_ts[chunk_idx, range(batch_size)]
        chosen_zs = chunked_zs[chunk_idx, torch.arange(batch_size).long()]
        return chosen_ts, chosen_zs

    def sample_system(self, n_systems, space_dim, **kwargs):
        """
        This method should be implemented in a subclass with the following interface:
        Parameters
        ----------
        n_systems: int
        space_dim: int

        Returns
        -------
        z0: torch.Tensor, [n_systems, z_dim]
        sys_params: tuple (torch.Tensor, torch.Tensor, ...
        """
        raise NotImplementedError

    def _get_dynamics(self, sys_params):
        """
        Parameters
        ----------
        sys_params: tuple(torch.Tensor, torch.Tensor, ...)
        """
        raise NotImplementedError
    
@export
class SpringDynamics(DynamicsDataset):
    default_root_dir = os.path.expanduser('~/datasets/ODEDynamics/SpringDynamics/')
    sys_dim = 2
    
    def __init__(self, root_dir=default_root_dir, train=True, download=True, n_systems=100, space_dim=2, regen=False,
                 chunk_len=5):
        super().__init__()
        filename = os.path.join(root_dir, f"spring_{space_dim}D_{n_systems}_{('train' if train else 'test')}.pz")
        self.space_dim = space_dim
        if os.path.exists(filename) and not regen:
            ts, zs,self.SysP = torch.load(filename)
        elif download:
            sim_kwargs = dict(
                traj_len=500,
                delta_t=0.01,
            )
            ts, zs, self.SysP = self.generate_trajectory_data(n_systems=n_systems, sim_kwargs=sim_kwargs)
            os.makedirs(root_dir, exist_ok=True)
            print(filename)
            torch.save((ts, zs, self.SysP),filename)
        else:
            raise Exception("Download=False and data not there")
        self.sys_dim = self.SysP.shape[-1]
        self.Ts, self.Zs = self.format_training_data(ts, zs, chunk_len)
    
    def sample_system(self, n_systems, space_dim, ood=False):
        """
        See DynamicsDataset.sample_system docstring
        """
        n = np.random.choice([6]) #TODO: handle padding/batching with different n
        if ood: n = np.random.choice([4,8])
        masses = (3 * torch.rand(n_systems, n).double() + .1)
        k = 5*torch.rand(n_systems, n).double()
        q0 = .4*torch.randn(n_systems, n, space_dim).double()
        p0 = .6*torch.randn(n_systems, n, space_dim).double()
        p0 -= p0.mean(0,keepdim=True)
        z0 = torch.cat([q0.reshape(n_systems, n * space_dim), p0.reshape(n_systems, n * space_dim)], dim=1)
        return z0, (masses, k)

    def _get_dynamics(self, sys_params):
        H = lambda t, z: SpringH(z, *sys_params)
        return HamiltonianDynamics(H, wgrad=False)

@export
class NBodyDynamics(DynamicsDataset):
    default_root_dir = os.path.expanduser('~/datasets/ODEDynamics/NBodyDynamics/')

    def __init__(self, root_dir=default_root_dir, train=True, download=True, n_systems=100, regen=False,
                 chunk_len=5, space_dim=3, delta_t=0.01):
        super().__init__()
        filename = os.path.join(root_dir, f"n_body_{space_dim}D_{n_systems}_{('train' if train else 'test')}.pz")
        self.space_dim = space_dim

        if os.path.exists(filename) and not regen:
            ts, zs, self.SysP = torch.load(filename)
        elif download:
            sim_kwargs = dict(
                traj_len=200,
                delta_t=delta_t,
            )
            ts, zs, self.SysP = self.generate_trajectory_data(n_systems, sim_kwargs)
            os.makedirs(root_dir, exist_ok=True)
            print(filename)
            torch.save((ts, zs, self.SysP), filename)
        else:
            raise Exception("Download=False and data not there")
        self.sys_dim = self.SysP.shape[-1]
        self.Ts, self.Zs = self.format_training_data(ts, zs, chunk_len)

    def sample_system(self, n_systems, n_bodies=6, space_dim=3):
        """
        See DynamicsDataset.sample_system docstring
        """
        grav_const = 1.  # hamiltonian.py assumes G = 1
        star_mass = torch.tensor([[32.]]).expand(n_systems, -1, -1)
        star_pos = torch.tensor([[0.] * space_dim]).expand(n_systems, -1, -1)
        star_vel = torch.tensor([[0.] * space_dim]).expand(n_systems, -1, -1)

        planet_mass_min, planet_mass_max = 2e-2, 2e-1
        planet_mass_range = planet_mass_max - planet_mass_min

        planet_dist_min, planet_dist_max = 0.5, 4.
        planet_dist_range = planet_dist_max - planet_dist_min

        # sample planet masses, radius vectors
        planet_masses = planet_mass_range * torch.rand(n_systems, n_bodies - 1, 1) + planet_mass_min
        rho = torch.linspace(planet_dist_min, planet_dist_max, n_bodies - 1)
        rho = rho.expand(n_systems, -1).unsqueeze(-1)
        rho = rho + 0.3 * (torch.rand(n_systems, n_bodies - 1, 1) - 0.5) * planet_dist_range / (n_bodies - 1)
        planet_vel_magnitude = (grav_const * star_mass / rho).sqrt()

        if space_dim == 2:
            planet_pos, planet_vel = self._init_2d(rho, planet_vel_magnitude)
        elif space_dim == 3:
            planet_pos, planet_vel = self._init_3d(rho, planet_vel_magnitude)
        else:
            raise RuntimeError("only 2-d and 3-d systems are supported")

        # import pdb; pdb.set_trace()
        perm = torch.stack([torch.randperm(n_bodies) for _ in range(n_systems)])

        pos = torch.cat([star_pos, planet_pos], dim=1)
        pos = torch.stack([pos[i, perm[i]] for i in range(n_systems)]).reshape(n_systems, -1)
        momentum = torch.cat([star_mass * star_vel, planet_masses * planet_vel], dim=1)
        momentum = torch.stack([momentum[i, perm[i]] for i in range(n_systems)]).reshape(n_systems, -1)
        z0 = torch.cat([pos.double(), momentum.double()], dim=-1)

        masses = torch.cat([star_mass, planet_masses], dim=1).squeeze(-1).double()
        masses = torch.stack([masses[i, perm[i]] for i in range(n_systems)])

        return z0, (masses,)

    def _init_2d(self, rho, planet_vel_magnitude):
        n_systems, n_planets, _ = rho.shape
        # sample radial vectors
        theta = 2 * math.pi * torch.rand(n_systems, n_planets, 1)
        planet_pos = torch.cat([
            rho * torch.cos(theta),
            rho * torch.sin(theta)
        ], dim=-1)
        # get radial tangent vector, randomly flip orientation
        e_1 = torch.stack([-planet_pos[..., 1], planet_pos[..., 0]], dim=-1)
        flip_dir = 2 * (torch.bernoulli(torch.empty(n_systems, n_planets, 1).fill_(0.5)) - 0.5)
        e_1 = e_1 * flip_dir / e_1.norm(dim=-1, keepdim=True)
        planet_vel = planet_vel_magnitude * e_1
        return planet_pos, planet_vel

    def _init_3d(self, rho, planet_vel_magnitude):
        n_systems, n_planets, _ = rho.shape
        # sample radial vectors
        theta = 2 * math.pi * torch.rand(n_systems, n_planets, 1)
        phi = torch.acos(2 * torch.rand(n_systems, n_planets, 1) - 1)  # incorrect to sample \phi \in [0, \pi]
        planet_pos = torch.cat([
            rho * torch.sin(phi) * torch.cos(theta),
            rho * torch.sin(phi) * torch.sin(theta),
            rho * torch.cos(phi)
        ], dim=-1)

        # get radial tangent plane orthonormal basis
        e_1 = torch.stack([torch.zeros(n_systems, n_planets), -planet_pos[..., 2], planet_pos[..., 1]], dim=-1)
        e_2 = torch.cross(planet_pos, e_1, dim=-1)
        e_1 = e_1 / e_1.norm(dim=-1, keepdim=True)
        e_2 = e_2 / e_2.norm(dim=-1, keepdim=True)

        # sample initial velocity in tangent plane
        omega = 2 * math.pi * torch.rand(n_systems, n_planets, 1)
        planet_vel = torch.cos(omega) * e_1 + torch.sin(omega) * e_2
        planet_vel = planet_vel_magnitude * planet_vel
        return planet_pos, planet_vel

    def _get_dynamics(self, sys_params):
        H = lambda t, z: KeplerH(z, *sys_params)
        return HamiltonianDynamics(H, wgrad=False)

@export
class T3aug(nn.Module):
    def __init__(self,scale=.5,train_only=True):
        super().__init__()
        self.train_only = train_only
        self.scale=scale
    def forward(self,x):
        if not self.training and self.train_only: return x
        coords,vals,mask = x
        bs = coords.shape[0]
        unifs = torch.randn(bs,1,3,device=coords.device,dtype=coords.dtype)
        translations = self.scale*unifs
        return (coords+translations,vals,mask)
@export
class SO3aug(nn.Module):
    def __init__(self,train_only=True):
        super().__init__()
        self.train_only = train_only
    def forward(self,x):
        if not self.training and self.train_only: return x
        coords,vals,mask = x
        # coords (bs,n,c)
        Rs = SO3().sample(coords.shape[0],1,device=coords.device,dtype=coords.dtype)
        return ((Rs@coords.unsqueeze(-1)).squeeze(-1),vals,mask)
@export
def SE3aug(scale=.5,train_only=True):
    return nn.Sequential(T3aug(scale,train_only),SO3aug(train_only))

default_qm9_dir = '~/datasets/molecular/qm9/'
def QM9datasets(root_dir=default_qm9_dir):
    root_dir = os.path.expanduser(root_dir)
    filename= f"{root_dir}data.pz"
    if os.path.exists(filename):
        return torch.load(filename)
    else:
        datasets, num_species, charge_scale = initialize_datasets((-1,-1,-1),
         "data", 'qm9', subtract_thermo=True,force_download=True)
        qm9_to_eV = {'U0': 27.2114, 'U': 27.2114, 'G': 27.2114, 'H': 27.2114, 'zpve': 27211.4, 'gap': 27.2114, 'homo': 27.2114, 'lumo': 27.2114}
        for dataset in datasets.values():
            dataset.convert_units(qm9_to_eV)
            dataset.num_species = 5
            dataset.charge_scale = 9
        os.makedirs(root_dir, exist_ok=True)
        torch.save((datasets, num_species, charge_scale),filename)
        return (datasets, num_species, charge_scale)


# class SchPackQM9(Dataset):
#     default_qm9_dir = '~/datasets/molecular/qm9/'
#     max_atoms = 29
#     num_species = 5
#     charge_scale = 9
#     def __init__(self,root_dir=default_qm9_dir):
#         super().__init__()
#         filename = f"{root_dir}sch_data.pz"
#         if os.path.exists(filename):
#             self.data = torch.load(filename)
#         else:
#             schqm9 = schnetpack.datasets.QM9(os.path.join(root_dir,'qm9.db'),download=True)
#             self.data = self.collect_and_pad_data(schqm9)
#             os.makedirs(root_dir, exist_ok=True)
#             torch.save(self.data,filename)
#         self.calc_stats()
    
#     def __getitem__(self, idx):
#         return {key: val[idx] for key, val in self.data.items()}

#     def calc_stats(self):
#         self.stats = {key: (val.mean(), val.std()) for key, val in self.data.items() 
#                       if type(val) is torch.Tensor and val.dim() == 1 and val.is_floating_point()}
#         self.median_stats = {key: (val.median(), torch.median(torch.abs(val - val.median()))) for key, val in self.data.items() 
#                              if type(val) is torch.Tensor and val.dim() == 1 and val.is_floating_point()}

#     def collect_and_pad_data(self,sch_dataset):
#         datapoints = [sch_dataset[i] for i in range(len(sch_dataset))]
#         properties = datapoints[0].keys()-{'_cell_offset','_cell','_neighbors'}
#         batched = {prop: batch_stack([mol[prop] for mol in datapoints]) for prop in properties}
#         return batched


md17_subsets = {'benzene','uracil','naphthalene','aspirin','salicylic_acid',
               'malonaldehyde','ethanol','toluene','paracetamol','azobenzene'}
default_md17_dir = '~/datasets/molecular/md17'
def MD17datasets(root_dir=default_md17_dir,task='benzene'):
    root_dir = os.path.expanduser(root_dir)
    filename= f"{root_dir}data.pz"
    if os.path.exists(filename):
        return torch.load(filename)
    else:
        datasets, num_species, charge_scale = initialize_datasets((-1,-1,-1), 
        "data", 'md17',subset=task,force_download=True)
        mean_energy = datasets['train'].data['energies'].mean()
        for dataset in datasets.values():
            dataset.data['energies'] -= mean_energy
        os.makedirs(root_dir, exist_ok=True)
        torch.save((datasets,num_species,charge_scale),filename)
        return (datasets,num_species,charge_scale)

if __name__=='__main__':
    from mpl_toolkits import mplot3d
    import matplotlib.pyplot as plt
    import cv2
    
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    
    i = 0
    # a = load_data(os.path.expanduser('~/datasets/ModelNet40/'))[0]
    # a[...,2] += a[...,1]
    # a[...,1] = a[...,2]-a[...,1]
    # a[...,2] -= a[...,1]
    D = ModelNet40()
    def update_plot(e):
        global i
        if e.key == "right": i+=1
        elif e.key == "left": i-=1
        else:return
        ax.cla()
        xyz,label = D[i]#.T
        x,y,z = xyz.numpy()*D.coords_std[:,None]
        # d[2] += d[1]
        # d[1] = d[2]-d[1]
        # d[2] -= d[1]
        ax.scatter(x,y,z,c=z)
        ax.text2D(0.05, 0.95, D.classes[label], transform=ax.transAxes)
        #ax.contour3D(d[0],d[2],d[1],cmap='viridis',edgecolor='none')
        ax.set_xlim3d(-1,1)
        ax.set_ylim3d(-1,1)
        ax.set_zlim3d(-1,1)
        fig.canvas.draw()
    fig.canvas.mpl_connect('key_press_event',update_plot)
    plt.show()