diff --git a/quality/eval.py b/quality/eval.py index fd1d524..405eaae 100644 --- a/quality/eval.py +++ b/quality/eval.py @@ -12,15 +12,15 @@ from PIL import Image -def read_img(imgPath): # read image & data pre-process - data = torch.randn(1, 3, 112, 112) +def read_img(imgPath, device): # read image & data pre-process + data = torch.randn(1, 3, 112, 112).to(device) transform = T.Compose([ T.Resize((112, 112)), T.ToTensor(), T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) img = Image.open(imgPath).convert("RGB") - data[0, :, :, :] = transform(img) + data[0, :, :, :] = transform(img).to(device) return data @@ -36,9 +36,9 @@ def network(eval_model, device): if __name__ == "__main__": imgpath = './demo_imgs/1.jpg' # [1,2,3.jpg] - device = 'cpu' # 'cpu' or 'cuda:x' + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") eval_model = './model/SDD_FIQA_checkpoints_r50.pth' # checkpoint net = network(eval_model, device) - input_data = read_img(imgpath) + input_data = read_img(imgpath, device) pred_score = net(input_data).data.cpu().numpy().squeeze() print(f"Quality score = {pred_score}")