-
Notifications
You must be signed in to change notification settings - Fork 1
Feature/pre train subreddit #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
db3ec23
merge
bc9154d
added method to add secondary output layer
1afbf9f
implemented pre-training on subreddits
95b779e
create methods for handling pre-training data
4afbf98
non-working commit to allow for debugg help.
193811d
changed which variable was used to check if pre-training
e5be638
only add sec_output when pre-train and add an extra layer after
d94d187
changed config to not include unnecissary parameters
fa9859c
cleaned main.py by adding it to builder
1de1133
sec_output uses softmax since a title only have one subreddit
5c2db6c
refactored add output method to one method
2374ed4
Adds the option to choose between GRU and LSTM units
hsson f46c8ce
Adds a missing variable name change
hsson 634d79b
Changes default dataset in config template
hsson d73a5c4
merge
468d07d
added method to add secondary output layer
2e957a3
implemented pre-training on subreddits
796a81d
create methods for handling pre-training data
81c149d
non-working commit to allow for debugg help.
f5c49ed
changed which variable was used to check if pre-training
8c5cb78
only add sec_output when pre-train and add an extra layer after
7383a0b
cleaned main.py by adding it to builder
157987d
sec_output uses softmax since a title only have one subreddit
90086ab
refactored add output method to one method
2561556
Merge branch 'feature/pre-train-subreddit' of github.com:kandidat-hig…
65cf80c
Refactors away redundant function in model builder
hsson fbe203f
Removes un-used constant
hsson ac40573
Removes unused epoch counting for pre-training
hsson File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,7 +53,8 @@ def __init__(self, config, session): | |
| self.learning_rate = config[LEARN_RATE] | ||
| self.embedding_size = config[EMBEDD_SIZE] | ||
| self.max_title_length = config[MAX_TITLE_LENGTH] | ||
| self.lstm_neurons = config[LSTM_NEURONS] | ||
| self.rnn_neurons = config[RNN_NEURONS] | ||
| self.rnn_unit = config[RNN_UNIT] | ||
| self.batch_size = config[BATCH_SIZE] | ||
| self.training_epochs = config[TRAINING_EPOCHS] | ||
| self.use_l2_loss = config[USE_L2_LOSS] | ||
|
|
@@ -67,13 +68,17 @@ def __init__(self, config, session): | |
| self.use_constant_limit = config[USE_CONSTANT_LIMIT] | ||
| self.constant_prediction_limit = config[CONSTANT_PREDICTION_LIMIT] | ||
| self.use_concat_input = config[USE_CONCAT_INPUT] | ||
| self.use_pretrained_net = config[USE_PRETRAINED_NET] | ||
| self.subreddit_count = 0 | ||
|
|
||
| # Will be set in build_graph | ||
| self.input = None | ||
| self.subreddit_input = None | ||
| self.target = None | ||
| self.sec_target = None | ||
| self.sigmoid = None | ||
| self.train_op = None | ||
| self.pre_train_op = None | ||
| self.error = None | ||
| self.init_op = None | ||
| self.saver = None | ||
|
|
@@ -105,6 +110,7 @@ def __init__(self, config, session): | |
|
|
||
| with tf.device("/cpu:0"): | ||
| self.data = data.Data(config) | ||
| self.subreddit_count = self.data.subreddit_count | ||
| if self.use_pretrained: | ||
| self.vocabulary_size = len(self.data.embedding_matrix) | ||
|
|
||
|
|
@@ -182,40 +188,41 @@ def validate_batch(self): | |
| self.subreddit_input: batch_sub, | ||
| self.target: batch_label}) | ||
|
|
||
| # TODO funktionen gör alldeles för mycket, | ||
| # dela upp utskrift, beräkning och träning | ||
| def train(self): | ||
| def train(self, use_pretrained_net=False): | ||
| """ Trains the model on the dataset """ | ||
| print("Starting training...") | ||
| if use_pretrained_net: | ||
| print("Pre-training on subreddits...") | ||
| else: | ||
| print("Starting training...") | ||
|
|
||
| if self.use_pretrained: | ||
| if self.use_pretrained and \ | ||
| (self.use_pretrained_net and use_pretrained_net) or \ | ||
| (not self.use_pretrained_net and not use_pretrained_net): | ||
| self._session.run(self.embedding_init, | ||
| feed_dict={self.embedding_placeholder: | ||
| self.data.embedding_matrix}) | ||
| self.train_writer = \ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are the writers moved to the model builder?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps not necessary but otherwise the writer initiated twice, once for each time we train. |
||
| tf.summary.FileWriter(self.logging_dir + '/' + TENSOR_DIR_TRAIN, | ||
| self._session.graph) | ||
| self.valid_writer = \ | ||
| tf.summary.FileWriter(self.logging_dir + '/' + TENSOR_DIR_VALID) | ||
|
|
||
| old_epoch = 0 | ||
|
|
||
| if self.epoch.eval(self._session) == 0: | ||
| if self.epoch.eval(self._session) == 0 and not use_pretrained_net: | ||
| self.validate() | ||
|
|
||
| # Train for a specified amount of epochs | ||
| for i in self.data.for_n_train_epochs(self.training_epochs, | ||
| self.batch_size): | ||
| for i in self.data.for_n_train_epochs(self.training_epochs, self.batch_size): | ||
| # Debug print out | ||
| epoch = self.data.completed_training_epochs | ||
| training_error = self.train_batch() | ||
| validation_error = self.validate_batch() | ||
|
|
||
| # Don't validate so often | ||
| if i % (self.data.train_size // self.batch_size // 10) == 0 and i: | ||
| done = self.data.percent_of_epoch | ||
| print("Validation error: {:f} | Training error: {:f} | Done: {:.0%}" | ||
| .format(validation_error, training_error, done)) | ||
| if not use_pretrained_net: | ||
| training_error = self.train_batch() | ||
| validation_error = self.validate_batch() | ||
|
|
||
| # Don't validate so often | ||
| if i % (self.data.train_size // self.batch_size // 10) == 0 and i: | ||
| done = self.data.percent_of_epoch | ||
| print("Validation error: {:f} | Training error: {:f} | Done: {:.0%}" | ||
| .format(validation_error, training_error, done)) | ||
| else: | ||
| self.train_batch(True) | ||
|
|
||
| # Do a full evaluation once an epoch is complete | ||
| if epoch != old_epoch: | ||
|
|
@@ -231,21 +238,30 @@ def train(self): | |
| epoch_top=self.epoch_top, prec_valid=self.prec_valid, prec_train=self.prec_train, | ||
| recall_valid=self.recall_valid, recall_train=self.recall_train) | ||
|
|
||
| def train_batch(self): | ||
| def train_batch(self, pre_train_net=False): | ||
| """ Trains for one batch and returns cross entropy error """ | ||
| with tf.device("/cpu:0"): | ||
| batch_input, batch_sub, batch_label = \ | ||
| self.data.next_train_batch() | ||
|
|
||
| self._session.run(self.train_op, | ||
| {self.input: batch_input, | ||
| self.subreddit_input: batch_sub, | ||
| self.target: batch_label}) | ||
|
|
||
| return self._session.run(self.error, | ||
| feed_dict={self.input: batch_input, | ||
| self.subreddit_input: batch_sub, | ||
| self.target: batch_label}) | ||
| if not pre_train_net: | ||
| batch_input, batch_sub, batch_label = \ | ||
| self.data.next_train_batch() | ||
| else: | ||
| batch_input, batch_label = \ | ||
| self.data.next_pre_train_batch() | ||
|
|
||
| if pre_train_net: | ||
| self._session.run(self.pre_train_op, | ||
| {self.input: batch_input, | ||
| self.sec_target: batch_label}) | ||
| else: | ||
| self._session.run(self.train_op, | ||
| {self.input: batch_input, | ||
| self.subreddit_input: batch_sub, | ||
| self.target: batch_label}) | ||
|
|
||
| return self._session.run(self.error, | ||
| feed_dict={self.input: batch_input, | ||
| self.subreddit_input: batch_sub, | ||
| self.target: batch_label}) | ||
| def close_writers(self): | ||
| """ Close tensorboard writers """ | ||
| self.train_writer.close() | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this variable used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes