machine-learning/convolutional-neural-network/src/evaluate.py
import torch
def evaluate(net, data_loader, device):
net.eval()
correct = 0
total = 0
with torch.no_grad():
for data in data_loader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = net(inputs)
# The class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct / total