def _reshape_keypoint_depth_weights(self, keys_to_tensors):
    """Reshape keypoint depth weights.

    The keypoint depth weights are reshaped to [num_instances, num_keypoints].
    The keypoint depth weights tensor is expected to have the same shape as the