if len(h_target.shape) == 5:
            # (bs, ts, aligned, window, emb) -> (bs, ts, window, emb)
            h_target = h_target.sum(2, keepdim=False) / nb_alignments.unsqueeze(