虽然transformer基本上代替了基于RNN的seq2seq,但在某些地方还是会用到。比如最近再看关于gan+rl做文本生成的paper大都还是rnn_based seq2seq。所以自己动手去魔改这些代码还是需要好好看下seq2seq的google官方实现~

tensorflow.seq2seq

tensorflow.contrib.seq2seq.python.ops seq2seq 模块包含8个python文件。下面分别介绍每一个文件中类或函数的作用,以及其设计。(but什么是设计,我是不是应该先看看设计模式的书,不然我都不知道我在说啥。。anyway。。)

  • attention_wrapper.py
  • basic_decoder.py
  • beam_search_ops.py
  • beam_search_decoder.py
  • decoder.py
  • helper.py
  • loss.py
  • sample.py

先有个大致的印象,整个这个模块的目的就是让我们更容易基于 LuongAttentionBahdanauAttention 各种花式attention,sampling等去实现seq2seq.

粗略的过一遍类/函数的说明之后,能发现这8个文件源码的阅读顺序是:
- decoder.py 定义了抽象类 class Decoder(object):.

  • basic_decoder.py 定义了 class BasicDecoder(decoder.Decoder): 显然它是继承的上一个 Decoder 类。相比之下,他多了 init 初始化,以及 step 实现细节。我们之后的 decoder 基本上都是用这个来实现的。

  • attention_wrapper.py 接下来为啥看这个呢,因为我们发现前面 BasicDecoder 类的初始化参数需要 cellhelper. 而cell就是 RNNCell. 实际上带上了attention的decoder也依旧基于rnn的,作者们使用了一个封装了attention_mechanism 的 RNNCell,其本质上依旧是 RNNCell.

  • helper.pyBasicDecoder 的初始化定义中还需要helper参数。看了文件说明就知道是training或inference阶段的sample的操作。我们通过RNNCell得到了这个time step的 cell_outputs 和 cell_state. 然后要基于 cell_outputs 来sample得到当前time step的token,以及如何传递到下一个time step. 这些都是在 helper 中完成的。

前面这4个文件构成了seq2seq的主要框架。接下来 sample.py, beam_search, loss.py理解起来就容易了。

encoder

在详细讲上述8个文件前,先插一脚,说说encoder.可以发现,上述代码中并没有发现编码器,因为编码器是一个比较基础的网络结构,普通的BasicRNNCell或者BasicLSTMCell就可以作为编码器,需要注意的是编码器的输出应该具有两个成分:
- 每一时刻的输出 - 最后时刻的隐含状态

如果我们用python中的namedtuple数据结构来表示编码器的输出,其可以表示如下: EncoderOutput = namedtuple("encoder_output", "outputs_final_state")

需要注意的是,如果编码器是GRU或是LSTM,或是stackRNN. 其对应的最后一层的 final_state 都是不同的。如何传递到decoder的init_state都需要注意。

如果是单层(单向或双向)可以直接使用 tf.keras.layer.GRU 或者 tf.keras.layer.LSTM. 如果是多层的呢,需要先定义RNNCell, tf.keras.layers.StackedRNNCells, 然后使用 tf.keras.layers.RNN.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
gru_rnn = tf.keras.layers.GRU(units=32,
return_sequences=True,
return_state=True,
go_backwards=True)
lstm_rnn = tf.keras.layers.LSTM(units=32,
return_sequences=True,
return_state=True,
go_backwards=True)
tmp_inputs = tf.random_normal((5, 10, 16), dtype=tf.float32)
print(gru_rnn(tmp_inputs))
print(lstm_rnn(tmp_inputs))

[<tf.Tensor 'gru_1_1/transpose_1:0' shape=(5, 10, 32) dtype=float32>, <tf.Tensor 'gru_1_1/while/Exit_3:0' shape=(5, 32) dtype=float32>]
[<tf.Tensor 'lstm_1/transpose_1:0' shape=(5, 10, 32) dtype=float32>, <tf.Tensor 'lstm_1/while/Exit_3:0' shape=(5, 32) dtype=float32>, <tf.Tensor 'lstm_1/while/Exit_4:0' shape=(5, 32) dtype=float32>]

输出都是list,但是LSTM的输出中元素有3个,[enc_output, finale_h_state, final_c_state]. 可以看到 state 的命名是 while/Exit,其内部实现是用了 tf.while_loop 来实现的。回到之前的问题,如果decoder的 RNNCell也是 LSTM,如何去定义init_state这是接下来需要关注的点。

如果不用 tf.keras.layer 而是直接用 tf.layers 需要自己写好 variable_scope. 好像 tf.1.14 版本之前都不适配 tf.keras.

decoder.py

1
2
3
class Decoder(object):
"""An RNN Decoder abstract interface object.
"""

Decoder抽象类提供了 batch_sizeoutput_sizeoutput_dtypeinitializestep等未实现的抽象函数,这些函数都是一个具体Decoder类必须要实现的函数。

重点需要注意的是 initialize 和 step 函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
 def initialize(self, name=None):
"""Called before any decoding iterations.

Returns:
`(finished, initial_inputs, initial_state)`: initial values of
'finished' flags, inputs and state.
"""

@abc.abstractmethod
def step(self, time, inputs, state, name=None):
"""Called per step of decoding (but only once for dynamic decoding).
Args:
time: Scalar `int32` tensor. Current step number.
inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time
step.
state: RNNCell state (possibly nested tuple of) tensor[s] from previous
time step.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`: `outputs` is an object
containing the decoder output, `next_state` is a (structure of) state
tensors and TensorArrays, `next_inputs` is the tensor that should be used
as input for the next step, `finished` is a boolean tensor telling whether
the sequence is complete, for each sequence in the batch.
"""
raise NotImplementedError

initialize 这里面initialize函数的功能是提供每一步解码的输入、初始状态、是否完成解码,即(finished, first_inputs, initial_state);step函数的功能是执行解码操作,提供输入和状态就能通过解码得到下一时刻的输入以及状态。

basic_decoder.py

1
2
3
4
5
6
7
8
"""
A class of Decoders that may sample to generate the next input.
解码过程中的one step. 其可能涉及到各式各样的sample.
"""
__all__ = [
"BasicDecoderOutput",
"BasicDecoder",
]
1
2
3
class BasicDecoderOutput(
collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))):
pass

BasicDecoderOutput 定义了一个nametuple,包含两个元素 rnn_outputsample_id. 顾名思义就不说了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class BasicDecoder(decoder.Decoder):
"""Basic sampling decoder."""
def __init__(self, cell, helper, initial_state, output_layer=None):
"""Initialize BasicDecoder."""
def step(self, time, inputs, state, name=None):
"""Perform a decoding step.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
cell_outputs, cell_state = self._cell(inputs, state)
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
sample_ids = self._helper.sample(
time=time, outputs=cell_outputs, state=cell_state)
(finished, next_inputs, next_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)

看这个代码你会发现 decoder_step 其实就干了两件事。

  • 基于rnn_cell得到 cell_output 和 cell_output。这一步可以是很单纯的rnncell,也可以是封装了attention的rnncell.

  • 然后基于helper中的sample和cell_output得到下一个time_step的输入.如果是训练阶段,下一个time step的输入是读取target sentence得到的,如果是infer阶段,输入是sample得到的。

我们发现time, finished都是在 helper.next_inputs这个函数相关的。是因为helper.py决定了sample的方式,同时它还需要判断是否生成了 <sos> token.所以是否完成解码也是在这个function判断的。

attention_wrapper.py

这个python文件中包含的类/函数比较多,但是主要还是理解一下5个类。其作用是封装了 attention 的RNNCell.

1
2
3
4
5
6
__all__ = [
"AttentionMechanism",
"AttentionWrapper",
"AttentionWrapperState",
"LuongAttention",
"BahdanauAttention"]

AttentionWrapper就是封装好了的wrapper,而 LuongAttention,BahdanauAttention 对应不同的 attention_mechanism.

其实之前看过一遍这部分的源码了,笔记在这儿 https://panxiaoxie.cn/2018/09/01/tensorflow-Attention-API/ 不同之处在于,现在的版本多了继承 keras. 如果是基于keras来写的话,用 V2 版.

AttentionMechanism and _BaseAttentionMechanism

1
2
3
4
5
6
7
8
class AttentionMechanism(object):
@property
def alignments_size(self):
raise NotImplementedError

@property
def state_size(self):
raise NotImplementedError
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class _BaseAttentionMechanismV2(AttentionMechanism, layers.Layer):
"""A base AttentionMechanism class providing common functionality.
Common functionality includes:
1. Storing the query and memory layers.
2. Preprocessing and storing the memory.
"""

def __init__(self,
memory,
probability_fn,
query_layer=None,
memory_layer=None,
memory_sequence_length=None,
**kwargs):
"""Construct base AttentionMechanism class.
"""

def call(self, inputs, mask=None, setup_memory=False, **kwargs):
"""Setup the memory or query the attention.
There are two case here, one for setup memory, and the second is query the
attention score. `setup_memory` is the flag to indicate which mode it is.
The input list will be treated differently based on that flag.
Args:
inputs: a list of tensor that could either be `query` and `state`, or
`memory` and `memory_sequence_length`. `query` is the tensor of dtype
matching `memory` and shape `[batch_size, query_depth]`. `state` is the
tensor of dtype matching `memory` and shape `[batch_size,
alignments_size]`. (`alignments_size` is memory's `max_time`). `memory`
is the memory to query; usually the output of an RNN encoder. The tensor
should be shaped `[batch_size, max_time, ...]`. `memory_sequence_length`
(optional) is the sequence lengths for the batch entries in memory. If
provided, the memory tensor rows are masked with zeros for values past
the respective sequence lengths.
mask: optional bool tensor with shape `[batch, max_time]` for the mask of
memory. If it is not None, the corresponding item of the memory should
be filtered out during calculation.
setup_memory: boolean, whether the input is for setting up memory, or
query attention.
**kwargs: Dict, other keyword arguments for the call method.
Returns:
Either processed memory or attention score, based on `setup_memory`.
"""
if setup_memory:
if isinstance(inputs, list):
if len(inputs) not in (1, 2):
raise ValueError("Expect inputs to have 1 or 2 tensors, got %d" %
len(inputs))
memory = inputs[0]
memory_sequence_length = inputs[1] if len(inputs) == 2 else None
memory_mask = mask
else:
memory, memory_sequence_length = inputs, None
memory_mask = mask
self._setup_memory(memory, memory_sequence_length, memory_mask)
# We force the self.built to false here since only memory is initialized,
# but the real query/state has not been call() yet. The layer should be
# build and call again.
self.built = False
# Return the processed memory in order to create the Keras connectivity
# data for it.
return self.values
else:
if not self._memory_initialized:
raise ValueError("Cannot query the attention before the setup of "
"memory")
if len(inputs) not in (2, 3):
raise ValueError("Expect the inputs to have query, state, and optional "
"processed memory, got %d items" % len(inputs))
# Ignore the rest of the inputs and only care about the query and state
query, state = inputs[0], inputs[1]
return self._calculate_attention(query, state)

def _calculate_attention(self, query, state):
raise NotImplementedError(
"_calculate_attention need to be implemented by subclasses.")

_BaseAttentionMechanismV2 相比原版 _BaseAttentionMechanism 继承了 keras.layer.Layer.在以后的 tensorflow 中就用这个吧。就不用自己去管理变量了,keras会自动命名。

参数 memory 是encoder的输出,[batch, max_times, enc_size]. probability_fn 通常是 softmax. query_layer 和 memory_layer 通常就是全连接层,并且这俩全链接层的输出units必须一致。

BahdanauAttentionV2 和 LuongAttentionV2 需要在这个类的基础上实现 _calculate_attention.

BahdanauAttentionV2 和 LuongAttentionV2

之前的笔记差不多介绍过了具体的实现方式 https://panxiaoxie.cn/2018/09/01/tensorflow-Attention-API/ ,所以具体实现方式这里就不说了。看论文公式就大致明白了~

作为一个类对象时,AttentionMechanism,BahdanauAttention,LuongAttention它们具有如下属性:

  • query_layer: 在 BahdanauAttention 中一般是 tf.layer.dense 的实例对象,其维度是 num_units. 所以 BahdanauAttention 中 query 的维度可以是任意值。而 LuongAttention 中 query_layer 为 None,所以 query 的维度只能是 num_units.
  • memory_layer: 在两个 attention 中都是一样的,tf.layer.dense,且维度为 num_units.
  • alignments_size: 对齐size,是 memory 的 max_times.
  • batch_size: 批量大小
  • values: 是经过 mask 处理后的 memory. [batch, max_times, embed_size]
  • keys: 是经过 memory_layer 全链接处理后的。 [batch, max_times, num_units].
  • state_size: 等于 alignment_size.

key 和 value 是基于memory得到的。query是基于 target inputs(training) 或是上一个step的输出得到的。

AttentionWrapper

定义好了 attention_mechanism 和 rnn_cell 之后,AttentionWrapper 在把两者封装起来,构成一个带有attention机制的RNNCell.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class AttentionWrapper(rnn_cell_impl.RNNCell):
"""Wraps another `RNNCell` with attention."""

def __init__(self,
cell,
attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
name=None,
attention_layer=None,
attention_fn=None):
"""Construct the `AttentionWrapper`.
"""
def call(self, inputs, state):
"""Perform a step of attention-wrapped RNN.

- Step 1: Mix the `inputs` and previous step's `attention` output via
`cell_input_fn`.
- Step 2: Call the wrapped `cell` with this input and its previous state.
- Step 3: Score the cell's output with `attention_mechanism`.
- Step 4: Calculate the alignments by passing the score through the
`normalizer`.
- Step 5: Calculate the context vector as the inner product between the
alignments and the attention_mechanism's values (memory).
- Step 6: Calculate the attention output by concatenating the cell output
and context through the attention layer (a linear layer with
`attention_layer_size` outputs).
"""

参数:

  • cell: An instance of RNNCell,可以是 GRUCell,LSTMCell,或是 StackedRNNCellsRNNCell.

  • attention_mechanism: A list of AttentionMechanism instances or a single instance. 还可以是 list,也就是说这里就可以实现多种attention的融合咯?能像机器阅读理解堆attenton那样灌水吗,审稿人都腻了吧。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
batch_size = 5
seq_len = [4,5,3,5,6]
max_times = 6
num_units = 16

enc_output = tf.random.normal((batch_size, max_times, num_units), dtype=tf.float32)
rnncell = rnn_cell.LSTMCell(num_units=16)
attention_mechanism = attention_wrapper.BahdanauAttention(
num_units=num_units,
memory=enc_output,
memory_sequence_length=seq_len)
attnRNNCell= attention_wrapper.AttentionWrapper(
cell=rnncell,
attention_mechanism=attention_mechanism,
alignment_history=True)
print(attnRNNCell.output_size, attnRNNCell.state_size)

print(attnRNNCell.zero_state(batch_size, tf.float32))

16 AttentionWrapperState(cell_state=LSTMStateTuple(c=16, h=16), attention=16, time=TensorShape([]), alignments=6, alignment_history=6, attention_state=6)
AttentionWrapperState(cell_state=LSTMStateTuple(c=<tf.Tensor 'AttentionWrapperZeroState/checked_cell_state:0' shape=(5, 16) dtype=float32>, h=<tf.Tensor 'AttentionWrapperZeroState/checked_cell_state_1:0' shape=(5, 16) dtype=float32>), attention=<tf.Tensor 'AttentionWrapperZeroState/zeros_2:0' shape=(5, 16) dtype=float32>, time=<tf.Tensor 'AttentionWrapperZeroState/zeros_1:0' shape=() dtype=int32>, alignments=<tf.Tensor 'AttentionWrapperZeroState/zeros:0' shape=(5, 6) dtype=float32>, alignment_history=<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x7fb556cd7f60>, attention_state=<tf.Tensor 'AttentionWrapperZeroState/zeros_3:0' shape=(5, 6) dtype=float32>)

helper.py

前面 BasicDecoder 中的参数需要 cell, helper. cell 就是前面 AttentionWrapper 定义的。那么 helper 呢?

Helper

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Helper(object):
"""Interface for implementing sampling in seq2seq decoders.
Helper instances are used by `BasicDecoder`.
"""

@abc.abstractproperty
def batch_size(self):
"""Batch size of tensor returned by `sample`.
Returns a scalar int32 tensor.
"""
raise NotImplementedError("batch_size has not been implemented")

@abc.abstractproperty
def sample_ids_shape(self):
"""Shape of tensor returned by `sample`, excluding the batch dimension.
Returns a `TensorShape`.
"""
raise NotImplementedError("sample_ids_shape has not been implemented")

@abc.abstractproperty
def sample_ids_dtype(self):
"""DType of tensor returned by `sample`.
Returns a DType.
"""
raise NotImplementedError("sample_ids_dtype has not been implemented")

@abc.abstractmethod
def initialize(self, name=None):
"""Returns `(initial_finished, initial_inputs)`."""
pass

@abc.abstractmethod
def sample(self, time, outputs, state, name=None):
"""Returns `sample_ids`."""
pass

@abc.abstractmethod
def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""Returns `(finished, next_inputs, next_state)`."""
pass

定义了一个抽象类,抽象方法包括 initialize, sample, next_inputs 等需要用户去具体实现。

那么接下里看看 TrainingHelper 的实现。在training阶段,是teaching forcing的,所以输出是直接读取 target sentence.

TrainingHelper

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class TrainingHelper(Helper):
"""A helper for use during training. Only reads inputs.
Returned sample_ids are the argmax of the RNN output logits.
"""

def __init__(self, inputs, sequence_length, time_major=False, name=None):
"""Initializer.
Args:
inputs: A (structure of) input tensors.
sequence_length: An int32 vector tensor.
time_major: Python bool. Whether the tensors in `inputs` are time major.
If `False` (default), they are assumed to be batch major.
name: Name scope for any created operations.
Raises:
ValueError: if `sequence_length` is not a 1D tensor.
"""
with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
inputs = ops.convert_to_tensor(inputs, name="inputs")
self._inputs = inputs
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)

self._input_tas = nest.map_structure(_unstack_ta, inputs)
self._sequence_length = ops.convert_to_tensor(
sequence_length, name="sequence_length")

self._zero_inputs = nest.map_structure(
lambda inp: array_ops.zeros_like(inp[0, :]), inputs)

self._batch_size = array_ops.size(sequence_length)

def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)

def sample(self, time, outputs, name=None, **unused_kwargs):
with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
sample_ids = math_ops.cast(
math_ops.argmax(outputs, axis=-1), dtypes.int32)
return sample_ids

def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
"""next_inputs_fn for TrainingHelper."""
with ops.name_scope(name, "TrainingHelperNextInputs",
[time, outputs, state]):
next_time = time + 1
finished = (next_time >= self._sequence_length)
all_finished = math_ops.reduce_all(finished)
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))
return (finished, next_inputs, state)

initialize

具体看 initialize, sample, next_inputs 的实现。

initialize 是初始化,返回的结果是 finished 表示是否完成decoder的状态。因为training阶段不需要去判断是否生成了 sos token .所以判断是否完成decoder只需要判断next time 是否达到 self.sequence_length 即可. 另一个返回结果是 next_inputs,是 tf.cond 函数。这里面的初始化就是判断 self.sequence_length 是否是0. 一般情况下都不会是 0 对吧. 所以这里的 tf.cond一般是返回后者,那么 next_inputs 是一个 TensorArray self._input_tas.

1
2
3
4
5
6
7
8
def cond(pred,
true_fn=None,
false_fn=None,
strict=False,
name=None,
fn1=None,
fn2=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`."""
1
2
3
4
5
6
7
8
def _unstack_ta(inp):
return tensor_array_ops.TensorArray(
dtype=inp.dtype, size=array_ops.shape(inp)[0],
element_shape=inp.get_shape()[1:]).unstack(inp)
self._input_tas = nest.map_structure(_unstack_ta, inputs)

ef map_structure(func, *structure, **kwargs):
"""Applies `func` to each entry in `structure` and returns a new structure."""

也就是把 tensor inputs 转换成了 TensorArray.

example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import tensorflow as tf
from tensorflow.python.util import nest
tf.enable_eager_execution()

batch_size = 5
tgt_len = [5,6,2,7,4]
tgt_max_times = 7
num_units = 16
tgt_inputs = tf.random.normal((batch_size, tgt_max_times, num_units), dtype=tf.float32)

def _unstack_ta(inp):
return tf.TensorArray(
dtype=inp.dtype, size=tf.shape(inp)[0],
element_shape=inp.get_shape()[1:]).unstack(inp)

next_inputs = nest.map_structure(_unstack_ta, tgt_inputs)
print(next_inputs)
print(next_inputs.stack().shape)

<tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x7f0c47dbe898>
(5, 7, 16)

sample

在训练阶段的输出就是直接 argmax(outputs) 就行了,得到的是 sample_id.

next_inputs

其实这个 Helper 也只是一个time step而已。需要传递给下一个时间步的有 (finished, next_inputs, state)。

1
2
3
4
5
def read_from_ta(inp):
return inp.read(next_time)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(read_from_ta, self._input_tas))

很容易理解,就是从 TensorArray self._input_tas 一个个读出来即可。要知道 TensorArray 有个特性就是 read() 之后就没了。

注意这里是直接 read(next_time). 所以 self._input_tas 必须转换成 time_major.

1
2
if not time_major:
inputs = nest.map_structure(_transpose_batch_time, inputs)

GreedyEmbeddingHelper

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class GreedyEmbeddingHelper(Helper):
"""A helper for use during inference.
Uses the argmax of the output (treated as logits) and passes the
result through an embedding layer to get the next input.
"""

def __init__(self, embedding, start_tokens, end_token):
"""Initializer.
Args:
embedding: A callable that takes a vector tensor of `ids` (argmax ids),
or the `params` argument for `embedding_lookup`. The returned tensor
will be passed to the decoder input.
start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
end_token: `int32` scalar, the token that marks end of decoding.
Raises:
ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
scalar.
"""
if callable(embedding):
self._embedding_fn = embedding
else:
self._embedding_fn = (
lambda ids: embedding_ops.embedding_lookup(embedding, ids))

self._start_tokens = ops.convert_to_tensor(
start_tokens, dtype=dtypes.int32, name="start_tokens")
self._end_token = ops.convert_to_tensor(
end_token, dtype=dtypes.int32, name="end_token")
if self._start_tokens.get_shape().ndims != 1:
raise ValueError("start_tokens must be a vector")
self._batch_size = array_ops.size(start_tokens)
if self._end_token.get_shape().ndims != 0:
raise ValueError("end_token must be a scalar")
self._start_inputs = self._embedding_fn(self._start_tokens)

@property
def batch_size(self):
return self._batch_size

@property
def sample_ids_shape(self):
return tensor_shape.TensorShape([])

@property
def sample_ids_dtype(self):
return dtypes.int32

def initialize(self, name=None):
finished = array_ops.tile([False], [self._batch_size])
return (finished, self._start_inputs)

def sample(self, time, outputs, state, name=None):
"""sample for GreedyEmbeddingHelper."""
del time, state # unused by sample_fn
# Outputs are logits, use argmax to get the most probable id
if not isinstance(outputs, ops.Tensor):
raise TypeError("Expected outputs to be a single Tensor, got: %s" %
type(outputs))
sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
return sample_ids

def next_inputs(self, time, outputs, state, sample_ids, name=None):
"""next_inputs_fn for GreedyEmbeddingHelper."""
del time, outputs # unused by next_inputs_fn
# 判断是否生成了 end_token。
finished = math_ops.equal(sample_ids, self._end_token)
# 这里用 all_finished 来代替当前的 finished.
all_finished = math_ops.reduce_all(finished)
# 如果finished为False,那就使用第二个fn计算得到的结果
next_inputs = control_flow_ops.cond(
all_finished,
# If we're finished, the next_inputs value doesn't matter
lambda: self._start_inputs,
lambda: self._embedding_fn(sample_ids))
return (finished, next_inputs, state)

在inference阶段是没有target sentence的,所以初始化参数 embedding, start_tokens, end_token 需要包括 <sos>, <eos> token 用来作为第一步的输入,以及判断是否出现 <eos> 并终止解码。

与 training 阶段的区别在于判断是否 finished 不同,以及 next_inputs 不是直接读取 inputs,而是 _embedding_fn(sample_ids).

decoder one step

了解完了 AttnRNNCell 和 helper 之后我们再回过头来看 basic_decoder.py 文件中的 BasicDecoder 中的 step 函数。这是执行decoder的一个step.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BasicDecoder(decoder.Decoder):
"""Basic sampling decoder."""
def step(self, time, inputs, state, name=None):
"""Perform a decoding step.
Args:
time: scalar `int32` tensor.
inputs: A (structure of) input tensors.
state: A (structure of) state tensors and TensorArrays.
name: Name scope for any created operations.
Returns:
`(outputs, next_state, next_inputs, finished)`.
"""
with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
# RNNCell的one step,如果是带有attention的,那么就是 AttentionWrapper.
cell_outputs, cell_state = self._cell(inputs, state)
# 如果没有 output_layer,那么 cell_outputs 和 cell_state 是一样的。
if self._output_layer is not None:
cell_outputs = self._output_layer(cell_outputs)
# 根据 helper中的sample函数来得到相应的token. 无论是training还是inference,都是根据cell_outputs得到的,最简单的即使 argmax,也已是 schedualesampling 等等.
sample_ids = self._helper.sample(
time=time, outputs=cell_outputs, state=cell_state)
# 根据helper中的next_inputs函数来迭代。next_outputs通过next_inputs函数计算得到。next_state就是cell_state.
(finished, next_inputs, next_state) = self._helper.next_inputs(
time=time,
outputs=cell_outputs,
state=cell_state,
sample_ids=sample_ids)
outputs = BasicDecoderOutput(cell_outputs, sample_ids)
return (outputs, next_state, next_inputs, finished)

dynamic_decode

前面封装好的 BasicDecoder 只是decoder 的一步而已。整个decoder过程还需要借助于 decoder.py 文件中的 dynamic_decode 函数。

参数:decoder 就是 BaseDecoder 实例。 impute_finished 表示对于不等长的batch,其中被masked部分的state应该是直接copy.这样得到的final state更准确。maximum_iterations 则表示最大解码长度。如果没有设置的话,则会直到生成 <eos> token 才会截止。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def dynamic_decode(decoder,
output_time_major=False,
impute_finished=False,
maximum_iterations=None,
parallel_iterations=32,
swap_memory=False,
scope=None,
**kwargs):
"""Perform dynamic decoding with `decoder`.
Calls initialize() once and step() repeatedly on the Decoder object.
Args:
decoder: A `Decoder` instance.
output_time_major: Python boolean. Default: `False` (batch major). If
`True`, outputs are returned as time major tensors (this mode is faster).
Otherwise, outputs are returned as batch major tensors (this adds extra
time to the computation).
impute_finished: Python boolean. If `True`, then states for batch
entries which are marked as finished get copied through and the
corresponding outputs get zeroed out. This causes some slowdown at
each time step, but ensures that the final state and outputs have
the correct values and that backprop ignores time steps that were
marked as finished.
maximum_iterations: `int32` scalar, maximum allowed number of decoding
steps. Default is `None` (decode until the decoder is fully done).
parallel_iterations: Argument passed to `tf.while_loop`.
swap_memory: Argument passed to `tf.while_loop`.
scope: Optional variable scope to use.
**kwargs: dict, other keyword arguments for dynamic_decode. It might contain
arguments for `BaseDecoder` to initialize, which takes all tensor inputs
during call().
Returns:
`(final_outputs, final_state, final_sequence_lengths)`.
Raises:
TypeError: if `decoder` is not an instance of `Decoder`.
ValueError: if `maximum_iterations` is provided but is not a scalar.
"""
if not isinstance(decoder, (Decoder, BaseDecoder)):
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))

with variable_scope.variable_scope(scope, "decoder") as varscope:
# Determine context types.
ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
in_while_loop = (
control_flow_util.GetContainingWhileContext(ctxt) is not None)
# Properly cache variable values inside the while_loop.
# Don't set a caching device when running in a loop, since it is possible
# that train steps could be wrapped in a tf.while_loop. In that scenario
# caching prevents forward computations in loop iterations from re-reading
# the updated weights.
if not context.executing_eagerly() and not in_while_loop:
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)

if maximum_iterations is not None:
maximum_iterations = ops.convert_to_tensor(
maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
if maximum_iterations.get_shape().ndims != 0:
raise ValueError("maximum_iterations must be a scalar")

# 第一步,初始化。
if isinstance(decoder, Decoder):
initial_finished, initial_inputs, initial_state = decoder.initialize()
else:
# For BaseDecoder that takes tensor inputs during call.
decoder_init_input = kwargs.pop("decoder_init_input", None)
decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {})
initial_finished, initial_inputs, initial_state = decoder.initialize(
decoder_init_input, **decoder_init_kwargs)

zero_outputs = _create_zero_outputs(decoder.output_size,
decoder.output_dtype,
decoder.batch_size)

if is_xla and maximum_iterations is None:
raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None:
initial_finished = math_ops.logical_or(
initial_finished, 0 >= maximum_iterations)
initial_sequence_lengths = array_ops.zeros_like(
initial_finished, dtype=dtypes.int32)
initial_time = constant_op.constant(0, dtype=dtypes.int32)

def _shape(batch_size, from_shape):
if (not isinstance(from_shape, tensor_shape.TensorShape) or
from_shape.ndims == 0):
return None
else:
batch_size = tensor_util.constant_value(
ops.convert_to_tensor(
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)

dynamic_size = maximum_iterations is None or not is_xla

# 创建tensorarray作为 output_ta. 最大size为 maximum_iterations.
def _create_ta(s, d):
return tensor_array_ops.TensorArray(
dtype=d,
size=0 if dynamic_size else maximum_iterations,
dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))

initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
decoder.output_dtype)

# 根据finished判断解码状态, tf.logical_not 返回 tf.Tensor(True/False, shape=(), dtype=bool)
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
finished, unused_sequence_lengths):
return math_ops.logical_not(math_ops.reduce_all(finished))

# 循环主体
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
"""Internal while_loop body.
Args:
time: scalar int32 tensor.
outputs_ta: structure of TensorArray.
state: (structure of) state tensors and TensorArrays.
inputs: (structure of) input tensors.
finished: bool tensor (keeping track of what's finished).
sequence_lengths: int32 tensor (keeping track of time of finish).
Returns:
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)`.
  """
  # decoder step 最核心的一步, 这里的 next_outputs 就是当前step的输出,看下面的代码,先传递到 emit,再传递到 output_ta中。
  (next_outputs, decoder_state, next_inputs,
   decoder_finished) = decoder.step(time, inputs, state)
  if decoder.tracks_own_finished:
    next_finished = decoder_finished
  else:
    next_finished = math_ops.logical_or(decoder_finished, finished)
  next_sequence_lengths = array_ops.where(
      math_ops.logical_not(finished),
      array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
      sequence_lengths)

  nest.assert_same_structure(state, decoder_state)
  nest.assert_same_structure(outputs_ta, next_outputs)
  nest.assert_same_structure(inputs, next_inputs)

  # Zero out output values past finish
  # 如果已经 finished 了, outputs_ta 中写入 zero. 终于明白前面为什么是 tf.reduce_all(finished) 了,
  # 因为一个batch中finish的长短是不一致的。所以只有最长的那个截止,才会结束这个batch的解码.
  if impute_finished:
    emit = nest.map_structure(
        lambda out, zero: array_ops.where(finished, zero, out),
        next_outputs,
        zero_outputs)
  else:
    emit = next_outputs

  # Copy through states past finish
  def _maybe_copy_state(new, cur):
    # TensorArrays and scalar states get passed through.
    if isinstance(cur, tensor_array_ops.TensorArray):
      pass_through = True
    else:
      new.set_shape(cur.shape)
      pass_through = (new.shape.ndims == 0)
    # 如果finished为True的example,依旧使用 cur,而不更新。
    return new if pass_through else array_ops.where(finished, cur, new)

  if impute_finished:
    next_state = nest.map_structure(
        _maybe_copy_state, decoder_state, state)
  else:
    next_state = decoder_state

  outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                  outputs_ta, emit)
  return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)

# tf.while_loop 类似于while循环,直到 condition为False. 初始值是  loop_vars 中的变量。(initial_time, initial_outputs_ta, ..),与 body 中的参数保持一致。同时也要与 return 返回的变量保持一致。
res = control_flow_ops.while_loop(
    condition,
    body,
    loop_vars=(
        initial_time,
        initial_outputs_ta,
        initial_state,
        initial_inputs,
        initial_finished,
        initial_sequence_lengths,
    ),
    parallel_iterations=parallel_iterations,
    maximum_iterations=maximum_iterations,
    swap_memory=swap_memory)

final_outputs_ta = res[1]
final_state = res[2]
final_sequence_lengths = res[5]

final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)

try:
  final_outputs, final_state = decoder.finalize(
      final_outputs, final_state, final_sequence_lengths)
except NotImplementedError:
  pass

if not output_time_major:
  final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)

return final_outputs, final_state, final_sequence_lengths

1
2
3
4
5
6
7
8
9
10
11
12
13
14

整个 dynamic_decode 的过程分为 `initialize()` 和 迭代 `step()`.
### initialize
其中 decoder.initialize():
```Python
class BasicDecoder(decoder.Decoder):
def initialize(self, name=None):
"""Initialize the decoder.
Args:
name: Name scope for any created operations.
Returns:
`(finished, first_inputs, initial_state)`.
"""
return self._helper.initialize() + (self._initial_state,)

那么我们看下 helper 的initialize 函数:

1
2
3
4
5
6
7
8
def initialize(self, name=None):
with ops.name_scope(name, "TrainingHelperInitialize"):
finished = math_ops.equal(0, self._sequence_length)
all_finished = math_ops.reduce_all(finished)
next_inputs = control_flow_ops.cond(
all_finished, lambda: self._zero_inputs,
lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
return (finished, next_inputs)

所以可以看到初始化就是:初始化finished状态,next_inputs和 initial_state. 那么 self._initial_state 是在 BasicDecoder 的第三个参数初始化的。这是个需要注意的点,保持 ecn_final_state 和 dec_init_state 的一致性。

step

结合上述源码中的注释来看~

tf.reduce_all 的作用,只有一个batch中所有的example都截止时,才停止循环。然后这也就涉及到 impute_finished_maybe_copy_state.

1
2
3
4
5
6
7
8
9
tf.enable_eager_execution()
tgt_len = [4,5,3,0,6]
tgt_len = tf.convert_to_tensor(tgt_len, tf.float32)
finished = tf.math.equal(0, tgt_len)
all_finished = tf.reduce_all(finished)
print(all_finished)

###
tf.Tensor(False, shape=(), dtype=bool)

总结下 dynamic_decode 的过程就是 - 设定好 condition(training阶段基于time和sequence_len, inference基于 mmaximum_iterations和 end_token)
- decoder 的 step,主要是 outputs, next_state, next_inputs 的计算。

text example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import tensorflow as tf
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.python.ops import rnn_cell
#
# tf.enable_eager_execution()
batch_size = 5
src_len = [4,5,3,5,6]
max_times = 6
num_units = 16
enc_output = tf.random.normal((batch_size, max_times, num_units), dtype=tf.float32)
#
# attenRNNCell
rnncell = rnn_cell.LSTMCell(num_units=16)
attention_mechanism = attention_wrapper.BahdanauAttention(
num_units=num_units,
memory=enc_output,
memory_sequence_length=src_len)
attnRNNCell= attention_wrapper.AttentionWrapper(
cell=rnncell,
attention_mechanism=attention_mechanism,
alignment_history=True)

# training
tgt_len = [5,6,2,7,4]
tgt_max_times = 7
tgt_inputs = tf.random.normal((batch_size, tgt_max_times, num_units), dtype=tf.float32)
training_helper = helper_py.TrainingHelper(tgt_inputs, tgt_len)

# train helper
train_decoder = basic_decoder.BasicDecoder(
cell=attnRNNCell,
helper=training_helper,
initial_state=attnRNNCell.zero_state(batch_size, tf.float32)
)

# inference
embedding = tf.get_variable("embedding", shape=(10, 16), initializer=tf.random_uniform_initializer())
infer_helper = helper_py.GreedyEmbeddingHelper(
embedding=embedding, # 可以是callable,也可以是embedding矩阵
start_tokens=tf.zeros([batch_size], dtype=tf.int32),
end_token=9
)
infer_decoder = basic_decoder.BasicDecoder(
cell=attnRNNCell,
helper=infer_helper,
initial_state=attnRNNCell.zero_state(batch_size, tf.float32)
)
final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode(
train_decoder,
maximum_iterations=False)

print(final_outputs.rnn_output)
print(final_outputs.sample_id)
print( final_state.cell_state)
print(final_sequence_lengths)

print("----------------------------")
final_outputs, final_state, final_sequence_lengths = decoder.dynamic_decode(
infer_decoder,
maximum_iterations=False)

print(final_outputs.rnn_output)
print(final_outputs.sample_id)
print( final_state.cell_state)
print(final_sequence_lengths)

####

Tensor("decoder/transpose:0", shape=(5, ?, 16), dtype=float32)
Tensor("decoder/transpose_1:0", shape=(5, ?), dtype=int32)
LSTMStateTuple(c=<tf.Tensor 'decoder/while/Exit_4:0' shape=(5, 16) dtype=float32>, h=<tf.Tensor 'decoder/while/Exit_5:0' shape=(5, 16) dtype=float32>)
Tensor("decoder/while/Exit_13:0", shape=(5,), dtype=int32)
--------infer-------------
Tensor("decoder_1/transpose:0", shape=(5, ?, 16), dtype=float32)
Tensor("decoder_1/transpose_1:0", shape=(5, ?), dtype=int32)
LSTMStateTuple(c=<tf.Tensor 'decoder_1/while/Exit_4:0' shape=(5, 16) dtype=float32>, h=<tf.Tensor 'decoder_1/while/Exit_5:0' shape=(5, 16) dtype=float32>)
Tensor("decoder_1/while/Exit_13:0", shape=(5,), dtype=int32)