-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvector.cpp
More file actions
35 lines (30 loc) · 871 Bytes
/
vector.cpp
File metadata and controls
35 lines (30 loc) · 871 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "vector.h"
std::vector<size_t> buildIndexer(const std::vector<size_t>* shape){
// Builds indexes used to broadcast tensors
// Example:
// vec A: shape (2, 1, 3)
// indexer A (iA): (3, 0, 1)
// iA[1, 1, 1] = iA[1*3 + 1*0 + 1] = iA[4]
std::vector<size_t> indexer(shape->size());
size_t prod = 1;
for (size_t i = shape->size() - 1; i > 0; --i) {
// Cumprod + masking
indexer[i] = prod * (shape->at(i) > 1);
prod *= shape->at(i);
}
indexer[0] = prod * (shape->at(0) > 1);
return indexer;
}
std::vector<size_t> vectorUnravel(const std::vector<size_t>* indexer, size_t idx){
std::vector<size_t> unraveled(indexer->size());
for (size_t i = 0; i < indexer->size(); ++i) {
if (indexer->at(i) == 0) {
unraveled[i] = 0;
} else {
auto [dv, rem] = divmod(idx, indexer->at(i));
idx = rem;
unraveled[i] = dv;
}
}
return unraveled;
}