-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocess_batch.py
More file actions
35 lines (24 loc) · 1.04 KB
/
process_batch.py
File metadata and controls
35 lines (24 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import tensorflow as tf
TRAIN_DATA = './simple-examples/data/train_data_index.txt'
TRAIN_BATCH_SIZE = 20
TRAIN_NUM_STEP = 35
def read_data(file_path):
with open(file_path, 'r') as fin:
id_string = ' '.join([line.strip() for line in fin.readlines()])
id_list = [int(w) for w in id_string.split()]
return id_list
def make_batches(id_list, batch_size, num_step):
num_batches = (len(id_list) - 1) // (batch_size * num_step)
data = np.array(id_list[: num_batches * batch_size * num_step])
data = np.reshape(data, [batch_size, num_batches * num_step])
data_batches = np.split(data, num_batches, axis=1)
label = np.array(id_list[1: num_batches * batch_size * num_step + 1])
label = np.reshape(label, [batch_size, num_batches * num_step])
label_batches = np.split(label, num_batches, axis=1)
return list(zip(data_batches, label_batches))
def main():
train_batches = make_batches(read_data(TRAIN_DATA), TRAIN_BATCH_SIZE, TRAIN_NUM_STEP)
pass
if __name__ == '__main__':
main()