-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecoder_encoder.py
More file actions
33 lines (23 loc) · 881 Bytes
/
decoder_encoder.py
File metadata and controls
33 lines (23 loc) · 881 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
def encode(X,seq_len, vocab_size):
x = np.zeros((len(X),seq_len, vocab_size), dtype=np.float32)
for ind,batch in enumerate(X):
for j, elem in enumerate(batch):
x[ind, j, elem] = 1
return x
def batch_gen(batch_size=32, seq_len=10, max_no=100):
x = np.zeros((batch_size, seq_len, max_no), dtype=np.float32)
y = np.zeros((batch_size, seq_len, max_no), dtype=np.float32)
while True:
# Generates a batch of input
X = np.random.randint(max_no, size=(batch_size, seq_len))
Y = np.sort(X, axis=1)
for ind,batch in enumerate(X):
for j, elem in enumerate(batch):
x[ind, j, elem] = 1
for ind,batch in enumerate(Y):
for j, elem in enumerate(batch):
y[ind, j, elem] = 1
yield x, y
x.fill(0.0)
y.fill(0.0)