def _tf_ifft_for_rank(self, rank):
    if rank == 1:
      return fft_ops.irfft
    elif rank == 2:
      return fft_ops.irfft2d