Tensorflow Attention API 源码阅读2

源码中的注释:

Perform a step of attention-wrapped RNN. - Step 1: Mix the inputs and previous step's attention output via cell_input_fn.
第一步:将 input 和 上一步得到的 attention 使用 cell_input_fn, 也就是 tf.concat([inputs, attention], -1))
cell_inputs = self._cell_input_fn(inputs, state.attention)

  • Step 2: Call the wrapped cell with this input and its previous state.
    第二步:使用 RNNCell 计算 cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

  • Step 3: Score the cell's output with attention_mechanism.
  • Step 4: Calculate the alignments by passing the score through the normalizer.
  • Step 5: Calculate the context vector as the inner product between the alignments and the attention_mechanism's values (memory).
  • Step 6: Calculate the attention output by concatenating the cell output and context through the attention layer (a linear layer with attention_layer_size outputs). Args: inputs: (Possibly nested tuple of) Tensor, the input at this time step. state: An instance of AttentionWrapperState containing tensors from the previous time step. Returns: A tuple (attention_or_cell_output, next_state), where:
  • attention_or_cell_output depending on output_attention.
  • next_state is an instance of AttentionWrapperState containing the state calculated at this time step. Raises: TypeError: If state is not an instance of AttentionWrapperState.