diff --git a/lucent/optvis/render.py b/lucent/optvis/render.py index 4c15168..1f18f66 100644 --- a/lucent/optvis/render.py +++ b/lucent/optvis/render.py @@ -149,6 +149,8 @@ def view(tensor): image = (image * 255).astype(np.uint8) if len(image.shape) == 4: image = np.concatenate(image, axis=1) + if len(image.shape) == 3 and image.shape[2] == 1: + image = image.squeeze(2) Image.fromarray(image).show()