-
Notifications
You must be signed in to change notification settings - Fork 8
Description
`---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
in <cell line: 2>()
20
21 ## sample from bootstrapped target distribution
---> 22 samples = bootstrap.sample(args.batch_size,
23 condition_dict, next_condition_dict, discount=sample_discount)
24
3 frames
/content/gamma-models/gamma/td/distributions.py in sample(self, batch_size, condition_dict, next_condition_dict, discount)
36
37 s1 = self.dist1.sample(condition_dict)
---> 38 s2 = self.dist2.sample(batch_size, next_condition_dict)
39
40 batch_size = len(s1)
/content/gamma-models/gamma/flows/conditional.py in sample(self, num_samples, condition)
81
82 def sample(self, num_samples, condition=None):
---> 83 z = self.prior.sample((num_samples,))
84 x = self.flow.transform(z, condition)
85 return x
/usr/local/lib/python3.10/dist-packages/torch/distributions/transformed_distribution.py in sample(self, sample_shape)
139 x = self.base_dist.sample(sample_shape)
140 for transform in self.transforms:
--> 141 x = transform(x)
142 return x
143
/usr/local/lib/python3.10/dist-packages/torch/distributions/transforms.py in call(self, x)
260
261 def call(self, x):
--> 262 assert self._inv is not None
263 return self._inv._inv_call(x)
264
AssertionError:`
I get this error while running the notebook, in the code block where the model is trained. Could you tell me if this is a pytorch specific bug? It looks like torch can't sampled from the inverse of the sigmoid transform.