if x_batch_size != y_batch_size:
      raise ValueError('Cannot do batch_dot on inputs '
                       'with different batch sizes. '
                       'Received inputs with shapes ' +
                       str(x_shape) + ' and ' +