diff --git a/lucent/optvis/param/gan.py b/lucent/optvis/param/gan.py index c181172..fb684c0 100644 --- a/lucent/optvis/param/gan.py +++ b/lucent/optvis/param/gan.py @@ -27,6 +27,24 @@ "fc7": "https://onedrive.live.com/download?cid=9CFFF6BCB39F6829&resid=9CFFF6BCB39F6829%2145338&authkey=AJ0R-daUAVYjQIw", "fc8": "https://onedrive.live.com/download?cid=9CFFF6BCB39F6829&resid=9CFFF6BCB39F6829%2145340&authkey=AKIfNk7s5MGrRkU"} +def download_url_to_file_fake_request(url, dst): + """ + Download object at the given URL to a local path, using browser-like HTTP GET request. + """ + + import requests + from tqdm import tqdm + + # Imitate Chrome browser + headers = {"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_6) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/76.0.3809.132 Safari/537.36"} + + with requests.get(url, headers=headers, stream=True) as r: + r.raise_for_status() + with open(dst, 'wb') as f: + for chunk in tqdm(r.iter_content(chunk_size=8192)): + f.write(chunk) def load_statedict_from_online(name="fc6"): torchhome = torch.hub._get_torch_home() @@ -34,8 +52,12 @@ def load_statedict_from_online(name="fc6"): os.makedirs(ckpthome, exist_ok=True) filepath = join(ckpthome, "upconvGAN_%s.pt"%name) if not os.path.exists(filepath): - torch.hub.download_url_to_file(model_urls[name], filepath, hash_prefix=None, - progress=True) + print("Downloading %s"%model_urls[name]) + download_url_to_file_fake_request(model_urls[name], filepath) + + # this is blocked by onedrive + #torch.hub.download_url_to_file(model_urls[name], filepath, hash_prefix=None, + # progress=True) SD = torch.load(filepath) return SD diff --git a/lucent/optvis/transform.py b/lucent/optvis/transform.py index ad8aab3..0e6cfb2 100644 --- a/lucent/optvis/transform.py +++ b/lucent/optvis/transform.py @@ -22,6 +22,12 @@ import kornia from kornia.geometry.transform import translate +try: + from kornia import warp_affine, get_rotation_matrix2d +except ImportError: + from kornia.geometry.transform import warp_affine, get_rotation_matrix2d + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") KORNIA_VERSION = kornia.__version__ @@ -52,7 +58,7 @@ def random_scale(scales): def inner(image_t): scale = np.random.choice(scales) shp = image_t.shape[2:] - scale_shape = [_roundup(scale * d) for d in shp] + scale_shape = [int(_roundup(scale * d)) for d in shp] pad_x = max(0, _roundup((shp[1] - scale_shape[1]) / 2)) pad_y = max(0, _roundup((shp[0] - scale_shape[0]) / 2)) upsample = torch.nn.Upsample( @@ -76,8 +82,8 @@ def inner(image_t): center = torch.ones(b, 2) center[..., 0] = (image_t.shape[3] - 1) / 2 center[..., 1] = (image_t.shape[2] - 1) / 2 - M = kornia.get_rotation_matrix2d(center, angle, scale).to(device) - rotated_image = kornia.warp_affine(image_t.float(), M, dsize=(h, w)) + M = get_rotation_matrix2d(center, angle, scale).to(device) + rotated_image = warp_affine(image_t.float(), M, dsize=(h, w)) return rotated_image return inner diff --git a/setup.py b/setup.py index 597d0eb..8ec8692 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ install_requires=[ "torch>=1.5.0", "torchvision", - "kornia<=0.4.1", + "kornia>=0.4.1", "tqdm", "numpy", "ipython",