hongbo-miao/hongbomiao.com

View on GitHub
cloud-platform/aws/amazon-sagemaker/pytorch-mnist/src/utils/average_gradients.py

Summary

Maintainability
A
0 mins
Test Coverage
import torch.distributed as dist


def average_gradients(model):
    # Gradient averaging.
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
        param.grad.data /= size