diff --git a/example/gluon/tree_lstm/README.md b/example/gluon/tree_lstm/README.md new file mode 100644 index 000000000000..e14ab4c70afc --- /dev/null +++ b/example/gluon/tree_lstm/README.md @@ -0,0 +1,29 @@ + +# Tree-Structured Long Short-Term Memory Networks +This is a [MXNet Gluon](https://mxnet.io/) implementation of Tree-LSTM as described in the paper [Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks](http://arxiv.org/abs/1503.00075) by Kai Sheng Tai, Richard Socher, and Christopher Manning. + +### Requirements +- Python (tested on **3.6.5**, should work on **>=2.7**) +- Java >= 8 (for Stanford CoreNLP utilities) +- Other dependencies are in `requirements.txt` +Note: Currently works with MXNet 1.3.0. + +### Usage +Before delving into how to run the code, here is a quick overview of the contents: + - Use the script `fetch_and_preprocess.sh` to download the [SICK dataset](http://alt.qcri.org/semeval2014/task1/index.php?id=data-and-tools), [Stanford Parser](http://nlp.stanford.edu/software/lex-parser.shtml) and [Stanford POS Tagger](http://nlp.stanford.edu/software/tagger.shtml), and [Glove word vectors](http://nlp.stanford.edu/projects/glove/) (Common Crawl 840) -- **Warning:** this is a 2GB download!), and additionally preprocess the data, i.e. generate dependency parses using [Stanford Neural Network Dependency Parser](http://nlp.stanford.edu/software/nndep.shtml). +- `main.py`does the actual heavy lifting of training the model and testing it on the SICK dataset. For a list of all command-line arguments, have a look at `python main.py -h`. +- The first run caches GLOVE embeddings for words in the SICK vocabulary. In later runs, only the cache is read in during later runs. + +Next, these are the different ways to run the code here to train a TreeLSTM model. +#### Local Python Environment +If you have a working Python3 environment, simply run the following sequence of steps: + +``` +- bash fetch_and_preprocess.sh +- python main.py +``` + + +### Acknowledgments +- The Gluon version is ported from this implementation [dasguptar/treelstm.pytorch](https://github.com/dasguptar/treelstm.pytorch) +- Shout-out to [Kai Sheng Tai](https://github.com/kaishengtai/) for the [original LuaTorch implementation](https://github.com/stanfordnlp/treelstm), and to the [Pytorch team](https://github.com/pytorch/pytorch#the-team) for the fun library. diff --git a/example/gluon/tree_lstm/dataset.py b/example/gluon/tree_lstm/dataset.py index 02c57c0cc809..5d6b766042d6 100644 --- a/example/gluon/tree_lstm/dataset.py +++ b/example/gluon/tree_lstm/dataset.py @@ -115,7 +115,7 @@ def load_embedding(self, f, reset=[]): self.add(word) if word in self.tok2idx: vectors[word] = [float(x) for x in tokens[1:]] - dim = len(vectors.values()[0]) + dim = len(list(vectors.values())[0]) def to_vector(tok): if tok in vectors and tok not in reset: return vectors[tok] @@ -154,7 +154,7 @@ def __init__(self, path, vocab, num_classes, shuffle=True): def reset(self): if self.shuffle: - mask = range(self.size) + mask = list(range(self.size)) random.shuffle(mask) self.l_sentences = [self.l_sentences[i] for i in mask] self.r_sentences = [self.r_sentences[i] for i in mask] diff --git a/example/gluon/tree_lstm/fetch_and_preprocess.sh b/example/gluon/tree_lstm/fetch_and_preprocess.sh index f372392830d0..a9b9d28612f3 100755 --- a/example/gluon/tree_lstm/fetch_and_preprocess.sh +++ b/example/gluon/tree_lstm/fetch_and_preprocess.sh @@ -18,8 +18,8 @@ # under the License. set -e -python2.7 scripts/download.py +python scripts/download.py CLASSPATH="lib:lib/stanford-parser/stanford-parser.jar:lib/stanford-parser/stanford-parser-3.5.1-models.jar" javac -cp $CLASSPATH lib/*.java -python2.7 scripts/preprocess-sick.py +python scripts/preprocess-sick.py diff --git a/example/gluon/tree_lstm/main.py b/example/gluon/tree_lstm/main.py index ad5d59f7a47d..53af3fa019e9 100644 --- a/example/gluon/tree_lstm/main.py +++ b/example/gluon/tree_lstm/main.py @@ -152,7 +152,6 @@ def train(epoch, ctx, train_data, dev_data): net.embed.weight.set_data(vocab.embed.as_in_context(ctx[0])) train_data.set_context(ctx[0]) dev_data.set_context(ctx[0]) - # set up trainer for optimizing the network. trainer = gluon.Trainer(net.collect_params(), optimizer, {'learning_rate': opt.lr, 'wd': opt.wd}) diff --git a/example/gluon/tree_lstm/scripts/download.py b/example/gluon/tree_lstm/scripts/download.py index 7ea930370175..6537ef1ff655 100644 --- a/example/gluon/tree_lstm/scripts/download.py +++ b/example/gluon/tree_lstm/scripts/download.py @@ -24,7 +24,6 @@ """ from __future__ import print_function -import urllib2 import sys import os import shutil