-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Tutorial]NLP Sequence to sequence model for translation #1815
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,238 @@ | ||
| """ | ||
| Keras LSTM Sequence to Sequence Model for Translation | ||
| ================================= | ||
| **Author**: `Siju Samuel <https://siju-samuel.github.io/>`_ | ||
|
|
||
| This script demonstrates how to implement a basic character-level sequence-to-sequence model. | ||
| We apply it to translating short English sentences into short French sentences, | ||
| character-by-character. | ||
|
|
||
| # Summary of the algorithm | ||
|
|
||
| - We start with input sequences from a domain (e.g. English sentences) | ||
| and corresponding target sequences from another domain | ||
| (e.g. French sentences). | ||
| - An encoder LSTM turns input sequences to 2 state vectors | ||
| (we keep the last LSTM state and discard the outputs). | ||
| - A decoder LSTM is trained to turn the target sequences into | ||
| the same sequence but offset by one timestep in the future, | ||
| a training process called "teacher forcing" in this context. | ||
| Is uses as initial state the state vectors from the encoder. | ||
| Effectively, the decoder learns to generate `targets[t+1...]` | ||
| given `targets[...t]`, conditioned on the input sequence. | ||
|
|
||
| This script loads the s2s.h5 model saved in repository | ||
| https://github.com/dmlc/web-data/raw/master/keras/models/s2s_translate/lstm_seq2seq.py | ||
| and generates sequences from it. It assumes that no changes have been made (for example: | ||
| latent_dim is unchanged, and the input data and model architecture are unchanged). | ||
|
|
||
| # References | ||
|
|
||
| - Sequence to Sequence Learning with Neural Networks | ||
| https://arxiv.org/abs/1409.3215 | ||
| - Learning Phrase Representations using | ||
| RNN Encoder-Decoder for Statistical Machine Translation | ||
| https://arxiv.org/abs/1406.1078 | ||
|
|
||
| See lstm_seq2seq.py for more details on the model architecture and how it is trained. | ||
| """ | ||
|
|
||
| from keras.models import Model, load_model | ||
| from keras.layers import Input | ||
| import random | ||
| import os | ||
| import numpy as np | ||
| import keras | ||
| import tvm | ||
| import nnvm | ||
|
|
||
| ###################################################################### | ||
| # Download required files | ||
| # ----------------------- | ||
| # Download files listed below from dmlc web-data repo. | ||
| model_file = "s2s_translate.h5" | ||
| data_file = "fra-eng.txt" | ||
|
|
||
| # Base location for model related files. | ||
| repo_base = 'https://github.com/dmlc/web-data/raw/master/keras/models/s2s_translate/' | ||
| model_url = os.path.join(repo_base, model_file) | ||
| data_url = os.path.join(repo_base, data_file) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The model and data files are not found in the repository yet. Can you share them?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # Download files listed below. | ||
| from mxnet.gluon.utils import download | ||
| download(model_url, model_file) | ||
| download(data_url, model_file) | ||
|
|
||
| latent_dim = 256 # Latent dimensionality of the encoding space. | ||
| test_samples = 10000 # Number of samples used for testing. | ||
|
|
||
| ###################################################################### | ||
| # Process the data file | ||
| # --------------------- | ||
| # Vectorize the data. We use the same approach as the training script. | ||
| # NOTE: the data must be identical, in order for the character -> integer | ||
| # mappings to be consistent. | ||
| input_texts = [] | ||
| target_texts = [] | ||
| input_characters = set() | ||
| target_characters = set() | ||
| with open(data_file, 'r', encoding='utf-8') as f: | ||
| lines = f.read().split('\n') | ||
| test_samples = min(test_samples, len(lines)) | ||
| max_encoder_seq_length = 0 | ||
| max_decoder_seq_length = 0 | ||
| for line in lines[:test_samples]: | ||
| input_text, target_text = line.split('\t') | ||
| # We use "tab" as the "start sequence" character | ||
| # for the targets, and "\n" as "end sequence" character. | ||
| target_text = '\t' + target_text + '\n' | ||
| max_encoder_seq_length = max(max_encoder_seq_length, len(input_text)) | ||
| max_decoder_seq_length = max(max_decoder_seq_length, len(target_text)) | ||
| for char in input_text: | ||
| if char not in input_characters: | ||
| input_characters.add(char) | ||
| for char in target_text: | ||
| if char not in target_characters: | ||
| target_characters.add(char) | ||
|
|
||
| input_characters = sorted(list(input_characters)) | ||
| target_characters = sorted(list(target_characters)) | ||
| num_encoder_tokens = len(input_characters) | ||
| num_decoder_tokens = len(target_characters) | ||
| input_token_index = dict( | ||
| [(char, i) for i, char in enumerate(input_characters)]) | ||
| target_token_index = dict( | ||
| [(char, i) for i, char in enumerate(target_characters)]) | ||
|
|
||
| # Reverse-lookup token index to decode sequences back to something readable. | ||
| reverse_target_char_index = dict( | ||
| (i, char) for char, i in target_token_index.items()) | ||
|
|
||
| ###################################################################### | ||
| # Load Keras Model | ||
| # ---------------- | ||
| # Restore the model and construct the encoder and decoder. | ||
| model = load_model(model_file) | ||
| encoder_inputs = model.input[0] # input_1 | ||
|
|
||
| encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1 | ||
| encoder_states = [state_h_enc, state_c_enc] | ||
| encoder_model = Model(encoder_inputs, encoder_states) | ||
|
|
||
| decoder_inputs = model.input[1] # input_2 | ||
| decoder_state_input_h = Input(shape=(latent_dim,), name='input_3') | ||
| decoder_state_input_c = Input(shape=(latent_dim,), name='input_4') | ||
| decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] | ||
| decoder_lstm = model.layers[3] | ||
| decoder_outputs, state_h_dec, state_c_dec = decoder_lstm( | ||
| decoder_inputs, initial_state=decoder_states_inputs) | ||
| decoder_states = [state_h_dec, state_c_dec] | ||
| decoder_dense = model.layers[4] | ||
| decoder_outputs = decoder_dense(decoder_outputs) | ||
| decoder_model = Model( | ||
| [decoder_inputs] + decoder_states_inputs, | ||
| [decoder_outputs] + decoder_states) | ||
|
|
||
| ###################################################################### | ||
| # Compile both encoder and decoder model on NNVM | ||
| # ---------------------------------------------- | ||
| # Creates NNVM graph definition from keras model file. | ||
| from tvm.contrib import graph_runtime | ||
| target = 'llvm' | ||
| ctx = tvm.cpu(0) | ||
|
|
||
| # Parse Encoder model | ||
| sym, params = nnvm.frontend.from_keras(encoder_model) | ||
| inp_enc_shape = (1, max_encoder_seq_length, num_encoder_tokens) | ||
| shape_dict = {'input_1': inp_enc_shape} | ||
|
|
||
| # Build Encoder model | ||
| with nnvm.compiler.build_config(opt_level=2): | ||
| enc_graph, enc_lib, enc_params = nnvm.compiler.build(sym, target, shape_dict, params=params) | ||
| print("Encoder build ok.") | ||
|
|
||
| # Create graph runtime for encoder model | ||
| tvm_enc = graph_runtime.create(enc_graph, enc_lib, ctx) | ||
| tvm_enc.set_input(**enc_params) | ||
|
|
||
| # Parse Decoder model | ||
| inp_dec_shape = (1, 1, num_decoder_tokens) | ||
| shape_dict = {'input_2': inp_dec_shape, | ||
| 'input_3': (1, latent_dim), | ||
| 'input_4': (1, latent_dim)} | ||
|
|
||
| # Build Decoder model | ||
| sym, params = nnvm.frontend.from_keras(decoder_model) | ||
| with nnvm.compiler.build_config(opt_level=2): | ||
| dec_graph, dec_lib, dec_params = nnvm.compiler.build(sym, target, shape_dict, params=params) | ||
| print("Decoder build ok.") | ||
|
|
||
| # Create graph runtime for decoder model | ||
| tvm_dec = graph_runtime.create(dec_graph, dec_lib, ctx) | ||
| tvm_dec.set_input(**dec_params) | ||
|
|
||
| # Decodes an input sequence. | ||
| def decode_sequence(input_seq): | ||
| # Set the input for encoder model. | ||
| tvm_enc.set_input('input_1', input_seq) | ||
|
|
||
| # Run encoder model | ||
| tvm_enc.run() | ||
|
|
||
| # Get states from encoder network | ||
| h = tvm_enc.get_output(0).asnumpy() | ||
| c = tvm_enc.get_output(1).asnumpy() | ||
|
|
||
| # Populate the first character of target sequence with the start character. | ||
| sampled_token_index = target_token_index['\t'] | ||
|
|
||
| # Sampling loop for a batch of sequences | ||
| decoded_sentence = '' | ||
| while True: | ||
| # Generate empty target sequence of length 1. | ||
| target_seq = np.zeros((1, 1, num_decoder_tokens), dtype='float32') | ||
| # Update the target sequence (of length 1). | ||
| target_seq[0, 0, sampled_token_index] = 1. | ||
|
|
||
| # Set the input and states for decoder model. | ||
| tvm_dec.set_input('input_2', target_seq) | ||
| tvm_dec.set_input('input_3', h) | ||
| tvm_dec.set_input('input_4', c) | ||
| # Run decoder model | ||
| tvm_dec.run() | ||
|
|
||
| output_tokens = tvm_dec.get_output(0).asnumpy() | ||
| h = tvm_dec.get_output(1).asnumpy() | ||
| c = tvm_dec.get_output(2).asnumpy() | ||
|
|
||
| # Sample a token | ||
| sampled_token_index = np.argmax(output_tokens[0, -1, :]) | ||
| sampled_char = reverse_target_char_index[sampled_token_index] | ||
|
|
||
| # Exit condition: either hit max length or find stop character. | ||
| if sampled_char == '\n': | ||
| break | ||
|
|
||
| # Update the sentence | ||
| decoded_sentence += sampled_char | ||
| if len(decoded_sentence) > max_decoder_seq_length: | ||
| break | ||
| return decoded_sentence | ||
|
|
||
| def generate_input_seq(input_text): | ||
| input_seq = np.zeros((1, max_encoder_seq_length, num_encoder_tokens), dtype='float32') | ||
| for t, char in enumerate(input_text): | ||
| input_seq[0, t, input_token_index[char]] = 1. | ||
| return input_seq | ||
|
|
||
| ###################################################################### | ||
| # Run the model | ||
| # ------------- | ||
| # Randonly take some text from test samples and translate | ||
| for seq_index in range(100): | ||
| # Take one sentence randomly and try to decode. | ||
| index = random.randint(1, test_samples) | ||
| input_text, _ = lines[index].split('\t') | ||
| input_seq = generate_input_seq(input_text) | ||
| decoded_sentence = decode_sequence(input_seq) | ||
| print((seq_index + 1), ": ", input_text, "==>", decoded_sentence) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current version doesn't have a check of
input_shape[1] != 1, so the input shape will be (1, m, n)?