-
Notifications
You must be signed in to change notification settings - Fork 705
Closed
Labels
Description
importing dynet in python, i try to change the intermedia state of lstm by some calculations, but after calculation between some parameters and s.h()[-1].npvalue(), how can i still return a RNNState object to continue to be a input of the lstm.
following is my code to covert state s,
def convert_dec_state(self, ctx_watt, s, bxm=None):
w_cz = dy.parameter(self.dec_w_cz)
w_hz = dy.parameter(self.dec_w_hz)
b_z = dy.parameter(self.dec_b_z)
w_cr = dy.parameter(self.dec_w_cr)
w_hr = dy.parameter(self.dec_w_hr)
b_r = dy.parameter(self.dec_b_r)
w_ch = dy.parameter(self.dec_w_ch)
w_hh = dy.parameter(self.dec_w_hh)
b_h = dy.parameter(self.dec_b_h)
# dy.colwise_add(w_cz * ctx_watt, b_z)
logger.debug('w_cz: {} {}'.format(type(w_cz), w_cz.npvalue().shape))
logger.debug('attention {} {}'.format(type(ctx_watt), ctx_watt.npvalue().shape))
print (w_cz * ctx_watt).npvalue().shape
z = dy.logistic(w_cz * ctx_watt + w_hz * s + b_z) # sigmoid
r = dy.logistic(w_cr * ctx_watt + w_hr * s + b_r) # sigmoid
c_h = w_ch * ctx_watt
hidden = dy.tanh(dy.dot_product((w_hh * s + b_h), r) + c_h)
hidden = dy.dot_product(s, z) + dy.dot_product((1. - z), hidden)
bxm = dy.transpose(bxm)
if bxm is not None:
hidden = bxm[:, None] * hidden + (1. - x_m)[:, None] * s.npvalue()
hshape = hidden.shape # (src_len, batch)
hidden_dy = dy.matInput(hshape[0], hshape[1])
hidden_dy.set(hidden.flatten())
return hidden_dy
i try to invoke
s = self.convert_dec_state(ctx_with_attend, s, bxm)
to change the s, but failed, because the returned hidden_dy is just a <type '_gdynet.Expression'>, not a <_gdynet.RNNState> which should be returned, i do not how to change hidden_dy into type
<_gdynet.RNNState>, any reply will be appreciated.