cloud-platform/aws/amazon-sagemaker/pytorch-mnist/src/utils/average_gradients.py
import torch.distributed as dist
def average_gradients(model):
# Gradient averaging.
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
param.grad.data /= size