hongbo-miao/hongbomiao.com

View on GitHub
machine-learning/hm-mlflow/experiments/classify-mnist/src/main.py

Summary

Maintainability
A
0 mins
Test Coverage
import lightning as L
import mlflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torchvision
from args import get_args
from lightning.pytorch.loggers.wandb import WandbLogger


class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28)
        )

    def forward(self, x):
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


def main():
    project_name = "classify-mnist"

    # W&B
    wandb_logger = WandbLogger(project=project_name)

    # MLflow
    mlflow.set_tracking_uri("https://mlflow.hongbomiao.com")
    mlflow.set_experiment(experiment_name=project_name)
    mlflow.pytorch.autolog()

    args = get_args()

    dataset = torchvision.datasets.MNIST(
        "data/",
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.transforms.ToTensor(),
                torchvision.transforms.transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
    )
    train_dataset, val_dataset = data.random_split(dataset, [55000, 5000])

    autoencoder = LitAutoEncoder()
    trainer = L.Trainer(
        devices="auto",
        accelerator="auto",
        max_epochs=args.max_epochs,
        check_val_every_n_epoch=1,
        logger=wandb_logger,
    )
    trainer.fit(
        autoencoder, data.DataLoader(train_dataset), data.DataLoader(val_dataset)
    )


if __name__ == "__main__":
    main()