From c3e65e5e681d965060d993c47266581ee4c5ca6a Mon Sep 17 00:00:00 2001 From: xinyu-intel Date: Fri, 2 Mar 2018 13:37:42 +0800 Subject: [PATCH 1/5] parallelization for roipooling --- src/operator/roi_pooling.cc | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 10d1420950cc..46a556de3319 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -56,6 +56,8 @@ inline void ROIPoolForward(const Tensor &out, const int data_size = data.size(1) * data.size(2) * data.size(3); // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R for (int n = 0; n < num_rois; ++n) { + // Increment ROI data pointer + bottom_rois += n * bbox.size(1); int roi_batch_ind = bottom_rois[0]; int roi_start_w = round(bottom_rois[1] * spatial_scale_); int roi_start_h = round(bottom_rois[2] * spatial_scale_); @@ -74,7 +76,13 @@ inline void ROIPoolForward(const Tensor &out, const Dtype* batch_data = bottom_data + data_size * roi_batch_ind; + #pragma omp parallel for firstprivate(batch_data, top_data, argmax_data) for (int c = 0; c < channels_; ++c) { + // Increment all data pointers + batch_data += c * data.size(2) * data.size(3); + top_data += c * out.size(2) * out.size(3); + argmax_data += c * max_idx.size(2) * max_idx.size(3); + for (int ph = 0; ph < pooled_height_; ++ph) { for (int pw = 0; pw < pooled_width_; ++pw) { // Compute pooling region for this output unit: @@ -113,15 +121,19 @@ inline void ROIPoolForward(const Tensor &out, } } } - // Increment all data pointers by one channel - batch_data += data.size(2) * data.size(3); - top_data += out.size(2) * out.size(3); - argmax_data += max_idx.size(2) * max_idx.size(3); + // Decrement all data pointers + batch_data -= c * data.size(2) * data.size(3); + top_data -= c * out.size(2) * out.size(3); + argmax_data -= c * max_idx.size(2) * max_idx.size(3); } - // Increment ROI data pointer - bottom_rois += bbox.size(1); + // Increase data pointers by one bbox + batch_data += channels_ * data.size(2) * data.size(3); + top_data += channels_ * out.size(2) * out.size(3); + argmax_data += channels_ * max_idx.size(2) * max_idx.size(3); + // Decrease bbox pointers + bottom_rois -= n * bbox.size(1); } - + bottom_rois += num_rois * bbox.size(1); return; } From 1ac728db2a9a67afc8a673d059a96cb529777301 Mon Sep 17 00:00:00 2001 From: xinyu-intel Date: Fri, 2 Mar 2018 16:41:43 +0800 Subject: [PATCH 2/5] remove some useless computation --- src/operator/roi_pooling.cc | 45 +++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 46a556de3319..dd8b77e4e7e3 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -54,15 +54,18 @@ inline void ROIPoolForward(const Tensor &out, const int num_rois = bbox.size(0); const int data_size = data.size(1) * data.size(2) * data.size(3); + const int data_size_c = data.size(2) * data.size(3); + const int out_size_c = out.size(2) * out.size(3); + const int max_idx_size_c = max_idx.size(2) * max_idx.size(3); // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R for (int n = 0; n < num_rois; ++n) { // Increment ROI data pointer - bottom_rois += n * bbox.size(1); - int roi_batch_ind = bottom_rois[0]; - int roi_start_w = round(bottom_rois[1] * spatial_scale_); - int roi_start_h = round(bottom_rois[2] * spatial_scale_); - int roi_end_w = round(bottom_rois[3] * spatial_scale_); - int roi_end_h = round(bottom_rois[4] * spatial_scale_); + const Dtype *bottom_rois_n = bottom_rois + n * bbox.size(1); + int roi_batch_ind = bottom_rois_n[0]; + int roi_start_w = round(bottom_rois_n[1] * spatial_scale_); + int roi_start_h = round(bottom_rois_n[2] * spatial_scale_); + int roi_end_w = round(bottom_rois_n[3] * spatial_scale_); + int roi_end_h = round(bottom_rois_n[4] * spatial_scale_); assert(roi_batch_ind >= 0); assert(static_cast(roi_batch_ind) < data.size(0) /* batch size */); @@ -79,9 +82,9 @@ inline void ROIPoolForward(const Tensor &out, #pragma omp parallel for firstprivate(batch_data, top_data, argmax_data) for (int c = 0; c < channels_; ++c) { // Increment all data pointers - batch_data += c * data.size(2) * data.size(3); - top_data += c * out.size(2) * out.size(3); - argmax_data += c * max_idx.size(2) * max_idx.size(3); + const Dtype* batch_data_c = batch_data + c * data_size_c; + Dtype* top_data_c = top_data + c * out_size_c; + Dtype* argmax_data_c = argmax_data + c * max_idx_size_c; for (int ph = 0; ph < pooled_height_; ++ph) { for (int pw = 0; pw < pooled_width_; ++pw) { @@ -106,34 +109,26 @@ inline void ROIPoolForward(const Tensor &out, const int pool_index = ph * pooled_width_ + pw; if (is_empty) { - top_data[pool_index] = 0; - argmax_data[pool_index] = -1; + top_data_c[pool_index] = 0; + argmax_data_c[pool_index] = -1; } for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { const int index = h * width_ + w; - if (batch_data[index] > top_data[pool_index]) { - top_data[pool_index] = batch_data[index]; - argmax_data[pool_index] = index; + if (batch_data_c[index] > top_data_c[pool_index]) { + top_data_c[pool_index] = batch_data_c[index]; + argmax_data_c[pool_index] = index; } } } } } - // Decrement all data pointers - batch_data -= c * data.size(2) * data.size(3); - top_data -= c * out.size(2) * out.size(3); - argmax_data -= c * max_idx.size(2) * max_idx.size(3); } - // Increase data pointers by one bbox - batch_data += channels_ * data.size(2) * data.size(3); - top_data += channels_ * out.size(2) * out.size(3); - argmax_data += channels_ * max_idx.size(2) * max_idx.size(3); - // Decrease bbox pointers - bottom_rois -= n * bbox.size(1); + // Increase data pointers by one outsize + top_data += channels_ * out_size_c; + argmax_data += channels_ * max_idx_size_c; } - bottom_rois += num_rois * bbox.size(1); return; } From 5b0412b10676975b974cf637a7e86b6d1b8b223d Mon Sep 17 00:00:00 2001 From: xinyu-intel Date: Sun, 4 Mar 2018 20:59:44 +0800 Subject: [PATCH 3/5] remove useless muls --- src/operator/roi_pooling.cc | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index dd8b77e4e7e3..6b5c2ff4030f 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -56,11 +56,15 @@ inline void ROIPoolForward(const Tensor &out, const int data_size = data.size(1) * data.size(2) * data.size(3); const int data_size_c = data.size(2) * data.size(3); const int out_size_c = out.size(2) * out.size(3); + const int out_size = channels_ * out_size_c; const int max_idx_size_c = max_idx.size(2) * max_idx.size(3); + const int max_idx_size = channels_ * max_idx_size_c; // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R for (int n = 0; n < num_rois; ++n) { // Increment ROI data pointer const Dtype *bottom_rois_n = bottom_rois + n * bbox.size(1); + Dtype *top_data_n = top_data + n * out_size; + Dtype *argmax_data_n = argmax_data + n * max_idx_size; int roi_batch_ind = bottom_rois_n[0]; int roi_start_w = round(bottom_rois_n[1] * spatial_scale_); int roi_start_h = round(bottom_rois_n[2] * spatial_scale_); @@ -79,12 +83,12 @@ inline void ROIPoolForward(const Tensor &out, const Dtype* batch_data = bottom_data + data_size * roi_batch_ind; - #pragma omp parallel for firstprivate(batch_data, top_data, argmax_data) + #pragma omp parallel for for (int c = 0; c < channels_; ++c) { // Increment all data pointers const Dtype* batch_data_c = batch_data + c * data_size_c; - Dtype* top_data_c = top_data + c * out_size_c; - Dtype* argmax_data_c = argmax_data + c * max_idx_size_c; + Dtype* top_data_c = top_data_n + c * out_size_c; + Dtype* argmax_data_c = argmax_data_n + c * max_idx_size_c; for (int ph = 0; ph < pooled_height_; ++ph) { for (int pw = 0; pw < pooled_width_; ++pw) { @@ -125,9 +129,6 @@ inline void ROIPoolForward(const Tensor &out, } } } - // Increase data pointers by one outsize - top_data += channels_ * out_size_c; - argmax_data += channels_ * max_idx_size_c; } return; } From 94665f18c80fcd999b24aea625c29d3a84efa4de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=96=B0=E5=AE=87?= Date: Mon, 5 Mar 2018 22:05:42 +0800 Subject: [PATCH 4/5] add author and retriggering --- src/operator/roi_pooling.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 6b5c2ff4030f..5b95beb7fe68 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -21,7 +21,7 @@ * Copyright (c) 2015 by Contributors * \file roi_pooling.cc * \brief roi pooling operator - * \author Ross Girshick, Kye-Hyeon Kim, Jian Guo + * \author Ross Girshick, Kye-Hyeon Kim, Jian Guo, Xinyu Chen */ #include "./roi_pooling-inl.h" #include From 6dfac28ad4c19e5b67469ef8edc3b86de126f999 Mon Sep 17 00:00:00 2001 From: xinyu-intel Date: Tue, 6 Mar 2018 09:16:20 +0800 Subject: [PATCH 5/5] retrigger again --- src/operator/roi_pooling.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/roi_pooling.cc b/src/operator/roi_pooling.cc index 5b95beb7fe68..acff1f97dcce 100644 --- a/src/operator/roi_pooling.cc +++ b/src/operator/roi_pooling.cc @@ -93,8 +93,8 @@ inline void ROIPoolForward(const Tensor &out, for (int ph = 0; ph < pooled_height_; ++ph) { for (int pw = 0; pw < pooled_width_; ++pw) { // Compute pooling region for this output unit: - // start (included) = floor(ph * roi_height / pooled_height_) - // end (excluded) = ceil((ph + 1) * roi_height / pooled_height_) + // start (included) = floor(ph * roi_height / pooled_height_) + // end (excluded) = ceil((ph + 1) * roi_height / pooled_height_) int hstart = static_cast(floor(static_cast(ph) * bin_size_h)); int wstart = static_cast(floor(static_cast(pw)