CopyNet Paper: Incorporating Copying Mechanism in Sequence-to-Sequence Learning.
CopyNet mechanism is wrapped with an exsiting RNN cell and used as an normal RNN cell.
Official nmt is also modified to enable CopyNet mechanism.
Just wrapper an any existing rnn cell(BasicLSTMCell
, AttentionWrapper
and so on).
cell = any_rnn_cell
copynet_cell = CopyNetWrapper(cell, encoder_outputs, encoder_input_ids,
encoder_vocab_size,decoder_vocab_size)
decoder_initial_state = copynet_cell.zero_state(batch_size,
tf.float32).clone(cell_state=decoder_initial_state)
helper = tf.contrib.seq2seq.TrainingHelper(...)
decoder = tf.contrib.seq2seq.BasicDecoder(copynet_cell, helper,
decoder_initial_state, output_layer=None)
decoder_outputs, final_state, coder_seq_length = tf.contrib.seq2seq.dynamic_decode(decoder=decoder)
decoder_logits, decoder_ids = decoder_outputs
Just add --copynet
argument to nmt command line, full nmt usage is in nmt.
python nmt.nmt.nmt.py --copynet ...other_nmt_arguments