シーケンス間パート1モデル

すべての人に良い一日を!

そして、改訂されたデータサイエンティストコース用に新しいストリームが再びオープンしました。別の優れた教師 、更新に基づいたわずかに洗練されたプログラムです。 さて、いつものように、興味深い公開レッスンと興味深い資料のコレクション。 今日は、Tensor Flowからseq2seqモデルの分析を開始します。

行こう

RNNチュートリアルで既に説明したように(この記事を読む前にこれをよく理解することをお勧めします)、言語をモデル化するためにリカレントニューラルネットワークを教えることができます。 そして、興味深い質問が生じます。意味のある回答を生成するために特定のデータでネットワークをトレーニングすることは可能ですか? たとえば、英語からフランス語に翻訳するニューラルネットワークを教えることができますか? できることがわかりました。

このガイドでは、このようなエンドツーエンドシステムを作成およびトレーニングする方法を示します。 Tensor FlowコアリポジトリTensorFlowモデルリポジトリをGitHubからコピーします 。 次に、翻訳プログラムを開始して開始できます。

cd models/tutorials/rnn/translate python translate.py --data_dir [your_data_directory] 



彼女はWMT'15 Webサイトから英語からフランス語への翻訳用のデータをダウンロードし、トレーニングとトレーニングの準備をします。 これには、ハードドライブに約20GBが必要であり、ダウンロードと準備にかなりの時間がかかるため、すぐにプロセスを開始して、このチュートリアルを読み続けることができます。

マニュアルは次のファイルにアクセスします。

ファイル何が入っているの?
テンソルフロー/テンソルフロー/ python / ops / seq2seq.pyシーケンス間モデルを作成するためのライブラリ
モデル/チュートリアル/ rnn /翻訳/ seq2seq_model.pyシーケンス間ニューラル翻訳モデル
モデル/チュートリアル/ rnn / translate / data_utils.py翻訳データを準備するためのヘルパー関数
モデル/チュートリアル/ rnn / translate / translate.py翻訳モデルをトレーニングして実行するバイナリ

シーケンス間の基本

Cho et al。、2014pdf提示する基本的なシーケンス間モデルは、2つのリカレントニューラルネットワーク(RNN)で構成されています。入力データを処理するエンコーダー(エンコーダー)と、データを生成するデコーダー(デコーダー)出力。 基本的なアーキテクチャは次のとおりです。



上の図の各長方形は、RNN内のセル、通常はGRUセル(制御された繰り返しブロック)、またはLSTMセル(長期短期メモリ)を表します(詳細については、 RNNチュートリアルをご覧ください)。 エンコーダーとデコーダーは、共通の重みを持つか、より頻繁に異なるパラメーターのセットを使用できます。 多層セルは、 Sutskever et al。、2014pdf )の翻訳など、シーケンス間モデルで使用されています。

上記の基本モデルでは、各入力は固定サイズの状態ベクトルにエンコードされる必要があります。これは、これがデコーダーに送信される唯一のものであるためです。 デコーダが入力データにより直接アクセスできるようにするために、 Bahdanau et al。、2014pdf )に注意メカニズムが導入されました 。 アテンションメカニズムの詳細は説明しません(このため、ここでの作業に慣れることができます)。 デコーダーが各デコードステップで入力データを調べることができると言うだけで十分です。 LSTMセルとデコーダーのアテンションメカニズムを備えた多層シーケンスツーシーケンスネットワークは次のとおりです。



TensorFlowライブラリseq2seq

上記からわかるように、異なるシーケンス間モデルがあります。 それらはすべて異なるRNNセルを使用できますが、それらはすべてエンコーダー入力データとデコーダー入力データを受け入れます。 これは、TensorFlow seq2seqライブラリインターフェイスの基本です(tensorflow / tensorflow / python / ops / seq2seq.py)。 この基本的なRNN、コーデック、シーケンス間モデルは次のように機能します。

 outputs, states = basic_rnn_seq2seq(encoder_inputs, decoder_inputs, cell) 

上記の呼び出しでは、 encoder_inputsは、上の図の文字A、B、Cに対応するエンコーダー入力データを表すテンソルのリストです。 同様に、 decoder_inputsはデコーダー入力データを表すテンソルです。 最初の写真のGO、W、X、Y、Z。

cell引数は、モデルで使用されるセルを決定するtf.contrib.rnn.RNNCellクラスのインスタンスです。 GRUCellLSTMCellなどの既存のセルを使用するか、 GRUCellセルを作成できます。 さらに、 tf.contrib.rnnは、多層セルを作成し、セルの入出力に例外を追加したり、その他の変換を行うためのシェルを提供します。 例については、 RNNチュートリアルご覧ください。

basic_rnn_seq2seqの呼び出しは、 basic_rnn_seq2seqstates 2つの引数を返します。 これらは両方とも、 decoder_inputsと同じ長さのテンソルのリストを表します。 outputsは、各タイムステップのデコーダー出力データに対応します。最初の画像では、W、X、Y、Z、EOSです。 返されるstatesは、各タイムステップでのデコーダーの内部状態を表します。

シーケンス間モデルを使用する多くのアプリケーションでは、時間tでのデコーダー出力は時間t + 1でデコーダーへの入力に送信されます。 テスト中、シーケンスのデコード中に、新しいシーケンスが作成されます。 一方、トレーニング中は、デコーダーが以前に誤っていた場合でも、各タイムステップで正しい入力データをデコーダーに送信するのが一般的です。 seq2seq.py関数は、 feed_previous引数で両方のモードをサポートします。 たとえば、ネストされたRNNモデルの次の使用を検討してください。

 outputs, states = embedding_rnn_seq2seq( encoder_inputs, decoder_inputs, cell, num_encoder_symbols, num_decoder_symbols, embedding_size, output_projection=None, feed_previous=False) 

embedding_rnn_seq2seqモデルでは、すべての入力データ( encoder_inputsdecoder_inputs両方)は離散値を反映する整数テンソルです。 それらは密な表現に埋め込まれます(添付方法の詳細については、 ベクトル表現ガイドを参照してください)が、これらの添付ファイルを作成するには、個別の文字の最大数を指定する必要があります:エンコーダー側のnum_decoder_symbolsとデコーダー側のnum_decoder_symbols

上記の呼び出しでは、 feed_previousをFalseに設定します。 これは、デコーダーが提供されている形式でdecoder_inputsテンソルを使用することを意味します。 feed_previousをTrueに設定すると、デコーダーは最初のdecoder_inputs要素のみを使用します。 リストの他のすべてのテンソルは無視され、代わりにデコーダー出力の以前の値が使用されます。 これは、翻訳モデルの翻訳をデコードするために使用されますが、トレーニング中に使用して、モデルのエラーに対する安定性を向上させることもできます。 およそBengio et al。、2015pdf )のように。

上記で使用される別の重要な引数はoutput_projectionです。 明確化しないと、埋め込みモデルの結論は、生成された各シンボルの対数を表すため、num_decoder_symbolsごとのトレーニングサンプルの数のテンソルになります。 大きなnum_decoder_symbolsなどの大きな出力ディクショナリを含むモデルをトレーニングする場合、これらの大きなテンソルの保存は実用的ではなくなります。 代わりに、 output_projectionを使用して大きなテンソルに後で投影される小さなテンソルを返すことをおoutput_projectionます。 これにより、 ジャンらによって説明されているように、サンプリングされたソフトマックス損失でseq2seqモデルを使用できます al。、2014pdf )。

basic_rnn_seq2seqおよびembedding_rnn_seq2seq加えて、 basic_rnn_seq2seqにはさらにいくつかのシーケンス間モデルがseq2seq.pyます。 それらに注意してください。 それらはすべて同様のインターフェースを備えているため、詳細については掘り下げません。 以下の翻訳モデルでは、 embedding_attention_seq2seqを使用します。

継続する。

Source: https://habr.com/ru/post/J430780/


All Articles