if self.parameters.mode not in [base_layers.TFLITE, base_layers.PREDICT]:
      q_tensor = tf.reshape(q_tensor, [bsz, -1, self.num_heads, self.filters])
      kv_tensors = tf.reshape(kv_tensors,
                              [bsz, -1, 2, self.num_heads, self.filters])
      kv_tensors = tf.unstack(kv_tensors, axis=2)