Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Classification_BatchDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from past.builtins import xrange
import numpy as np
import scipy.misc as misc
import imageio
import pandas as pa
import re
import os
Expand Down Expand Up @@ -121,7 +121,7 @@ def load_class(self, folder, class_index):
return None

def load_image(self,folder,image, class_index):
image = misc.imread(self.path + "/" + folder + "/" + image)
image = imageio.imread(self.path + "/" + folder + "/" + image)
nr_y = image.shape[0] // self.tile_size[0]
nr_x = image.shape[1] // self.tile_size[1]

Expand Down
10 changes: 7 additions & 3 deletions Example_Classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def main(unused_argv):

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5):

print("Training the network with %d training steps..." % FLAGS.num_steps)

for i in range(FLAGS.num_steps):
batch = data_reader.next_batch(FLAGS.batch_size)
if i % 1000 == 0:
train_accuracy = accuracy.eval(feed_dict={
Expand Down Expand Up @@ -139,8 +142,9 @@ def main(unused_argv):
help='Directory for storing input data')
parser.add_argument("--batch_size", type=int, default=2, help="batch size for training")
parser.add_argument("--test_batch_size", type=int, default=200, help="batch size for training")
parser.add_argument("--model_path", type=str, default="Models/deepscores_class.ckpt",
parser.add_argument("--num_steps", type=int, default=5, help="Number of training steps")
parser.add_argument("--model_path", type=str, default="./Models/deepscores_class.ckpt",
help="where to store the trained model")

FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)