machine-learning/convolutional-neural-network/src/main.py
import torch
import wandb
import yaml
from evaluate import evaluate
from model.data_loader import train_data_loader, val_data_loader
from model.net import Net
from torch import nn, optim
from train import train
from utils.device import device
from utils.writer import write_params
def main():
with open("params.yaml", "r") as f:
params = yaml.safe_load(f)
write_params(params)
with wandb.init(
entity="hongbo-miao", project="convolutional-neural-network", config=params
) as wb:
config = wb.config
net = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=config["lr"])
max_val_acc = 0.0
for epoch in range(config["train"]["epochs"]):
train_loss = train(net, train_data_loader, device, optimizer, criterion)
train_acc = evaluate(net, train_data_loader, device)
val_acc = evaluate(net, val_data_loader, device)
print({"Train": train_acc, "Validation": val_acc})
wb.log(
{
"epoch": epoch,
"train_loss": train_loss,
"train_acc": train_acc,
"val_acc": val_acc,
}
)
if val_acc > max_val_acc:
print("Found better model.")
max_val_acc = val_acc
filename = "output/models/model.pt"
torch.save(net.state_dict(), filename)
wb.save(filename)
if __name__ == "__main__":
main()