forked from bargavj/distributedMachineLearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodelAggregate.oc
More file actions
71 lines (59 loc) · 1.84 KB
/
modelAggregate.oc
File metadata and controls
71 lines (59 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include "modelAggregate.oh"
void aggregate(void* args)
{
protocolIO *io = args;
obliv int64_t beta_avg[D], random_vals[D], n_min, sizes[M], b, sign, res, res2;
obig randval, logval, randval2, logval2;
int64_t prec, prec2, val, gates;
obig_init(&randval, MAXN);
obig_init(&logval, MAXN);
obig_init(&randval2, MAXN);
obig_init(&logval2, MAXN);
for(int i = 0; i < M; i++) {
sizes[i] = feedOblivLLong(io->sizes[i], 1);
}
n_min = sizes[0];
for(int i = 1; i < M; i++) {
obliv if (n_min > sizes[i]) {
n_min = sizes[i];
}
}
for(int j = 0; j < D; j++)
{
beta_avg[j] = 0;
random_vals[j] = 0;
for(int i = 0; i < M; i++) {
beta_avg[j] += (feedOblivLLong(io->beta1[i][j], 1) ^ feedOblivLLong(io->beta2[i][j], 2));
random_vals[j] ^= feedOblivLLong(io->random_vals[i][j], 1);
}
beta_avg[j] /= M;
}
b = 2 * SCALE / (M * n_min * io->lambda * io->epsilon);
////////// lap(b) - (+/-) * log(random_vals) * b; b = 2 * SCALE / (m * n * lambda * epsilon); random_vals should be in (0,1]
obig_import_onative(&randval2, RAND_MAX);
ofixed_ln(&logval2, &prec2, randval2, 0);
res2 = obig_export_onative_signed(logval2);
res2 /= (1LL << prec2);
for(int j = 0; j < D; j++) {
sign = 1;
obliv if (random_vals[j] < 0) {
sign = -1;
}
obig_import_onative(&randval, random_vals[j]);
gates = yaoGateCount();
ofixed_ln(&logval, &prec, randval, 0);
gates = yaoGateCount() - gates;
res = obig_export_onative_signed(logval);
res /= (1LL << prec);
fprintf(stderr, "Log gate count: %llu\n", gates);
beta_avg[j] += (res - res2) * b * sign;
}
printGateCount();
for(int j = 0; j < D; j++) {
revealOblivLLong(&io->beta_avg[j], beta_avg[j], 0);
}
obig_free(&randval);
obig_free(&logval);
obig_free(&randval2);
obig_free(&logval2);
}