hongbo-miao/hongbomiao.com

View on GitHub
machine-learning/triton/amazon-sagamaker-triton-resnet-50/set_up/workspace/export_pt.py

Summary

Maintainability
A
0 mins
Test Coverage
import torch
import torchvision.models as models


def main() -> None:
    model_name = "model.pt"

    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"

    resnet50 = models.resnet50(pretrained=True)
    resnet50 = resnet50.eval()
    resnet50.to(device)

    resnet50_jit = torch.jit.script(resnet50)
    resnet50_jit.save(model_name)


if __name__ == "__main__":
    main()