diff --git a/lucent/optvis/render.py b/lucent/optvis/render.py index 4c15168..e7de171 100644 --- a/lucent/optvis/render.py +++ b/lucent/optvis/render.py @@ -136,6 +136,10 @@ def closure(): def tensor_to_img_array(tensor): image = tensor.cpu().detach().numpy() image = np.transpose(image, [0, 2, 3, 1]) + # Check if the image is single channel and convert to 3-channel + if len(image.shape) == 4 and image.shape[3] == 1: # Single channel image + image = np.repeat(image, 3, axis=3) + image = (image * 255).astype(np.uint8) return image