Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions example/gluon/tree_lstm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

# Tree-Structured Long Short-Term Memory Networks
This is a [MXNet Gluon](https://mxnet.io/) implementation of Tree-LSTM as described in the paper [Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks](http://arxiv.org/abs/1503.00075) by Kai Sheng Tai, Richard Socher, and Christopher Manning.

### Requirements
- Python (tested on **3.6.5**, should work on **>=2.7**)
- Java >= 8 (for Stanford CoreNLP utilities)
- Other dependencies are in `requirements.txt`
Note: Currently works with MXNet 1.3.0.

### Usage
Before delving into how to run the code, here is a quick overview of the contents:
- Use the script `fetch_and_preprocess.sh` to download the [SICK dataset](http://alt.qcri.org/semeval2014/task1/index.php?id=data-and-tools), [Stanford Parser](http://nlp.stanford.edu/software/lex-parser.shtml) and [Stanford POS Tagger](http://nlp.stanford.edu/software/tagger.shtml), and [Glove word vectors](http://nlp.stanford.edu/projects/glove/) (Common Crawl 840) -- **Warning:** this is a 2GB download!), and additionally preprocess the data, i.e. generate dependency parses using [Stanford Neural Network Dependency Parser](http://nlp.stanford.edu/software/nndep.shtml).
- `main.py`does the actual heavy lifting of training the model and testing it on the SICK dataset. For a list of all command-line arguments, have a look at `python main.py -h`.
- The first run caches GLOVE embeddings for words in the SICK vocabulary. In later runs, only the cache is read in during later runs.

Next, these are the different ways to run the code here to train a TreeLSTM model.
#### Local Python Environment
If you have a working Python3 environment, simply run the following sequence of steps:

```
- bash fetch_and_preprocess.sh
- python main.py
```


### Acknowledgments
- The Gluon version is ported from this implementation [dasguptar/treelstm.pytorch](https://github.com/dasguptar/treelstm.pytorch)
- Shout-out to [Kai Sheng Tai](https://github.com/kaishengtai/) for the [original LuaTorch implementation](https://github.com/stanfordnlp/treelstm), and to the [Pytorch team](https://github.com/pytorch/pytorch#the-team) for the fun library.
4 changes: 2 additions & 2 deletions example/gluon/tree_lstm/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def load_embedding(self, f, reset=[]):
self.add(word)
if word in self.tok2idx:
vectors[word] = [float(x) for x in tokens[1:]]
dim = len(vectors.values()[0])
dim = len(list(vectors.values())[0])
def to_vector(tok):
if tok in vectors and tok not in reset:
return vectors[tok]
Expand Down Expand Up @@ -154,7 +154,7 @@ def __init__(self, path, vocab, num_classes, shuffle=True):

def reset(self):
if self.shuffle:
mask = range(self.size)
mask = list(range(self.size))
random.shuffle(mask)
self.l_sentences = [self.l_sentences[i] for i in mask]
self.r_sentences = [self.r_sentences[i] for i in mask]
Expand Down
4 changes: 2 additions & 2 deletions example/gluon/tree_lstm/fetch_and_preprocess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
# under the License.

set -e
python2.7 scripts/download.py
python scripts/download.py

CLASSPATH="lib:lib/stanford-parser/stanford-parser.jar:lib/stanford-parser/stanford-parser-3.5.1-models.jar"
javac -cp $CLASSPATH lib/*.java
python2.7 scripts/preprocess-sick.py
python scripts/preprocess-sick.py
1 change: 0 additions & 1 deletion example/gluon/tree_lstm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def train(epoch, ctx, train_data, dev_data):
net.embed.weight.set_data(vocab.embed.as_in_context(ctx[0]))
train_data.set_context(ctx[0])
dev_data.set_context(ctx[0])

# set up trainer for optimizing the network.
trainer = gluon.Trainer(net.collect_params(), optimizer, {'learning_rate': opt.lr, 'wd': opt.wd})

Expand Down
1 change: 0 additions & 1 deletion example/gluon/tree_lstm/scripts/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"""

from __future__ import print_function
import urllib2
import sys
import os
import shutil
Expand Down