diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d424dc6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +models +myenv \ No newline at end of file diff --git a/README.md b/README.md index 09417c8..ae0e7ec 100644 --- a/README.md +++ b/README.md @@ -54,11 +54,34 @@ We provide a simply script to get the visualization result on the CIHP dataset u # Example of inference python exp/inference/inference.py \ --loadmodel /path_to_inference_model \ ---img_path ./img/messi.jpg \ +--img_path ./img/ronaldo.jpg \ --output_path ./img/ \ ---output_name /output_file_name +--output_name ronaldo_output ``` +💻 **Note for macOS users (Apple M1/M2/M3):** +```shell +# Example of inference +python exp/inference/inference_mac.py \ +--loadmodel /path_to_inference_model \ +--img_path ./img/ronaldo.jpg \ +--output_path ./img/ \ +--output_name ronaldo_output +``` +This script uses **PyTorch MPS (Metal Performance Shaders)** for GPU acceleration on macOS. + + +### Blend (Overlay) Result +After running inference, you can visualize the segmentation result by blending the original image and the mask together. +```shell +# Example of BLEND +python exp/inference/inference_image_blend.py \ +--img_path ./img/ronaldo.jpg \ +--mask_path ./img/ronaldo_output.png \ +--output_path ./img/ \ +--output_name ronaldo_blend.png +``` + ### Training #### Transfer learning 1. Download the Pascal pretrained model(available soon). diff --git a/dataloaders/.gitignore b/dataloaders/.gitignore new file mode 100644 index 0000000..98cb45d --- /dev/null +++ b/dataloaders/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +/__pycache__ \ No newline at end of file diff --git a/exp/inference/inference_image_blend.py b/exp/inference/inference_image_blend.py new file mode 100644 index 0000000..7835b25 --- /dev/null +++ b/exp/inference/inference_image_blend.py @@ -0,0 +1,58 @@ +import cv2 +import os +import argparse + +def create_overlay(img_path, mask_path, output_path, output_name): + """ + Combines an original image and its parsed mask into a blended overlay. + Args: + img_path (str): Full path to the original image. + mask_path (str): Full path to the parsed mask image. + output_path (str): Directory where the overlay will be saved. + output_name (str): Filename for the final overlay result. + """ + + print("Starting overlay process...") + print(f"Original: {img_path}") + print(f"Mask: {mask_path}") + + # Load images + original = cv2.imread(img_path) + mask = cv2.imread(mask_path) + + # Validate + if original is None or mask is None: + raise FileNotFoundError(f"Could not load image or mask.\noriginal={img_path}\nmask={mask_path}") + + # Resize mask to match original + mask = cv2.resize(mask, (original.shape[1], original.shape[0])) + + # Blend both (semi-transparent) + overlay = cv2.addWeighted(original, 0.6, mask, 0.4, 0) + + # Create output directory if needed + os.makedirs(output_path, exist_ok=True) + + # Save overlay + output_file = os.path.join(output_path, output_name) + cv2.imwrite(output_file, overlay) + + print(f"Overlay saved to: {output_file}") + return output_file + + + +parser = argparse.ArgumentParser(description="Overlay original image and its parsed mask") +parser.add_argument("--img_path", required=True, help="Path to the original image") +parser.add_argument("--mask_path", required=True, help="Path to the mask image") +parser.add_argument("--output_path", required=True, help="Directory for saving output") +parser.add_argument("--output_name", required=True, help="Filename for the saved overlay") + +args = parser.parse_args() + +create_overlay( + img_path=args.img_path, + mask_path=args.mask_path, + output_path=args.output_path, + output_name=args.output_name +) diff --git a/exp/inference/inference_mac.py b/exp/inference/inference_mac.py new file mode 100644 index 0000000..408f2bd --- /dev/null +++ b/exp/inference/inference_mac.py @@ -0,0 +1,188 @@ +import socket +import timeit +import numpy as np +from PIL import Image +from datetime import datetime +import os +import sys +from collections import OrderedDict +sys.path.append('./') + +import torch +from torch.autograd import Variable +from torchvision import transforms +import cv2 + +from networks import deeplab_xception_transfer, graph +from dataloaders import custom_transforms as tr + +import argparse +import torch.nn.functional as F + +# ---------- Device selection: MPS → CPU ---------- +if torch.backends.mps.is_available(): + device = torch.device("mps") + print("Using Apple Metal (MPS)") +else: + device = torch.device("cpu") + print("Using CPU") + +label_colours = [(0,0,0) + , (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0) + , (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)] + +def flip(x, dim): + indices = [slice(None)] * x.dim() + indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, + dtype=torch.long, device=x.device) + return x[tuple(indices)] + +def flip_cihp(tail_list): + ''' + + :param tail_list: tail_list size is 1 x n_class x h x w + :return: + ''' + # tail_list = tail_list[0] + tail_list_rev = [None] * 20 + for xx in range(14): + tail_list_rev[xx] = tail_list[xx].unsqueeze(0) + tail_list_rev[14] = tail_list[15].unsqueeze(0) + tail_list_rev[15] = tail_list[14].unsqueeze(0) + tail_list_rev[16] = tail_list[17].unsqueeze(0) + tail_list_rev[17] = tail_list[16].unsqueeze(0) + tail_list_rev[18] = tail_list[19].unsqueeze(0) + tail_list_rev[19] = tail_list[18].unsqueeze(0) + return torch.cat(tail_list_rev,dim=0) + +def decode_labels(mask, num_images=1, num_classes=20): + """Decode batch of segmentation masks. + + Args: + mask: result of inference after taking argmax. + num_images: number of images to decode from the batch. + num_classes: number of classes to predict (including background). + + Returns: + A batch with num_images RGB images of the same size as the input. + """ + n, h, w = mask.shape + assert (n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % ( + n, num_images) + outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) + for i in range(num_images): + img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) + pixels = img.load() + for j_, j in enumerate(mask[i, :, :]): + for k_, k in enumerate(j): + if k < num_classes: + pixels[k_, j_] = label_colours[k] + outputs[i] = np.array(img) + return outputs + +def read_img(img_path): + _img = Image.open(img_path).convert('RGB') # return is RGB pic + return _img + +def img_transform(img, transform=None): + sample = {'image': img, 'label': 0} + + sample = transform(sample) + return sample + +@torch.no_grad() +def inference(net, img_path='', output_path='./', output_name='f'): + # ----- build adjacencies on correct device ----- + adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float().to(device) + adj2_test = adj2_.unsqueeze(0).unsqueeze(0).expand(1,1,7,20).transpose(2,3) + + adj1_ = torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float().to(device) + adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1,1,7,7) + + cihp_adj = graph.preprocess_adj(graph.cihp_graph) + adj3_ = torch.from_numpy(cihp_adj).float().to(device) + adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1,1,20,20) + + # ----- multi-scale ----- + scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75] + img = read_img(img_path) + testloader_list, testloader_flip_list = [], [] + + for pv in scale_list: + composed = transforms.Compose([ + tr.Scale_only_img(pv), + tr.Normalize_xception_tf_only_img(), + tr.ToTensor_only_img() + ]) + composed_flip = transforms.Compose([ + tr.Scale_only_img(pv), + tr.HorizontalFlip_only_img(), + tr.Normalize_xception_tf_only_img(), + tr.ToTensor_only_img() + ]) + testloader_list.append(img_transform(img, composed)) + testloader_flip_list.append(img_transform(img, composed_flip)) + + start_time = timeit.default_timer() + net.eval() + outputs_final = None + + for iii, (sb, sb_flip) in enumerate(zip(testloader_list, testloader_flip_list)): + inputs = sb['image'].unsqueeze(0).to(device) + inputs_f = sb_flip['image'].unsqueeze(0).to(device) + inputs = torch.cat((inputs, inputs_f), dim=0) + + if iii == 0: + _, _, h, w = inputs.size() + + # forward + outputs = net.forward(inputs, adj1_test, adj3_test, adj2_test) + # TTA: average original + flipped-back + outputs = (outputs[0] + flip(flip_cihp(outputs[1]), dim=-1)) / 2 + outputs = outputs.unsqueeze(0) + + if iii > 0: + outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True) + outputs_final = outputs_final + outputs + else: + outputs_final = outputs.clone() + + predictions = torch.max(outputs_final, 1)[1] + results = predictions.detach().cpu().numpy() + vis_res = decode_labels(results) + + os.makedirs(output_path, exist_ok=True) + Image.fromarray(vis_res[0]).save(os.path.join(output_path, f'{output_name}.png')) + cv2.imwrite(os.path.join(output_path, f'{output_name}_gray.png'), results[0, :, :]) + + print('time used for the multi-scale image inference:', timeit.default_timer() - start_time) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--loadmodel', default='', type=str) + parser.add_argument('--img_path', default='', type=str) + parser.add_argument('--output_path', default='', type=str) + parser.add_argument('--output_name', default='', type=str) + args = parser.parse_args() + + net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem( + n_classes=20, hidden_layers=128, source_classes=7 + ) + + if args.loadmodel: + state = torch.load(args.loadmodel, map_location=device) + net.load_source_model(state) + print('Loaded model:', args.loadmodel) + else: + raise RuntimeError('No model supplied via --loadmodel') + + net.to(device).eval() + + if not args.img_path: + raise RuntimeError('Provide --img_path') + if not args.output_path: + args.output_path = './outputs' + if not args.output_name: + args.output_name = os.path.splitext(os.path.basename(args.img_path))[0] + + inference(net=net, img_path=args.img_path, output_path=args.output_path, output_name=args.output_name) diff --git a/img/ronaldo.jpg b/img/ronaldo.jpg new file mode 100644 index 0000000..27dec76 Binary files /dev/null and b/img/ronaldo.jpg differ diff --git a/img/ronaldo_blend.png b/img/ronaldo_blend.png new file mode 100644 index 0000000..5671a01 Binary files /dev/null and b/img/ronaldo_blend.png differ diff --git a/img/ronaldo_output.png b/img/ronaldo_output.png new file mode 100644 index 0000000..891cd37 Binary files /dev/null and b/img/ronaldo_output.png differ diff --git a/img/ronaldo_output_gray.png b/img/ronaldo_output_gray.png new file mode 100644 index 0000000..b13c239 Binary files /dev/null and b/img/ronaldo_output_gray.png differ diff --git a/networks/.gitignore b/networks/.gitignore new file mode 100644 index 0000000..98cb45d --- /dev/null +++ b/networks/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +/__pycache__ \ No newline at end of file diff --git a/sync_batchnorm/.gitignore b/sync_batchnorm/.gitignore new file mode 100644 index 0000000..98cb45d --- /dev/null +++ b/sync_batchnorm/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +/__pycache__ \ No newline at end of file