-
Notifications
You must be signed in to change notification settings - Fork 62
Open
Description
I am looking at GEMM computations in EVA.
EVA uses vectorised computations, though following the paper Secure Outsourced Matrix Computation and Application to Neural Networks, we can run a naive encrypted matmul by having the vector size be 1.
This issue asks if the EVA Extension Library (EXL) will be released, which may already implement this, however it is not current available as far as I know.
I have tried to implement this, however I am getting an error: RuntimeError: bad optional access.
You can see my example before, is there something I am missing?
#!/usr/bin/env python
from eva import EvaProgram, Input, Output, evaluate
from eva.ckks import CKKSCompiler
from eva.seal import generate_keys
from eva.metric import valuation_mse
import numpy as np
def get_gemm(N, K, M):
gemm = EvaProgram("gemm", vec_size=1)
with gemm:
outputs = [[0] * N] * M
for n in range(N):
for m in range(M):
for k in range(K):
x = Input(f"x_{n}_{k}")
w = Input(f"w_{k}_{m}")
outputs[n][m] += x * m
for n in range(N):
for m in range(M):
Output(f"out_{n}_{m}", outputs[n][m])
gemm.set_input_scales(25)
gemm.set_output_ranges(10)
return gemm
def generate_inputs(N, K):
inputs = dict()
i = 0
for n in range(N):
for k in range(K):
inputs[f"x_{n}_{k}"] = [i]
i += 1
return inputs
def generate_weights(K, M):
inputs = dict()
i = 0
for k in range(K):
for m in range(M):
inputs[f"w_{k}_{m}"] = [i]
i += 1
return inputs
def main():
N, K, M = 8, 8, 8
inputs = generate_inputs(N, K)
weights = generate_weights(K, M)
gemm = get_gemm(N, K, M)
data = {**weights, **inputs}
print(data)
for prog in [gemm]:
print(f"Compiling {prog.name}")
compiler = CKKSCompiler()
compiled, params, signature = compiler.compile(prog)
public_ctx, secret_ctx = generate_keys(params)
enc_inputs = public_ctx.encrypt(data, signature)
print("excuting GEMM")
enc_outputs = public_ctx.execute(compiled, enc_inputs)
outputs = secret_ctx.decrypt(enc_outputs, signature)
reference = evaluate(compiled, inputs)
print("MSE", valuation_mse(outputs, reference))
print()
if __name__ == "__main__":
main()Metadata
Metadata
Assignees
Labels
No labels