hongbo-miao/hongbomiao.com

View on GitHub
machine-learning/convolutional-neural-network/src/model/data_loader.py

Summary

Maintainability
A
0 mins
Test Coverage
import torch
import torchvision
import torchvision.transforms as transforms
from args import get_args

args = get_args()
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 4

train_set = torchvision.datasets.CIFAR10(
    root="./data/processed",
    train=True,
    download=args.should_download_original_data,
    transform=transform,
)
train_data_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True, num_workers=2
)
val_set = torchvision.datasets.CIFAR10(
    root="./data/processed",
    train=False,
    download=args.should_download_original_data,
    transform=transform,
)
val_data_loader = torch.utils.data.DataLoader(
    val_set, batch_size=batch_size, shuffle=False, num_workers=2
)