RasaHQ/rasa_core

View on GitHub
rasa/core/policies/tf_utils.py

Summary

Maintainability
F
3 days
Test Coverage
from collections import namedtuple
import tensorflow as tf


class TimedNTM(object):
    """Timed Neural Turing Machine

    Inspired by paper:
        https://arxiv.org/pdf/1410.5401.pdf
    Implementation inspired by:
        https://github.com/carpedm20/NTM-tensorflow/blob/master/ntm_cell.py

    See our paper for details: https://arxiv.org/abs/1811.11707
    """

    def __init__(self, attn_shift_range, sparse_attention, name):
        """Construct the `TimedNTM`.

        Args:
            attn_shift_range: Python int.
                A time range within which to attend to the memory by location
            sparse_attention: Python bool.
                If `True` use sparsemax instead of softmax for probs
            name: Name to use when creating ops.
        """

        # interpolation gate
        self.name = 'timed_ntm_' + name

        self._inter_gate = tf.layers.Dense(
            units=1,
            activation=tf.sigmoid,
            name=self.name + '/inter_gate'
        )
        # if use sparsemax instead of softmax for probs
        self._sparse_attention = sparse_attention

        if sparse_attention:
            # sparsemax doesn't support inf
            self._inf = float(5000)
        else:
            self._inf = float('inf')

        # shift weighting if range is provided
        if attn_shift_range:
            self._shift_weight = tf.layers.Dense(
                units=2 * attn_shift_range + 1,
                activation=tf.nn.softmax,
                name=self.name + '/shift_weight'
            )
        else:
            self._shift_weight = None

        # sharpening parameter
        self._gamma_sharp = tf.layers.Dense(
            units=1,
            activation=lambda a: tf.nn.softplus(a) + 1,
            bias_initializer=tf.constant_initializer(1),
            name=self.name + '/gamma_sharp'
        )

    def __call__(self, attn_inputs, scores, scores_state, mask):
        # apply exponential moving average with interpolation gate weight
        # to scores from previous time which are equal to probs at this point
        # different from original NTM where it is applied after softmax
        i_g = self._inter_gate(attn_inputs)

        # scores limited by time
        scores = tf.concat([i_g * scores[:, :-1] + (1 - i_g) * scores_state,
                            scores[:, -1:]], 1)
        next_scores_state = scores

        if mask is not None:
            # apply mask to scores
            if self._shift_weight is not None:
                # rearrange scores to make them continuous for convolution
                scores = tf.map_fn(self._rearrange_fn,
                                   [scores, mask], dtype=scores.dtype)
            else:
                scores = tf.where(mask > 0,
                                  scores, -self._inf * tf.ones_like(scores))

        # create probabilities for attention
        if self._sparse_attention:
            probs = tf.contrib.sparsemax.sparsemax(scores)
        else:
            probs = tf.nn.softmax(scores)

        if self._shift_weight is not None:
            s_w = self._shift_weight(attn_inputs)

            # we want to go back in time during convolution
            conv_probs = tf.reverse(probs, axis=[1])

            # preare probs for tf.nn.depthwise_conv2d
            # [in_width, in_channels=batch]
            conv_probs = tf.transpose(conv_probs, [1, 0])
            # [batch=1, in_height=1, in_width=time+1, in_channels=batch]
            conv_probs = conv_probs[tf.newaxis, tf.newaxis, :, :]

            # [filter_height=1, filter_width=2*attn_shift_range+1,
            #   in_channels=batch, channel_multiplier=1]
            conv_s_w = tf.transpose(s_w, [1, 0])
            conv_s_w = conv_s_w[tf.newaxis, :, :, tf.newaxis]

            # perform 1d convolution
            # [batch=1, out_height=1, out_width=time+1, out_channels=batch]
            conv_probs = tf.nn.depthwise_conv2d_native(conv_probs, conv_s_w,
                                                       [1, 1, 1, 1], 'SAME')
            conv_probs = conv_probs[0, 0, :, :]
            conv_probs = tf.transpose(conv_probs, [1, 0])

            probs = tf.reverse(conv_probs, axis=[1])

            if mask is not None:
                # arrange probs back to their original time order
                probs = tf.map_fn(self._arrange_back_fn,
                                  [probs, mask], dtype=probs.dtype)

        # sharpening
        g_sh = self._gamma_sharp(attn_inputs)

        powed_probs = tf.pow(probs, g_sh)
        probs = powed_probs / (
            tf.reduce_sum(powed_probs, 1, keepdims=True) + 1e-32)

        return probs, next_scores_state

    def _rearrange_fn(self, list_tensor_1d_mask_1d):
        """Rearranges tensor_1d to put all the values
            where mask_1d=1 to the right and
            where mask_1d=0 to the left and sets them to -infinity"""
        tensor_1d, mask_1d = list_tensor_1d_mask_1d

        partitioned_tensor = tf.dynamic_partition(tensor_1d,
                                                  mask_1d, 2)
        partitioned_tensor[0] = \
            -self._inf * tf.ones_like(partitioned_tensor[0])

        return tf.concat(partitioned_tensor, 0)

    @staticmethod
    def _arrange_back_fn(list_tensor_1d_mask_1d):
        """Arranges back tensor_1d to restore original order
            modified by `_rearrange_fn` according to mask_1d:
            - number of 0s in mask_1d values on the left are set to
              their corresponding places where mask_1d=0,
            - number of 1s in mask_1d values on the right are set to
              their corresponding places where mask_1d=1"""
        tensor_1d, mask_1d = list_tensor_1d_mask_1d

        mask_indices = tf.dynamic_partition(tf.range(tf.shape(tensor_1d)[0]),
                                            mask_1d, 2)

        mask_sum = tf.reduce_sum(mask_1d, axis=0)
        partitioned_tensor = [tf.zeros_like(tensor_1d[:-mask_sum]),
                              tensor_1d[-mask_sum:]]

        return tf.dynamic_stitch(mask_indices, partitioned_tensor)


def _compute_time_attention(attention_mechanism, attn_inputs, attention_state,
                            # time is added to calculate time attention
                            time, timed_ntm, time_mask, ignore_mask,
                            attention_layer):
    """Computes the attention and alignments limited by time
        for a given attention_mechanism.

        Modified helper method from tensorflow."""

    scores, _ = attention_mechanism(attn_inputs, state=attention_state)

    # take only scores from current and past times
    timed_scores = scores[:, :time + 1]
    timed_scores_state = attention_state[:, :time]

    # get mask for past times
    timed_time_mask = time_mask[:, :time]
    if ignore_mask is not None:
        timed_time_mask *= 1 - ignore_mask[:, :time]

    # set mask for current time to 1
    timed_time_mask = tf.concat([timed_time_mask,
                                 tf.ones_like(time_mask[:, :1])], 1)

    # pass these scores to NTM
    probs, next_scores_state = timed_ntm(attn_inputs, timed_scores,
                                         timed_scores_state,
                                         timed_time_mask)

    # concatenate probs with zeros to get new alignments
    zeros = tf.zeros_like(scores)
    # remove current time from attention
    alignments = tf.concat([probs[:, :-1], zeros[:, time:]], 1)

    # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
    expanded_alignments = tf.expand_dims(alignments, 1)

    # Context is the inner product of alignments and values along the
    # memory time dimension.
    # alignments shape is
    #   [batch_size, 1, memory_time]
    # attention_mechanism.values shape is
    #   [batch_size, memory_time, memory_size]
    # the batched matmul is over memory_time, so the output shape is
    #   [batch_size, 1, memory_size].
    # we then squeeze out the singleton dim.
    context = tf.matmul(expanded_alignments, attention_mechanism.values)
    context = tf.squeeze(context, [1])

    if attention_layer is not None:
        attention = attention_layer(tf.concat([attn_inputs, context], 1))
    else:
        attention = context

    # return current time to attention
    alignments = tf.concat([probs, zeros[:, time + 1:]], 1)
    next_attention_state = tf.concat([next_scores_state,
                                      zeros[:, time + 1:]], 1)
    return attention, alignments, next_attention_state


# noinspection PyProtectedMember
class TimeAttentionWrapperState(
    namedtuple("TimeAttentionWrapperState",
               tf.contrib.seq2seq.AttentionWrapperState._fields +
               ("all_time_masks", "all_cell_states"))):  # added
    """Modified  from tensorflow's tf.contrib.seq2seq.AttentionWrapperState
        see there for description of the parameters

    Additional fields:
        - `all_time_masks`: A mask applied to a memory
           that filters certain time steps
        - `all_cell_states`: All states of the wrapped `RNNCell`
           at all the previous time steps.
    """

    def clone(self, **kwargs):
        """Copied  from tensorflow's tf.contrib.seq2seq.AttentionWrapperState
            see there for description of the parameters"""

        def with_same_shape(old, new):
            """Check and set new tensor's shape."""
            if isinstance(old, tf.Tensor) and isinstance(new, tf.Tensor):
                return tf.contrib.framework.with_same_shape(old, new)
            return new

        return tf.contrib.framework.nest.map_structure(
            with_same_shape,
            self,
            super(TimeAttentionWrapperState, self)._replace(**kwargs)
        )


class TimeAttentionWrapper(tf.contrib.seq2seq.AttentionWrapper):
    """Custom AttentionWrapper that takes into account time
        when calculating attention.
        Attention is calculated before calling rnn cell.

        Modified from tensorflow's tf.contrib.seq2seq.AttentionWrapper.

        See our paper for details: https://arxiv.org/abs/1811.11707
    """

    def __init__(self, cell,
                 attention_mechanism,
                 sequence_len,
                 attn_shift_range=0,
                 sparse_attention=False,
                 attention_layer_size=None,
                 alignment_history=False,
                 rnn_and_attn_inputs_fn=None,
                 ignore_mask=None,
                 cell_input_fn=None,
                 index_of_attn_to_copy=None,
                 likelihood_fn=None,
                 tensor_not_to_copy=None,
                 output_attention=False,
                 initial_cell_state=None,
                 name=None,
                 attention_layer=None):
        """Construct the `TimeAttentionWrapper`.
            See the super class for the original arguments description.

        Additional args:
            sequence_len: Python integer.
                Maximum length of the sequence, used to create
                appropriate TensorArray for all cell states
                in TimeAttentionWrapperState
            attn_shift_range: Python integer (`0` by default).
                A time range within which to attend to the memory
                by location in Neural Turing Machine.
            sparse_attention: Python bool.
                A flag to use sparsemax (if `True`) instead of
                softmax (if `False`, default) for probabilities
            inputs_and_attn_inputs_fn: (optional) A `callable`.
                A function that creates inputs and attention inputs tensors.
            ignore_mask: (optional) Boolean Tensor.
                Determines which time steps to ignore in attention
            index_of_attn_to_copy: (optional) Python integer.
                An index of attention mechanism that picks
                which part of attention tensor to use for copying to output,
                the default is `None`, which turns off copying mechanism.
                Copy inspired by: https://arxiv.org/pdf/1603.06393.pdf
            likelihood_fn: (optional) A `callable`.
                A method to perform likelihood calculation to
                filter time step in copy mechanism.
                Returns a tuple of binary likelihood and likelihood
            tensor_not_to_copy: (optional) A Tensor.
                A tensor, which shouldn't be copied from previous time steps

        Modified args:
            output_attention: Python bool.  If `True`, the output at each
                time step is the concatenated cell outputs,
                attention values and additional values described in
                `additional_output_size()`, used in copy mechanism.
        """
        super(TimeAttentionWrapper, self).__init__(
            cell,
            attention_mechanism,
            attention_layer_size,
            alignment_history,
            cell_input_fn,
            output_attention,
            initial_cell_state,
            name,
            attention_layer
        )
        self._sequence_len = sequence_len

        if not isinstance(attn_shift_range, list):
            # attn_shift_range might not be a list
            attn_shift_range = [attn_shift_range]
        self._timed_ntms = [TimedNTM(attn_shift_range[0],
                                     sparse_attention,
                                     name='0')]
        if self._is_multi:
            # if there are several attention mechanisms,
            # create additional TimedNTMs for them
            if len(attn_shift_range) == 1:
                # original attn_shift_range might not be a list
                attn_shift_range *= len(attention_mechanism)
            elif len(attn_shift_range) != len(attention_mechanism):
                raise ValueError(
                    "If provided, `attn_shift_range` must contain exactly one "
                    "integer per attention_mechanism, saw: {} vs {}"
                    "".format(len(attn_shift_range), len(attention_mechanism))
                )
            for i in range(1, len(attention_mechanism)):
                self._timed_ntms.append(TimedNTM(attn_shift_range[i],
                                                 sparse_attention,
                                                 name=str(i)))

        if rnn_and_attn_inputs_fn is None:
            rnn_and_attn_inputs_fn = self._default_rnn_and_attn_inputs_fn
        else:
            if not callable(rnn_and_attn_inputs_fn):
                raise TypeError(
                    "`rnn_and_attn_inputs_fn` must be callable, saw type: {}"
                    "".format(type(rnn_and_attn_inputs_fn).__name__)
                )
        self._rnn_and_attn_inputs_fn = rnn_and_attn_inputs_fn

        if not isinstance(ignore_mask, list):
            self._ignore_mask = [tf.cast(ignore_mask, tf.int32)]
        else:
            self._ignore_mask = [tf.cast(i_m, tf.int32) for i_m in ignore_mask]

        self._index_of_attn_to_copy = index_of_attn_to_copy

        self._likelihood_fn = likelihood_fn
        self._tensor_not_to_copy = tensor_not_to_copy

    @staticmethod
    def _default_rnn_and_attn_inputs_fn(inputs, cell_state):
        if isinstance(cell_state, tf.contrib.rnn.LSTMStateTuple):
            return inputs, tf.concat([inputs, cell_state.h], -1)
        else:
            return inputs, tf.concat([inputs, cell_state], -1)

    @staticmethod
    def additional_output_size():
        """Number of additional outputs:

        likelihoods:
            attn_likelihood, state_likelihood
        debugging info:
            current_time_prob,
            bin_likelihood_not_to_copy, bin_likelihood_to_copy

        **Method should be static**
        """
        return 2 + 3

    @property
    def output_size(self):
        if self._output_attention:
            if self._index_of_attn_to_copy is not None:
                # output both raw rnn cell_output and
                # cell_output with copied attention
                # together with attention vector itself
                # and additional output
                return (2 * self._cell.output_size +
                        self._attention_layer_size +
                        self.additional_output_size())
            else:
                return self._cell.output_size + self._attention_layer_size
        else:
            return self._cell.output_size

    @property
    def state_size(self):
        """The `state_size` property of `TimeAttentionWrapper`.
        Returns:
            A `TimeAttentionWrapperState` tuple containing shapes
            used by this object.
        """

        # use AttentionWrapperState from superclass
        state_size = super(TimeAttentionWrapper, self).state_size

        all_cell_states = self._cell.state_size

        return TimeAttentionWrapperState(
            cell_state=state_size.cell_state,
            time=state_size.time,
            attention=state_size.attention,
            alignments=state_size.alignments,
            attention_state=state_size.attention_state,
            alignment_history=state_size.alignment_history,
            all_time_masks=self._sequence_len,
            all_cell_states=all_cell_states)

    def zero_state(self, batch_size, dtype):
        """Modified  from tensorflow's zero_state
            see there for description of the parameters"""

        # use AttentionWrapperState from superclass
        zero_state = super(TimeAttentionWrapper,
                           self).zero_state(batch_size, dtype)

        with tf.name_scope(type(self).__name__ + "ZeroState",
                           values=[batch_size]):
            # store time masks
            all_time_masks = tf.TensorArray(
                tf.int32,
                size=self._sequence_len + 1,
                dynamic_size=False,
                clear_after_read=False
            ).write(0, tf.zeros([batch_size, self.state_size.all_time_masks],
                                tf.int32))

            # store all cell states into a tensor array to allow
            # copy mechanism to go back in time
            if isinstance(self._cell.state_size,
                          tf.contrib.rnn.LSTMStateTuple):
                all_cell_states = tf.contrib.rnn.LSTMStateTuple(
                    tf.TensorArray(dtype, size=self._sequence_len + 1,
                                   dynamic_size=False,
                                   clear_after_read=False
                                   ).write(0, zero_state.cell_state.c),
                    tf.TensorArray(dtype, size=self._sequence_len + 1,
                                   dynamic_size=False,
                                   clear_after_read=False
                                   ).write(0, zero_state.cell_state.h)
                )
            else:
                all_cell_states = tf.TensorArray(
                    dtype, size=0,
                    dynamic_size=False,
                    clear_after_read=False
                ).write(0, zero_state.cell_state)

            return TimeAttentionWrapperState(
                cell_state=zero_state.cell_state,
                time=zero_state.time,
                attention=zero_state.attention,
                alignments=zero_state.alignments,
                attention_state=zero_state.attention_state,
                alignment_history=zero_state.alignment_history,
                all_time_masks=all_time_masks,
                all_cell_states=all_cell_states
            )

    def call(self, inputs, state):
        """Perform a step of attention-wrapped RNN.

        The order has changed:
        - Step 1: Calculate attention inputs based on the previous cell state
                  and current inputs
        - Step 2: Score the output with `attention_mechanism`.
        - Step 3: Calculate the alignments by passing the score through the
                  `normalizer` and limit them by time.
        - Step 4: Calculate the context vector as the inner product between the
                  alignments and the attention_mechanism's values (memory).
        - Step 5: Calculate the attention output by concatenating
                  the cell output and context through the attention layer
                  (a linear layer with `attention_layer_size` outputs).
        - Step 6: Mix the `inputs` and `attention` output via
                  `cell_input_fn` to get cell inputs.
        - Step 7: Call the wrapped `cell` with these cell inputs and
                  its previous state.
        - Step 8: (optional) Maybe copy output and cell state from history

        Args:
          inputs: (Possibly nested tuple of) Tensor,
                  the input at this time step.
          state: An instance of `TimeAttentionWrapperState`
                 containing tensors from the previous time step.

        Returns:
          A tuple `(attention_or_cell_output, next_state)`, where:

          - `attention_or_cell_output` depending on `output_attention`.
          - `next_state` is an instance of `TimeAttentionWrapperState`
             containing the state calculated at this time step.

        Raises:
          TypeError: If `state` is not an instance of
          `TimeAttentionWrapperState`.
        """
        if not isinstance(state, TimeAttentionWrapperState):
            raise TypeError("Expected state to be instance of "
                            "TimeAttentionWrapperState. "
                            "Received type {} instead.".format(type(state)))

        # Step 1: Calculate attention based on
        #         the previous output and current input
        cell_state = state.cell_state

        rnn_inputs, attn_inputs = self._rnn_and_attn_inputs_fn(inputs,
                                                               cell_state)

        cell_batch_size = (
            attn_inputs.shape[0].value or
            tf.shape(attn_inputs)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  "
            "Are you using "
            "the BeamSearchDecoder?  "
            "You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with tf.control_dependencies(
                self._batch_size_checks(cell_batch_size, error_message)):
            attn_inputs = tf.identity(
                attn_inputs, name="checked_attn_inputs")

        if self._is_multi:
            previous_attention_state = state.attention_state
            previous_alignment_history = state.alignment_history
        else:
            previous_attention_state = [state.attention_state]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_attention_states = []
        maybe_all_histories = []

        prev_time_masks = self._read_from_tensor_array(state.all_time_masks,
                                                       state.time)
        prev_time_mask = prev_time_masks[:, -1, :]

        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            # Steps 2 - 5 are performed inside `_compute_time_attention`
            (attention, alignments,
             next_attention_state) = _compute_time_attention(
                attention_mechanism, attn_inputs,
                previous_attention_state[i],
                # time is added to calculate time attention
                state.time, self._timed_ntms[i],
                # provide boolean masks, to ignore some time steps
                prev_time_mask, self._ignore_mask[i],
                self._attention_layers[i]
                if self._attention_layers else None)

            alignment_history = previous_alignment_history[i].write(
                state.time, alignments) if self._alignment_history else ()

            all_attention_states.append(next_attention_state)
            all_alignments.append(alignments)
            all_attentions.append(attention)
            maybe_all_histories.append(alignment_history)

        attention = tf.concat(all_attentions, 1)

        # Step 6: Mix the `inputs` and `attention` output via
        #         `cell_input_fn` to get cell inputs.
        cell_inputs = self._cell_input_fn(rnn_inputs, attention)

        # Step 7: Call the wrapped `cell` with these cell inputs and
        #         its previous state.
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        prev_all_cell_states = state.all_cell_states

        time_mask = tf.concat([prev_time_mask[:, :state.time],
                               tf.ones_like(prev_time_mask[:, :1]),
                               prev_time_mask[:, state.time + 1:]], 1)

        if self._index_of_attn_to_copy is not None:
            # Step 8: Maybe copy output and cell state from history

            # get relevant previous outputs from history
            attn_to_copy = all_attentions[self._index_of_attn_to_copy]
            # copy them to current output
            cell_output_with_attn = cell_output + attn_to_copy

            memory_probs = self._get_memory_probs(all_alignments, state.time)

            # check that we do not pay attention to `tensor_not_to_copy`
            bin_likelihood_not_to_copy, _ = self._likelihood_fn(
                cell_output_with_attn, self._tensor_not_to_copy)
            # recalculate probs
            memory_probs *= 1 - bin_likelihood_not_to_copy

            history_alignments = self._history_alignments(memory_probs)

            # get previous output from the history
            prev_output = self._prev_output(cell_output_with_attn,
                                            history_alignments,
                                            state.time)

            # check that current output is close to
            # the one in the history to which we pay attention to
            bin_likelihood_to_copy, _ = self._likelihood_fn(
                cell_output_with_attn, prev_output)
            # recalculate probs
            memory_probs *= bin_likelihood_to_copy

            history_alignments = self._history_alignments(memory_probs)
            current_time_prob = history_alignments[:, -1:]

            # create additional likelihoods to maximize
            attn_likelihood = self._additional_likelihood(
                attn_to_copy,
                prev_output,
                current_time_prob
            )
            state_likelihood = self._additional_likelihood(
                cell_output + tf.stop_gradient(attn_to_copy),
                prev_output,
                current_time_prob
            )

            # recalculate time_mask
            time_mask = self._apply_alignments_to_history(
                tf.cast(history_alignments, time_mask.dtype),
                prev_time_masks[:, :-1, :],
                time_mask
            )

            # recalculate new next_cell_state based on history_alignments
            next_cell_state = self._new_next_cell_state(
                prev_all_cell_states,
                next_cell_state,
                cell_output_with_attn,
                history_alignments,
                state.time
            )

            all_cell_states = self._all_cell_states(
                prev_all_cell_states,
                next_cell_state,
                state.time
            )

            if self._output_attention:
                # concatenate cell outputs, attention, additional likelihoods
                # and copy_attn_debug
                output = tf.concat([cell_output_with_attn,
                                    cell_output,
                                    attention,
                                    # additional likelihoods
                                    attn_likelihood, state_likelihood,
                                    # copy_attn_debug
                                    bin_likelihood_not_to_copy,
                                    bin_likelihood_to_copy,
                                    current_time_prob], 1)
            else:
                output = cell_output_with_attn

        else:
            # do not waste resources on storing history
            all_cell_states = prev_all_cell_states

            if self._output_attention:
                output = tf.concat([cell_output, attention], 1)
            else:
                output = cell_output

        all_time_masks = state.all_time_masks.write(state.time + 1, time_mask)

        next_state = TimeAttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            attention_state=self._item_or_tuple(all_attention_states),
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(maybe_all_histories),
            all_time_masks=all_time_masks,
            all_cell_states=all_cell_states
        )
        return output, next_state

    # helper for TensorArray
    @staticmethod
    def _read_from_tensor_array(tensor_array, time):
        """TensorArray time reader"""
        return tf.transpose(tensor_array.gather(tf.range(0, time + 1)),
                            [1, 0, 2])

    # helper methods for copy mechanism
    def _get_memory_probs(self, all_alignments, time):
        """Helper method to get memory_probs from all_alignments"""

        memory_probs = tf.stop_gradient(all_alignments[
                                            self._index_of_attn_to_copy][:,
                                        :time])

        # binarize memory_probs only if max value is larger than margin=0.1
        memory_probs_max = tf.reduce_max(memory_probs, axis=1, keepdims=True)
        memory_probs_max = tf.where(memory_probs_max > 0.1,
                                    memory_probs_max, -memory_probs_max)

        return tf.where(tf.equal(memory_probs, memory_probs_max),
                        tf.ones_like(memory_probs),
                        tf.zeros_like(memory_probs))

    @staticmethod
    def _history_alignments(memory_probs):
        """Helper method to apply binary mask to memory_probs"""

        current_time_prob = 1 - tf.reduce_sum(memory_probs, 1, keepdims=True)
        return tf.concat([memory_probs, current_time_prob], 1)

    @staticmethod
    def _apply_alignments_to_history(alignments, history_states, state):
        """Helper method to apply attention probabilities to rnn history

        copied from tf's `_compute_attention(...)`"""

        expanded_alignments = tf.stop_gradient(tf.expand_dims(alignments, 1))

        history_states = tf.concat([history_states,
                                    tf.expand_dims(state, 1)], 1)

        # Context is the inner product of alignments and values along the
        # memory time dimension.
        # expanded_alignments shape is
        #   [batch_size, 1, memory_time]
        # history_states shape is
        #   [batch_size, memory_time, memory_size]
        # the batched matmul is over memory_time, so the output shape is
        #   [batch_size, 1, memory_size].
        # we then squeeze out the singleton dim.

        return tf.squeeze(tf.matmul(expanded_alignments, history_states), [1])

    def _prev_output(self, state, alignments, time):
        """Helper method to get previous output from memory"""

        # get all previous outputs from appropriate
        # attention mechanism's memory limited by current time
        prev_outputs = tf.stop_gradient(self._attention_mechanisms[
                                        self._index_of_attn_to_copy].values[
                                        :, :time, :])

        # multiply by alignments to get one vector from one time step
        return self._apply_alignments_to_history(alignments,
                                                 prev_outputs,
                                                 state)

    def _additional_likelihood(self, output, prev_output, current_time_prob):
        """Helper method to create additional likelihood to maximize"""

        _, likelihood = self._likelihood_fn(
            output, tf.stop_gradient(prev_output))
        return tf.where(current_time_prob < 0.5,
                        likelihood, tf.ones_like(likelihood))

    def _new_hidden_state(self, prev_all_cell_states,
                          new_state, alignments, time):
        """Helper method to look into rnn history"""

        # reshape to (batch, time, memory_time) and
        # do not include current time because
        # we do not want to pay attention to it,
        # but we need to read it instead of
        # adding conditional flow if time == 0
        prev_cell_states = self._read_from_tensor_array(prev_all_cell_states,
                                                        time)[:, :-1, :]

        return self._apply_alignments_to_history(alignments,
                                                 prev_cell_states,
                                                 new_state)

    def _new_next_cell_state(self, prev_all_cell_states,
                             next_cell_state, new_cell_output,
                             alignments, time):
        """Helper method to recalculate new next_cell_state"""

        if isinstance(next_cell_state, tf.contrib.rnn.LSTMStateTuple):
            next_cell_state_c = self._new_hidden_state(
                prev_all_cell_states.c,
                next_cell_state.c,
                alignments,
                time
            )
            next_cell_state_h = self._new_hidden_state(
                prev_all_cell_states.h,
                new_cell_output,
                alignments,
                time
            )
            return tf.contrib.rnn.LSTMStateTuple(next_cell_state_c,
                                                 next_cell_state_h)
        else:
            return self._new_hidden_state(prev_all_cell_states,
                                          alignments, new_cell_output, time)

    @staticmethod
    def _all_cell_states(prev_all_cell_states, next_cell_state, time):
        """Helper method to recalculate all_cell_states tensor array"""

        if isinstance(next_cell_state, tf.contrib.rnn.LSTMStateTuple):
            return tf.contrib.rnn.LSTMStateTuple(
                prev_all_cell_states.c.write(time + 1, next_cell_state.c),
                prev_all_cell_states.h.write(time + 1, next_cell_state.h)
            )
        else:
            return prev_all_cell_states.write(time + 1, next_cell_state)


class ChronoBiasLayerNormBasicLSTMCell(tf.contrib.rnn.LayerNormBasicLSTMCell):
    """Custom LayerNormBasicLSTMCell that allows chrono initialization
        of gate biases.

        See super class for description.

        See https://arxiv.org/abs/1804.11188
        for details about chrono initialization
    """

    def __init__(self,
                 num_units,
                 forget_bias=1.0,
                 input_bias=0.0,
                 activation=tf.tanh,
                 layer_norm=True,
                 norm_gain=1.0,
                 norm_shift=0.0,
                 dropout_keep_prob=1.0,
                 dropout_prob_seed=None,
                 out_layer_size=None,
                 reuse=None):
        """Initializes the basic LSTM cell

        Additional args:
            input_bias: float, The bias added to input gates.
            out_layer_size: (optional) integer, The number of units in
                the optional additional output layer.
        """
        super(ChronoBiasLayerNormBasicLSTMCell, self).__init__(
            num_units,
            forget_bias=forget_bias,
            activation=activation,
            layer_norm=layer_norm,
            norm_gain=norm_gain,
            norm_shift=norm_shift,
            dropout_keep_prob=dropout_keep_prob,
            dropout_prob_seed=dropout_prob_seed,
            reuse=reuse
        )
        self._input_bias = input_bias
        self._out_layer_size = out_layer_size

    @property
    def output_size(self):
        return self._out_layer_size or self._num_units

    @property
    def state_size(self):
        return tf.contrib.rnn.LSTMStateTuple(self._num_units,
                                             self.output_size)

    @staticmethod
    def _dense_layer(args, layer_size):
        """Optional out projection layer"""
        proj_size = args.get_shape()[-1]
        dtype = args.dtype
        weights = tf.get_variable("kernel",
                                  [proj_size, layer_size],
                                  dtype=dtype)
        bias = tf.get_variable("bias",
                               [layer_size],
                               dtype=dtype)
        out = tf.nn.bias_add(tf.matmul(args, weights), bias)
        return out

    def call(self, inputs, state):
        """LSTM cell with layer normalization and recurrent dropout."""
        c, h = state
        args = tf.concat([inputs, h], 1)
        concat = self._linear(args)
        dtype = args.dtype

        i, j, f, o = tf.split(value=concat, num_or_size_splits=4, axis=1)
        if self._layer_norm:
            i = self._norm(i, "input", dtype=dtype)
            j = self._norm(j, "transform", dtype=dtype)
            f = self._norm(f, "forget", dtype=dtype)
            o = self._norm(o, "output", dtype=dtype)

        g = self._activation(j)
        if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
            g = tf.nn.dropout(g, self._keep_prob, seed=self._seed)

        new_c = (c * tf.sigmoid(f + self._forget_bias) +
                 g * tf.sigmoid(i + self._input_bias))  # added input_bias

        # do not do layer normalization on the new c,
        # because there are no trainable weights
        # if self._layer_norm:
        #     new_c = self._norm(new_c, "state", dtype=dtype)

        new_h = self._activation(new_c) * tf.sigmoid(o)

        # added dropout to the hidden state h
        if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
            new_h = tf.nn.dropout(new_h, self._keep_prob, seed=self._seed)

        # add postprocessing of the output
        if self._out_layer_size is not None:
            with tf.variable_scope('out_layer'):
                new_h = self._dense_layer(new_h, self._out_layer_size)

        new_state = tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
        return new_h, new_state