if dtype is not None and dtype != tf.float32:
      forward_position_sequence = tf.cast(forward_position_sequence,
                                          dtype=dtype)
      backward_position_sequence = tf.cast(backward_position_sequence,