-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodels.diff
More file actions
97 lines (85 loc) · 3.7 KB
/
models.diff
File metadata and controls
97 lines (85 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
diff -uar models-my/official/resnet/cifar10_main.py models/official/resnet/cifar10_main.py
--- models/official/resnet/cifar10_main.py 2019-02-13 12:54:50.943172776 +0300
+++ models/official/resnet/cifar10_main.py 2019-02-13 12:56:06.408195123 +0300
@@ -19,6 +19,8 @@
from __future__ import print_function
import os
+import sys
+import json
from absl import app as absl_app
from absl import flags
@@ -29,6 +31,9 @@
from official.utils.flags import core as flags_core
from official.utils.logs import logger
+from tensorflow.contrib.ignite import IgniteDataset
+import tensorflow.contrib.ignite.python.ops.igfs_ops
+
HEIGHT = 32
WIDTH = 32
NUM_CHANNELS = 3
@@ -68,7 +73,7 @@
def parse_record(raw_record, is_training, dtype):
"""Parse CIFAR-10 image and label from a raw record."""
# Convert bytes to a vector of uint8 that is record_bytes long.
- record_vector = tf.decode_raw(raw_record, tf.uint8)
+ record_vector = raw_record
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
@@ -125,8 +130,7 @@
Returns:
A dataset that can be used for iteration.
"""
- filenames = get_filenames(is_training, data_dir)
- dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
+ dataset = IgniteDataset("TEST_DATA", local=True).map(lambda row: row['val'])
return resnet_run_loop.process_record_dataset(
dataset=dataset,
@@ -230,15 +234,14 @@
fine_tune=params['fine_tune']
)
-
def define_cifar_flags():
resnet_run_loop.define_resnet_flags()
flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(data_dir='/tmp/cifar10_data/cifar-10-batches-bin',
- model_dir='/tmp/cifar10_model',
+ model_dir='igfs:///tmp/cifar10_model',
resnet_size='56',
train_epochs=182,
- epochs_between_evals=10,
+ epochs_between_evals=1,
batch_size=128,
image_bytes_as_serving_input=False)
diff -uar models-my/official/resnet/resnet_run_loop.py models/official/resnet/resnet_run_loop.py
--- models/official/resnet/resnet_run_loop.py 2019-02-13 12:54:50.943172776 +0300
+++ models/official/resnet/resnet_run_loop.py 2019-02-13 12:58:10.701746369 +0300
@@ -27,6 +27,7 @@
import math
import multiprocessing
import os
+import json
# pylint: disable=g-bad-import-order
from absl import flags
@@ -476,7 +477,11 @@
# Creates a `RunConfig` that checkpoints every 24 hours which essentially
# results in checkpoints determined only by `epochs_between_evals`.
run_config = tf.estimator.RunConfig(
- train_distribute=distribution_strategy,
+ experimental_distribute=tf.contrib.distribute.DistributeConfig(
+ train_distribute=tf.contrib.distribute.CollectiveAllReduceStrategy(),
+ eval_distribute=tf.contrib.distribute.MirroredStrategy(),
+ remote_cluster=json.loads(os.environ['TF_CLUSTER'])
+ ),
session_config=session_config,
save_checkpoints_secs=60*60*24)
@@ -560,8 +565,11 @@
tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops))
if num_train_epochs:
- classifier.train(input_fn=lambda: input_fn_train(num_train_epochs),
- hooks=train_hooks, max_steps=flags_obj.max_train_steps)
+ tf.estimator.train_and_evaluate(
+ classifier,
+ train_spec=tf.estimator.TrainSpec(input_fn=lambda: input_fn_train(num_train_epochs), hooks=train_hooks, max_steps=flags_obj.max_train_steps),
+ eval_spec=tf.estimator.EvalSpec(input_fn=input_fn_eval, steps=flags_obj.max_train_steps)
+ )
tf.logging.info('Starting to evaluate.')