From db8b05c541c5534f4daff7b588b7f6aeb1074718 Mon Sep 17 00:00:00 2001 From: Aleksey Morozov <36787333+amrzv@users.noreply.github.com> Date: Sat, 10 Jul 2021 12:14:19 +0300 Subject: [PATCH 1/3] added form for easier selecting available choices of parameters --- notebooks/rewriting-interface.ipynb | 40 ++++++++++++++++------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/notebooks/rewriting-interface.ipynb b/notebooks/rewriting-interface.ipynb index 3571309..6fb3f6e 100644 --- a/notebooks/rewriting-interface.ipynb +++ b/notebooks/rewriting-interface.ipynb @@ -17,7 +17,7 @@ "5. Click on a target region of an image and \"Paste\" to paste the pattern in a new place. Clicking around the red box can adjust it.\n", "6. Click \"Execute\" to see the effects. \"Toggle\" compares to the original model, and \"Revert\" discards the edit.\n", "\n", - "Editing a model is challenging, because you need to develop an understanding of the way the model organies its rules. You can build your intuition by using the \"Highlight\" button to see how the model generalizes regions that you select in the context.\n", + "Editing a model is challenging, because you need to develop an understanding of the way the model organizes its rules. You can build your intuition by using the \"Highlight\" button to see how the model generalizes regions that you select in the context.\n", "\n", "Particular model edits can be saved or loaded as json files; and other geeky details\n", "can be seen in the source code at http://github.com/davidbau/rewriting. This\n", @@ -46,11 +46,30 @@ " import google.colab, sys, torch\n", " sys.path.append('/content/tutorial_code')\n", " if not torch.cuda.is_available():\n", - " print(\"Change runtime type to include a GPU.\") \n", + " print(\"Change runtime type to include a GPU.\")\n", "except:\n", " pass" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Choices {run: \"auto\"}\n", + "\n", + "ganname = 'stylegan' #@param [\"stylegan\", \"proggan\"]\n", + "modelname = 'church' #@param [\"church\", \"faces\", \"horse\", \"kitchen\"]\n", + "#@markdown - layer 6,7,8,9,10 work OK for different things.\n", + "#@markdown - layer 8 is good for trees or domes in churches, and hats on horses\n", + "#@markdown - layer 6 is good for smiles on faces\n", + "#@markdown - layer 10 is good for hair on faces\n", + "layernum = 8 #@param {type:\"slider\", min:6, max:10, step:1}\n", + "#@markdown Number of images to sample when gathering statistics.\n", + "size = 1000 #@param {type:\"integer\"}" + ] + }, { "cell_type": "code", "execution_count": null, @@ -66,21 +85,6 @@ "import utils.stylegan2, utils.proggan\n", "from utils.stylegan2 import load_seq_stylegan\n", "\n", - "# Choices: ganname = 'stylegan' or ganname = 'proggan'\n", - "ganname = 'stylegan'\n", - "\n", - "# Choices: modelname = 'church' or faces' or 'horse' or 'kitchen'\n", - "modelname = 'church'\n", - "\n", - "# layer 6,7,8,9,10 work OK for different things.\n", - "# layer 8 is good for trees or domes in churches, and hats on horses\n", - "# layer 6 is good for smiles on faces\n", - "# layer 10 is good for hair on faces\n", - "layernum = 8\n", - "\n", - "# Number of images to sample when gathering statistics.\n", - "size = 1000\n", - "\n", "# Make a directory for caching some data.\n", "layerscheme = 'default'\n", "expdir = 'results/pgw/%s/%s/%s/layer%d' % (ganname, modelname, layerscheme, layernum)\n", @@ -93,7 +97,7 @@ "elif ganname == 'proggan':\n", " model = utils.proggan.load_pretrained(modelname)\n", " Rewriter = ganrewrite.ProgressiveGanRewriter\n", - " \n", + "\n", "# Create a Rewriter object - this implements our method.\n", "zds = zdataset.z_dataset_for_model(model, size=size)\n", "gw = Rewriter(\n", From 0ac5a413839dde67c1566dff4e02c2720ca6cd60 Mon Sep 17 00:00:00 2001 From: Aleksey Morozov <36787333+amrzv@users.noreply.github.com> Date: Sat, 10 Jul 2021 12:27:29 +0300 Subject: [PATCH 2/3] cleaned up imports --- notebooks/rewriting-interface.ipynb | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/notebooks/rewriting-interface.ipynb b/notebooks/rewriting-interface.ipynb index 6fb3f6e..4d94198 100644 --- a/notebooks/rewriting-interface.ipynb +++ b/notebooks/rewriting-interface.ipynb @@ -78,11 +78,10 @@ }, "outputs": [], "source": [ - "from utils import zdataset, show, labwidget\n", + "from utils import zdataset, show\n", "from rewrite import ganrewrite, rewriteapp\n", - "import torch, copy, os, json\n", - "from torchvision.utils import save_image\n", - "import utils.stylegan2, utils.proggan\n", + "import os\n", + "from utils.proggan import load_pretrained\n", "from utils.stylegan2 import load_seq_stylegan\n", "\n", "# Make a directory for caching some data.\n", @@ -95,7 +94,7 @@ " model = load_seq_stylegan(modelname, mconv='seq', truncation=0.50)\n", " Rewriter = ganrewrite.SeqStyleGanRewriter\n", "elif ganname == 'proggan':\n", - " model = utils.proggan.load_pretrained(modelname)\n", + " model = load_pretrained(modelname)\n", " Rewriter = ganrewrite.ProgressiveGanRewriter\n", "\n", "# Create a Rewriter object - this implements our method.\n", From cda137f0e4c2ef89cfc0d8b52b5c02ea5b1d666b Mon Sep 17 00:00:00 2001 From: Aleksey Morozov <36787333+amrzv@users.noreply.github.com> Date: Sat, 10 Jul 2021 13:49:58 +0300 Subject: [PATCH 3/3] update torch.symeig to silence the warning --- rewrite/ganrewrite.py | 4 +--- rewrite/rewriteapp.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/rewrite/ganrewrite.py b/rewrite/ganrewrite.py index bf6fda9..f43dcff 100644 --- a/rewrite/ganrewrite.py +++ b/rewrite/ganrewrite.py @@ -1,12 +1,10 @@ import copy import os import torch -import json import random import time import warnings from utils import nethook, renormalize, pbar, tally, imgviz -from collections import OrderedDict import torchvision @@ -819,7 +817,7 @@ def rank_one_conv(weight, direction): def zca_from_cov(cov): - evals, evecs = torch.symeig(cov.double(), eigenvectors=True) + evals, evecs = torch.linalg.eigh(cov.double(), UPLO='U') zca = torch.mm(torch.mm(evecs, torch.diag (evals.sqrt().clamp(1e-20).reciprocal())), evecs.t()).to(cov.dtype) diff --git a/rewrite/rewriteapp.py b/rewrite/rewriteapp.py index cf35e58..8e55167 100644 --- a/rewrite/rewriteapp.py +++ b/rewrite/rewriteapp.py @@ -742,7 +742,7 @@ def rank_one_conv(weight, direction): def zca_from_cov(cov): - evals, evecs = torch.symeig(cov.double(), eigenvectors=True) + evals, evecs = torch.linalg.eigh(cov.double(), UPLO='U') zca = torch.mm(torch.mm(evecs, torch.diag (evals.sqrt().clamp(1e-20).reciprocal())), evecs.t()).to(cov.dtype)