Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions src/operator/roi_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2015 by Contributors
* \file roi_pooling.cc
* \brief roi pooling operator
* \author Ross Girshick, Kye-Hyeon Kim, Jian Guo
* \author Ross Girshick, Kye-Hyeon Kim, Jian Guo, Xinyu Chen
*/
#include "./roi_pooling-inl.h"
#include <mshadow/base.h>
Expand Down Expand Up @@ -54,13 +54,22 @@ inline void ROIPoolForward(const Tensor<cpu, 4, Dtype> &out,

const int num_rois = bbox.size(0);
const int data_size = data.size(1) * data.size(2) * data.size(3);
const int data_size_c = data.size(2) * data.size(3);
const int out_size_c = out.size(2) * out.size(3);
const int out_size = channels_ * out_size_c;
const int max_idx_size_c = max_idx.size(2) * max_idx.size(3);
const int max_idx_size = channels_ * max_idx_size_c;
// For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R
for (int n = 0; n < num_rois; ++n) {
int roi_batch_ind = bottom_rois[0];
int roi_start_w = round(bottom_rois[1] * spatial_scale_);
int roi_start_h = round(bottom_rois[2] * spatial_scale_);
int roi_end_w = round(bottom_rois[3] * spatial_scale_);
int roi_end_h = round(bottom_rois[4] * spatial_scale_);
// Increment ROI data pointer
const Dtype *bottom_rois_n = bottom_rois + n * bbox.size(1);
Dtype *top_data_n = top_data + n * out_size;
Dtype *argmax_data_n = argmax_data + n * max_idx_size;
int roi_batch_ind = bottom_rois_n[0];
int roi_start_w = round(bottom_rois_n[1] * spatial_scale_);
int roi_start_h = round(bottom_rois_n[2] * spatial_scale_);
int roi_end_w = round(bottom_rois_n[3] * spatial_scale_);
int roi_end_h = round(bottom_rois_n[4] * spatial_scale_);
assert(roi_batch_ind >= 0);
assert(static_cast<index_t>(roi_batch_ind) < data.size(0) /* batch size */);

Expand All @@ -74,12 +83,18 @@ inline void ROIPoolForward(const Tensor<cpu, 4, Dtype> &out,

const Dtype* batch_data = bottom_data + data_size * roi_batch_ind;

#pragma omp parallel for
for (int c = 0; c < channels_; ++c) {
// Increment all data pointers
const Dtype* batch_data_c = batch_data + c * data_size_c;
Dtype* top_data_c = top_data_n + c * out_size_c;
Dtype* argmax_data_c = argmax_data_n + c * max_idx_size_c;

for (int ph = 0; ph < pooled_height_; ++ph) {
for (int pw = 0; pw < pooled_width_; ++pw) {
// Compute pooling region for this output unit:
// start (included) = floor(ph * roi_height / pooled_height_)
// end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)
// start (included) = floor(ph * roi_height / pooled_height_)
// end (excluded) = ceil((ph + 1) * roi_height / pooled_height_)
int hstart = static_cast<int>(floor(static_cast<Dtype>(ph)
* bin_size_h));
int wstart = static_cast<int>(floor(static_cast<Dtype>(pw)
Expand All @@ -98,30 +113,23 @@ inline void ROIPoolForward(const Tensor<cpu, 4, Dtype> &out,

const int pool_index = ph * pooled_width_ + pw;
if (is_empty) {
top_data[pool_index] = 0;
argmax_data[pool_index] = -1;
top_data_c[pool_index] = 0;
argmax_data_c[pool_index] = -1;
}

for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
const int index = h * width_ + w;
if (batch_data[index] > top_data[pool_index]) {
top_data[pool_index] = batch_data[index];
argmax_data[pool_index] = index;
if (batch_data_c[index] > top_data_c[pool_index]) {
top_data_c[pool_index] = batch_data_c[index];
argmax_data_c[pool_index] = index;
}
}
}
}
}
// Increment all data pointers by one channel
batch_data += data.size(2) * data.size(3);
top_data += out.size(2) * out.size(3);
argmax_data += max_idx.size(2) * max_idx.size(3);
}
// Increment ROI data pointer
bottom_rois += bbox.size(1);
}

return;
}

Expand Down