self.intermediate_dense = tf_keras.layers.EinsumDense(
        "abc,cd->abd",
        output_shape=(None, self.intermediate_size),
        bias_axes="d",
        kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),