machine-learning/graph-neural-network/src/main.py
import numpy as np
import torch
import torch.optim as optim
import wandb
from args import get_args
from model.data_loader import fetch_dataset, get_dataloaders
from model.gnn import GNN
from ogb.graphproppred import Evaluator
from tqdm import tqdm
cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()
def train(model, device, loader, optimizer, task_type):
model.train()
total_loss = 0
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
pass
else:
pred = model(batch)
optimizer.zero_grad()
# ignore nan targets (unlabeled) when computing training loss.
is_labeled = batch.y == batch.y
if "classification" in task_type:
loss = cls_criterion(
pred.to(torch.float32)[is_labeled],
batch.y.to(torch.float32)[is_labeled],
)
else:
loss = reg_criterion(
pred.to(torch.float32)[is_labeled],
batch.y.to(torch.float32)[is_labeled],
)
loss.backward()
total_loss += loss.item()
optimizer.step()
return total_loss
def evaluate(model, device, loader, evaluator):
model.eval()
y_true = []
y_pred = []
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
if batch.x.shape[0] == 1:
pass
else:
with torch.no_grad():
pred = model(batch)
y_true.append(batch.y.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)
def main():
# Training settings
args = get_args()
with wandb.init(
entity="hongbo-miao", project="graph-neural-network", config=args
) as wb:
config = wb.config
device = (
torch.device("cuda:" + str(config.device))
if torch.cuda.is_available()
else torch.device("cpu")
)
dataset, split_idx = fetch_dataset(config)
# automatic evaluator. takes dataset name as input
evaluator = Evaluator(config.dataset)
dataloaders = get_dataloaders(dataset, split_idx, config)
train_loader = dataloaders["train"]
val_loader = dataloaders["val"]
test_loader = dataloaders["test"]
if config.gnn == "gin":
model = GNN(
gnn_type="gin",
num_tasks=dataset.num_tasks,
num_layer=config.num_layer,
emb_dim=config.emb_dim,
drop_ratio=config.drop_ratio,
virtual_node=False,
).to(device)
elif config.gnn == "gin-virtual":
model = GNN(
gnn_type="gin",
num_tasks=dataset.num_tasks,
num_layer=config.num_layer,
emb_dim=config.emb_dim,
drop_ratio=config.drop_ratio,
virtual_node=True,
).to(device)
elif config.gnn == "gcn":
model = GNN(
gnn_type="gcn",
num_tasks=dataset.num_tasks,
num_layer=config.num_layer,
emb_dim=config.emb_dim,
drop_ratio=config.drop_ratio,
virtual_node=False,
).to(device)
elif config.gnn == "gcn-virtual":
model = GNN(
gnn_type="gcn",
num_tasks=dataset.num_tasks,
num_layer=config.num_layer,
emb_dim=config.emb_dim,
drop_ratio=config.drop_ratio,
virtual_node=True,
).to(device)
else:
raise ValueError("Invalid GNN type")
wb.watch(model)
optimizer = optim.Adam(model.parameters(), lr=config.lr)
val_curve = []
test_curve = []
train_curve = []
max_perf_metric_val = 0.0
for epoch in range(1, config.epochs + 1):
print(f"=====Epoch {epoch}")
print("Training...")
train_loss = train(
model, device, train_loader, optimizer, dataset.task_type
)
print("Evaluating...")
train_perf = evaluate(model, device, train_loader, evaluator)
val_perf = evaluate(model, device, val_loader, evaluator)
test_perf = evaluate(model, device, test_loader, evaluator)
print({"Train": train_perf, "Validation": val_perf, "Test": test_perf})
wb.log(
{
"epoch": epoch,
"train_loss": train_loss,
"train_perf": train_perf,
"val_perf": val_perf,
"test_perf": test_perf,
}
)
train_curve.append(train_perf[dataset.eval_metric])
val_curve.append(val_perf[dataset.eval_metric])
test_curve.append(test_perf[dataset.eval_metric])
# Save model
if val_perf[dataset.eval_metric] > max_perf_metric_val:
print("Found better model.")
max_perf_metric_val = val_perf[dataset.eval_metric]
torch.save(model.state_dict(), "model.pt")
wb.save("model.pt")
if "classification" in dataset.task_type:
best_val_epoch = np.argmax(np.array(val_curve))
best_train = max(train_curve)
else:
best_val_epoch = np.argmin(np.array(val_curve))
best_train = min(train_curve)
print("Finished training!")
print(f"Best validation score: {val_curve[best_val_epoch]}")
print(f"Test score: {test_curve[best_val_epoch]}")
if not config.filename == "":
torch.save(
{
"Val": val_curve[best_val_epoch],
"Test": test_curve[best_val_epoch],
"Train": train_curve[best_val_epoch],
"BestTrain": best_train,
},
config.filename,
)
if __name__ == "__main__":
main()