diff --git a/golang/Makefile b/golang/Makefile new file mode 100644 index 000000000000..54019740c87a --- /dev/null +++ b/golang/Makefile @@ -0,0 +1,64 @@ +.PHONY: clean all + +TVM_BASE = $(CURDIR)/../ +TARGET = gotvm +LIBS = -lm -ldl +NATIVE_SRC = tvm_runtime_pack.cc + +GOPATH=$(CURDIR)/gopath +GOPATHDIR=${GOPATH}/src/${TARGET}/ +CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/" +CGO_CXXFLAGS="-std=c++11" +CGO_CFLAGS="-I${TVM_BASE}" +CGO_LDFLAGS="-ldl -lm" + +all: + @mkdir gopath 2>/dev/null || true + @mkdir gopath/src 2>/dev/null || true + @mkdir gopath/src/$(TARGET) 2>/dev/null || true + @cp src/$(TARGET).cc gopath/src/$(TARGET) + @cp src/$(TARGET).h gopath/src/$(TARGET) + @cp src/$(NATIVE_SRC) gopath/src/$(TARGET) + @cp src/*.go gopath/src/$(TARGET) + @export GOPATH=$(GOPATH); \ + export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ + export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \ + export CGO_CFLAGS=$(CGO_CFLAGS); \ + export CGO_LDFLAGS=$(CGO_LDFLAGS); \ + (cd $(GOPATHDIR) && go clean -cache \ + && golint && go build -o $(TARGET).a \ + && go install) + @find . -name gotvm.a + @#mkdir gopath/doc 2>/dev/null || true + @#godoc -html -goroot gopath/ gotvm | grep -v "for documentation on the gotvm command" > gopath/doc/gotvm.html + @#echo "Run 'godoc -http=:6060 -goroot=./gopath' for documentation" + +samples: all + cp gopath/pkg/linux_amd64/gotvm.a sample/ -rfa + make -C sample + +tests: all + @(cd sample; python3 deploy.py) + @export GOPATH=$(GOPATH); \ + export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ + export CGO_CXXFLAGS=$(CGO_CXXFLAGS); \ + export CGO_CFLAGS=$(CGO_CFLAGS); \ + export CGO_LDFLAGS=$(CGO_LDFLAGS); \ + (cd $(GOPATHDIR) \ + && cp ../../../sample/deploy.so . \ + && go test -v) + +clean: + @if [ -d $(GOPATHDIR) ] ; then \ + export GOPATH=$(GOPATH); \ + export CGO_CPPFLAGS=$(CGO_CPPFLAGS); \ + export CGO_CFLAGS=$(CGO_CFLAGS); \ + export CGO_LDFLAGS=$(CGO_LDFLAGS); \ + (cd $(GOPATHDIR) && go clean -cache); fi + @rm -rf gopath + @make -C sample clean + +lint: + @(cd src; golint) + @python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.cc + @python3 ${TVM_BASE}/dmlc-core/scripts/lint.py gotvm cpp src/*.h diff --git a/golang/README.md b/golang/README.md new file mode 100644 index 000000000000..9c152dd7365c --- /dev/null +++ b/golang/README.md @@ -0,0 +1,107 @@ +# gotvm - Golang Frontend for TVM Runtime + +This folder contain golang interface for TVM runtime. It brings TVM runtime to Golang. + +- It enable c runtime api of tvm exposed to golang. +- It enables module loading (lib, graph and params) and inference operations. + +## Installation + +### Requirements + +- go compiler (https://golang.org/) version 0.10 or above. + +### Modules + +- src + Module that generates golang package corresponding to the c runtime api exposed from tvm source tree. + This process build golang package _gotvm.a_ + +- samples + Sample golang reference application to inference through gotvm package. + +### Build + +Once the Requirements are installed + +To build _gotvm_ package + +```bash +make +``` + +To build and run internal tests + +```bash +make tests +``` + +To build sample apps. + +```bash +make samples +``` + +## Run + +To Demonstrates sample TVM module compilation using python and deploy via golang. +```bash +./simple +``` + +To deploy a realtime module with lib, graph and param. +```bash +./complex +``` + +To demonstrate go function closure conversion to packed function handle. + +```bash +./pack_func_convert +``` + +To demonstrate a packed function handle given as an argument. + +```bash +pack_func_handle_arg +``` + +To register go function with runtime as a global function. + +```bash +pack_func_register +``` + +To demonstrate function closure passed as argument to a function call. + +```bash +./pack_func_closure_arg +``` + +To demonstrate function closure returned from a packed function. + +```bash +./pack_func_closure_return +``` + +## Documentation +gotvm.go is documented with sufficient information about gotvm package. +A html version documentation can be accessed by running below command after building runtime. + +```bash +godoc -http=:6060 -goroot=./gopath +``` +After above command try http://127.0.0.1:6060 from any browser. + +Also please refer to the sample applications under sample folder. + +## Docker +Docker setup may need below additions for dependencies and environment preparation. + +Please refer ```docker/install/ubuntu_install_golang.sh``` for the packages dependencies. + +go compiler 1.10 on ubuntu doesn't install on standard path, hence an explicit export may be needed as shown below. + +```bash +export PATH="/usr/lib/go-1.10/bin:$PATH"``` +``` diff --git a/golang/sample/Makefile b/golang/sample/Makefile new file mode 100644 index 000000000000..8ebea49da42f --- /dev/null +++ b/golang/sample/Makefile @@ -0,0 +1,17 @@ +.PHONY: clean all + +SOURCES=$(wildcard *.go) +EXECUTABLE=$(patsubst %.go, %, $(SOURCES)) + +all: $(EXECUTABLE) + @golint + @python3 deploy.py + +%: %.o + @go tool link -linkmode external -extld "g++" -extldflags "-ldl" -o $@ $< + +%.o: %.go + @go tool compile -pack -o $@ $< + +clean: + @rm -f $(EXECUTABLE) *.so *.o *.a diff --git a/golang/sample/complex.go b/golang/sample/complex.go new file mode 100644 index 000000000000..7a8d0044375c --- /dev/null +++ b/golang/sample/complex.go @@ -0,0 +1,171 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief Sample golang application deployment over tvm. + * \file complex.go + */ + +package main + +import ( + "fmt" + "io/ioutil" + "math/rand" + "./gotvm" + "runtime" +) + +// NNVM compiled model paths. +const ( + modLib = "./mobilenet.so" + modJSON = "./mobilenet.json" + modParams = "./mobilenet.params" +) + +// main +func main() { + defer runtime.GC() + // Welcome + fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion) + fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion) + + // Query global functions available + funcNames, err := gotvm.FuncListGlobalNames() + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Global Functions:%v\n", funcNames) + + // Import tvm module (so) + modp, err := gotvm.LoadModuleFromFile(modLib) + if err != nil { + fmt.Print(err) + fmt.Printf("Please copy tvm compiled modules here and update the sample.go accordingly.\n") + fmt.Printf("You may need to update modLib, modJSON, modParams, tshapeIn, tshapeOut\n") + return + } + fmt.Printf("Module Imported:%p\n", modp) + bytes, err := ioutil.ReadFile(modJSON) + if err != nil { + fmt.Print(err) + return + } + jsonStr := string(bytes) + + // Load module on tvm runtime - call tvm.graph_runtime.create + funp, err := gotvm.GetGlobalFunction("tvm.graph_runtime.create") + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Calling tvm.graph_runtime.create\n") + // Call function + graphrt, err := funp.Invoke(jsonStr, modp, (int64)(gotvm.KDLCPU), (int64)(0)) + if err != nil { + fmt.Print(err) + return + } + graphmod := graphrt.AsModule() + fmt.Printf("Graph runtime Created\n") + + // Array allocation attributes + tshapeIn := []int64{1, 224, 224, 3} + tshapeOut := []int64{1, 1001} + + // Allocate input Array + inX, err := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0)) + if err != nil { + fmt.Print(err) + return + } + + // Allocate output Array + out, err := gotvm.Empty(tshapeOut) + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Input and Output Arrays allocated\n") + + // Get module function from graph runtime : load_params + // Read params + bytes, err = ioutil.ReadFile(modParams) + if err != nil { + fmt.Print(err) + } + + // Load Params + funp, err = graphmod.GetFunction("load_params") + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Func load_params:%p\n", funp) + + // Call function + _, err = funp.Invoke(bytes) + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Module params loaded\n") + + // Set some data in input Array + inSlice := make([]float32, (244 * 244 * 3)) + rand.Seed(10) + rand.Shuffle(len(inSlice), func(i, j int) {inSlice[i], + inSlice[j] = rand.Float32(), + rand.Float32() }) + inX.CopyFrom(inSlice) + + // Set Input + funp, err = graphmod.GetFunction("set_input") + if err != nil { + fmt.Print(err) + return + } + + // Call function + _, err = funp.Invoke("input", inX) + if err != nil { + fmt.Print(err) + return + } + + fmt.Printf("Module input is set\n") + + // Run + funp, err = graphmod.GetFunction("run") + if err != nil { + fmt.Print(err) + return + } + + // Call function + _, err = funp.Invoke() + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Module Executed \n") + + // Call runtime function get_output + funp, err = graphmod.GetFunction("get_output") + if err != nil { + fmt.Print(err) + return + } + + // Call function + _, err = funp.Invoke(int64(0), out) + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Got Module Output \n") + + // Print results + outIntf, _ := out.AsSlice() + outSlice := outIntf.([]float32) + fmt.Printf("Result:%v\n", outSlice[:10]) +} diff --git a/golang/sample/deploy.py b/golang/sample/deploy.py new file mode 100644 index 000000000000..065638299bc6 --- /dev/null +++ b/golang/sample/deploy.py @@ -0,0 +1,40 @@ +""" +Get Started with TVM Go +======================= +""" +from __future__ import absolute_import, print_function + +import tvm +import numpy as np + +# Global declarations of environment. + +tgt_host="llvm" +tgt="llvm" + +###################################################################### +# Describe the Computation +# ------------------------ +n = tvm.var("n") +A = tvm.placeholder((n,), name='A') +B = tvm.placeholder((n,), name='B') +C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C") + +###################################################################### +# Schedule the Computation +# ------------------------ +s = tvm.create_schedule(C.op) + +###################################################################### +# Compilation +# ----------- +fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd") + +###################################################################### +# Save Compiled Module +# -------------------- +from tvm.contrib import cc +from tvm.contrib import util + +fadd.save("deploy.o") +cc.create_shared("deploy.so", ["deploy.o"]) diff --git a/golang/sample/pack_func_closure_arg.go b/golang/sample/pack_func_closure_arg.go new file mode 100644 index 000000000000..b31113160586 --- /dev/null +++ b/golang/sample/pack_func_closure_arg.go @@ -0,0 +1,57 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief Sample golang application to demonstrate go-closure given to a packed function argument. + * \file pack_func_closure_arg.go + */ + +package main + +import ( + "fmt" + "./gotvm" +) + + +// sampleFunctionArg receives a Packed Function handle and calls it. +func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + // Call Packed Function + retVal, err = pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64()) + return +} + +// main +func main() { + // Not passing a function name implicitely + // picks the name from reflection as "main.sampleDunctionArg" + gotvm.RegisterFunction(sampleFunctionArg); + fmt.Printf("Registered: sampleFunctionArg\n") + + // Get registered global function. + funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg") + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("GetGlobalFunction: main.sampleFunctionArg - Success\n") + + // funccall is a simple golang callback function like C = A + B. + funccall := func (args ...*gotvm.Value) (retVal interface{}, err error) { + for _, v := range args { + fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) + } + val1 := args[0].AsInt64() + val2 := args[1].AsInt64() + retVal = int64(val1+val2) + return + } + + // Call function + result, err := funp.Invoke(funccall, 30, 50) + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Invoked sampleFunctionArg with function closure arg : Result:%v\n", result.AsInt64()) +} diff --git a/golang/sample/pack_func_closure_return.go b/golang/sample/pack_func_closure_return.go new file mode 100644 index 000000000000..98de8e2e5146 --- /dev/null +++ b/golang/sample/pack_func_closure_return.go @@ -0,0 +1,57 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief Sample golang application to demonstrate go-closure returned from a callback function. + * \file pack_func_closure_return.go + */ + +package main + +import ( + "fmt" + "./gotvm" +) + +// sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. +func sampleFunctionCb(args ...*gotvm.Value) (retVal interface{}, err error) { + funccall := func (cargs ...*gotvm.Value) (fret interface{}, ferr error) { + for _, v := range cargs { + fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) + } + val1 := cargs[0].AsInt64() + val2 := cargs[1].AsInt64() + fret = int64(val1+val2) + return + } + retVal = funccall + return +} + +// main +func main() { + // Not passing a function name implicitely + // picks the name from reflection as "main.sampleDunctionCb" + gotvm.RegisterFunction(sampleFunctionCb); + fmt.Printf("Registered: sampleFunctionCb\n") + + // Get registered global function + funp, err := gotvm.GetGlobalFunction("main.sampleFunctionCb") + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("GetGlobalFunction: main.sampleFunctionCb - Success\n") + + // Call function + result, err := funp.Invoke() + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Invoked main.sampleFunctionCb via Function handle\n") + + pfunc := result.AsFunction() + fmt.Printf("Function Handle received via Packed Function call:%T - %v \n", pfunc, pfunc) + + pfuncRet, err := pfunc.Invoke(30, 40) + fmt.Printf("Invoked closure inside sampleFunctionCb result:%v\n", pfuncRet.AsInt64()) +} diff --git a/golang/sample/pack_func_convert.go b/golang/sample/pack_func_convert.go new file mode 100644 index 000000000000..6748d67fe75f --- /dev/null +++ b/golang/sample/pack_func_convert.go @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief Sample golang application to demonstrate function conversion to packed function. + * \file pack_func_convert.go + */ + +package main + +import ( + "fmt" + "./gotvm" +) + +// sampleCb is a simple golang callback function like C = A + B. +func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { + for _, v := range args { + fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) + } + val1 := args[0].AsInt64() + val2 := args[1].AsInt64() + retVal = int64(val1+val2) + return +} + +// main +func main() { + // Welcome + + // Simple convert to a packed function + fhandle, err := gotvm.ConvertFunction(sampleCb) + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Converted function\n") + + retVal, err := fhandle.Invoke(10, 20) + fmt.Printf("Invoke Completed\n") + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Result:%v\n", retVal.AsInt64()) +} diff --git a/golang/sample/pack_func_handle_arg.go b/golang/sample/pack_func_handle_arg.go new file mode 100644 index 000000000000..ad1313f93f5f --- /dev/null +++ b/golang/sample/pack_func_handle_arg.go @@ -0,0 +1,60 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief Sample golang application to demonstrate converted packed + * function handle passed to another packed function. + * \file pack_func_handle_arg.go + */ + +package main + +import ( + "fmt" + "./gotvm" +) + +// sampleCb is a simple golang callback function like C = A + B. +func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { + for _, v := range args { + fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) + } + val1 := args[0].AsInt64() + val2 := args[1].AsInt64() + retVal = int64(val1+val2) + return +} + +// sampleFunctionArg receives a Packed Function handle and calls it. +func sampleFunctionArg(args ...*gotvm.Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + + // Call Packed Function + retVal, err = pfunc.Invoke(args[1], args[2]) + return +} + +// main +func main() { + // Simple convert to a packed function + fhandle, err := gotvm.ConvertFunction(sampleCb) + if err != nil { + fmt.Print(err) + return + } + + gotvm.RegisterFunction(sampleFunctionArg); + fmt.Printf("Registered: sampleFunctionArg\n") + + funp, err := gotvm.GetGlobalFunction("main.sampleFunctionArg") + if err != nil { + fmt.Print(err) + return + } + + retVal, err := funp.Invoke(fhandle, 10, 20) + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("Result:%v\n", retVal.AsInt64()) +} diff --git a/golang/sample/pack_func_register.go b/golang/sample/pack_func_register.go new file mode 100644 index 000000000000..5da67e00c16c --- /dev/null +++ b/golang/sample/pack_func_register.go @@ -0,0 +1,63 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief Sample golang application to demonstrate function register into TVM global functions. + * \file pack_func_register.go + */ + +package main + +import ( + "fmt" + "./gotvm" + "strings" +) + +// sampleCb is a simple golang callback function like C = A + B. +func sampleCb(args ...*gotvm.Value) (retVal interface{}, err error) { + for _, v := range args { + fmt.Printf("ARGS:%T : %v\n", v.AsInt64(), v.AsInt64()) + } + val1 := args[0].AsInt64() + val2 := args[1].AsInt64() + retVal = int64(val1+val2) + return +} + +// main +func main() { + // Register sampleCb with TVM packed function system and call and check Global Function List. + gotvm.RegisterFunction(sampleCb, "sampleCb"); + // Query global functions available + funcNames, err := gotvm.FuncListGlobalNames() + if err != nil { + fmt.Print(err) + return + } + + found := 0 + for ii := range (funcNames) { + if strings.Compare(funcNames[ii], "sampleCb") == 0 { + found = 1 + } + } + if found == 0 { + fmt.Printf("Function registerd but, not listed\n") + return + } + + + // Get "sampleCb" and verify the call. + funp, err := gotvm.GetGlobalFunction("sampleCb") + if err != nil { + fmt.Print(err) + return + } + + // Call function + result, err := funp.Invoke((int64)(10), (int64)(20)) + if err != nil { + fmt.Print(err) + return + } + fmt.Printf("sampleCb result: %v\n", result.AsInt64()) +} diff --git a/golang/sample/simple.go b/golang/sample/simple.go new file mode 100644 index 000000000000..ada3963662de --- /dev/null +++ b/golang/sample/simple.go @@ -0,0 +1,72 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief Sample golang application deployment over tvm. + * \file simple.go + */ + +package main + +import ( + "fmt" + "runtime" + "./gotvm" + "math/rand" +) + +// NNVM compiled model paths. +const ( + modLib = "./deploy.so" +) + +// main +func main() { + // Welcome + defer runtime.GC() + fmt.Printf("TVM Version : v%v\n", gotvm.TVMVersion) + fmt.Printf("DLPACK Version: v%v\n\n", gotvm.DLPackVersion) + + // Import tvm module (so) + modp, _ := gotvm.LoadModuleFromFile(modLib) + fmt.Printf("Module Imported\n") + + + // Allocate Array for inputs and outputs. + // Allocation by explicit type and context. + tshapeIn := []int64{4} + inX, _ := gotvm.Empty(tshapeIn, "float32", gotvm.CPU(0)) + + // Default allocation on CPU + inY, _ := gotvm.Empty(tshapeIn, "float32") + + // Default allocation to type "float32" and on CPU + out, _ := gotvm.Empty(tshapeIn) + fmt.Printf("Input and Output Arrays allocated\n") + + // Fill Input Data : inX , inY + inXSlice := make([]float32, 4) + inYSlice := make([]float32, 4) + for i := range inXSlice { + inXSlice[i] = rand.Float32() + inYSlice[i] = rand.Float32() + } + + + // Copy the data on target memory through runtime CopyFrom api. + inX.CopyFrom(inXSlice) + inY.CopyFrom(inYSlice) + fmt.Printf("X: %v\n", inXSlice) + fmt.Printf("Y: %v\n", inYSlice) + + // Get function "myadd" + funp, _ := modp.GetFunction("myadd") + + // Call function + funp.Invoke(inX, inY, out) + fmt.Printf("Module function myadd executed\n") + + // Get the output tensor as an interface holding a slice through runtime CopyTo api. + outSlice, _ := out.AsSlice() + + // Print results + fmt.Printf("Result:%v\n", outSlice.([]float32)) +} diff --git a/golang/src/array_test.go b/golang/src/array_test.go new file mode 100644 index 000000000000..6917dd14e373 --- /dev/null +++ b/golang/src/array_test.go @@ -0,0 +1,596 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package + * \file array_test.go + */ + + +package gotvm + +import ( + "testing" + "unsafe" + "math/rand" +) + +// Create an array and check size. +func TestArrayCreateSize(t *testing.T) { + _, err := Empty([]int64{4}) + if err != nil { + t.Error(err.Error()) + return + } + + _, err = Empty([]int64{4, 5, 6}) + if err != nil { + t.Error(err.Error()) + return + } + + _, err = Empty([]int64{}) + if err == nil { + t.Error("Expected err for empty Array created, but didn't got !!") + return + } +} + +// Check array creation via various different arguments. +func TestArrayCreateArgs(t *testing.T) { + _, err := Empty([]int64{4, 2}, "float32", CPU(0)) + if err != nil { + t.Error(err.Error()) + return + } + + _, err = Empty([]int64{4, 2}, "float32") + if err != nil { + t.Error(err.Error()) + return + } + + _, err = Empty([]int64{4, 2}, CPU(0)) + if err != nil { + t.Error(err.Error()) + return + } + + _, err = Empty([]int64{4, 2}, CPU(0), "float32") + if err != nil { + t.Error(err.Error()) + return + } +} + +// Create an array and check the NDim. +func TestArrayNDim(t *testing.T) { + arr, err := Empty([]int64{4, 5, 6}) + if err != nil { + t.Error(err.Error()) + return + } + + if 3 != arr.GetNdim() { + t.Errorf("GetNdim failed Expected: 3 Got :%v\n", arr.GetNdim()) + return + } +} + +// Create an array and check Shape. +func TestArrayShape(t *testing.T) { + arr, err := Empty([]int64{4, 5, 6}) + if err != nil { + t.Error(err.Error()) + return + } + + shape := arr.GetShape() + if len(shape) != 3 { + t.Errorf("Shape slice expected: 3 Got :%v\n", len(shape)) + return + } + + if shape[0] != 4 || shape[1] != 5 || shape[2] != 6 { + t.Errorf("Shape values expected {4, 5, 6} Got : %v\n", shape); + return + } +} + +// Create an array and check created Context. +func TestArrayCtx(t *testing.T) { + // TODO: Could some test cases for other targets + arr, err := Empty([]int64{4}, CPU(0)) + if err != nil { + t.Error(err.Error()) + return + } + + ctx := arr.GetCtx() + if ctx.DeviceType != KDLCPU { + t.Errorf("Ctx DeviceType expected: %v Got :%v\n", KDLCPU, ctx.DeviceType) + return + } + if ctx.DeviceID != 0 { + t.Errorf("Ctx DeviceID expected: %v Got :%v\n", KDLCPU, ctx.DeviceID) + return + } + + arr, err = Empty([]int64{4}, CPU(2)) + if err != nil { + t.Error(err.Error()) + return + } + + ctx = arr.GetCtx() + if ctx.DeviceType != KDLCPU { + t.Errorf("Ctx DeviceType expected: %v Got :%v\n", KDLCPU, ctx.DeviceType) + return + } + if ctx.DeviceID != 2 { + t.Errorf("Ctx DeviceID expected: %v Got :%v\n", KDLCPU, ctx.DeviceID) + return + } +} + +// Create array of different dtypes and check dtypes. +func TestArrayDType(t *testing.T) { + for _, dtype := range []string{"int8", "int16", "int32", "int64", + "uint8", "uint16", "uint32", "uint64", + "float32", "float64"} { + arr, err := Empty([]int64{4}, dtype) + if err != nil { + t.Error(err.Error()) + return + } + + if dtype != arr.GetDType() { + t.Errorf("Dtype expected: %v Got :%v\n", dtype, arr.GetDType()) + return + } + } +} + +// Copy Int8 data to created Array and verify. +func TestArrayCopySliceInt8(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "int8") + + if err != nil { + t.Error(err.Error()) + return + } + + bdata := make([]byte, dlen) + rand.Read(bdata) + data := (*[1<<31]int8)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []int8: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + + dataRet := ret.([]int8) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} + +// Copy Int16 data to created Array and verify. +func TestArrayCopySliceInt16(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "int16") + if err != nil { + t.Error(err.Error()) + return + } + + bdata := make([]byte, dlen*2) + rand.Read(bdata) + data := (*[1<<31]int16)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + switch ret.(type) { + case []int16: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + + dataRet := ret.([]int16) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} + +// Copy Int32 data to created Array and verify. +func TestArrayCopySliceInt32(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "int32") + if err != nil { + t.Error(err.Error()) + return + } + + bdata := make([]byte, dlen*4) + rand.Read(bdata) + data := (*[1<<31]int32)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []int32: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + dataRet := ret.([]int32) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} + +// Copy Int64 data to created Array and verify. +func TestArrayCopySliceInt64(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "int64") + if err != nil { + t.Error(err.Error()) + return + } + + bdata := make([]byte, dlen*8) + rand.Read(bdata) + data := (*[1<<31]int64)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []int64: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + dataRet := ret.([]int64) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} + +// Copy UInt8 data to created Array and verify. +func TestArrayCopySliceUInt8(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "uint8") + if err != nil { + t.Error(err.Error()) + return + } + + bdata := make([]byte, dlen) + rand.Read(bdata) + data := (*[1<<31]uint8)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []uint8: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + dataRet := ret.([]uint8) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} + +// Copy UInt16 data to created Array and verify. +func TestArrayCopySliceUInt16(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "uint16") + if err != nil { + t.Error(err.Error()) + return + } + + bdata := make([]byte, dlen*2) + rand.Read(bdata) + data := (*[1<<31]uint16)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []uint16: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + dataRet := ret.([]uint16) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} + +// Copy UInt32 data to created Array and verify. +func TestArrayCopySliceUInt32(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "uint32") + if err != nil { + t.Error(err.Error()) + return + } + + bdata := make([]byte, dlen*4) + rand.Read(bdata) + data := (*[1<<31]uint32)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []uint32: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + dataRet := ret.([]uint32) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} + +// Copy UInt64 data to created Array and verify. +func TestArrayCopySliceUInt64(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "uint64") + if err != nil { + t.Error(err.Error()) + return + } + + bdata := make([]byte, dlen*8) + rand.Read(bdata) + data := (*[1<<31]uint64)(unsafe.Pointer(&bdata[0]))[:dlen:dlen] + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []uint64: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + dataRet := ret.([]uint64) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} + +// Copy Float32 data to created Array and verify. +func TestArrayCopySliceFloat32(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "float32") + if err != nil { + t.Error(err.Error()) + return + } + + data := make([]float32, dlen) + + for i := range data { + data[i] = rand.Float32() + } + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []float32: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + dataRet := ret.([]float32) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v \nGot :%v \n", data, dataRet) + return + } + } +} + +// Copy Float64 data to created Array and verify. +func TestArrayCopySliceFloat64(t *testing.T) { + dlen := int64(32) + arr, err := Empty([]int64{4, dlen/4}, "float64") + if err != nil { + t.Error(err.Error()) + return + } + + data := make([]float64, dlen) + + for i := range data { + data[i] = rand.Float64() + } + + err = arr.CopyFrom(data) + if err != nil { + t.Error(err.Error()) + return + } + + ret, err := arr.AsSlice() + if err != nil { + t.Error(err.Error()) + return + } + + switch ret.(type) { + case []float64: + default: + t.Errorf("Expected : %T but got :%T\n", data, ret) + return + } + dataRet := ret.([]float64) + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v\n", data, dataRet) + return + } + } +} diff --git a/golang/src/bytearray.go b/golang/src/bytearray.go new file mode 100644 index 000000000000..e40a630223dc --- /dev/null +++ b/golang/src/bytearray.go @@ -0,0 +1,72 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package source for TVMByteArray interface. + * \file bytearray.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +import ( + "unsafe" +) + +// ByteArray type wraps the TVMByteArray of C runtime API. +// +// This can be used to hold raw data like params of a model. +type ByteArray uintptr + +// nativeCPtr returns the type freed unitptr for ByteArray. +func (tbytearray ByteArray) nativeCPtr() (retVal uintptr) { + retVal = (uintptr)(tbytearray) + return +} + +// SetData is used to intialize ByteArray from a golang string object. +// +// This method initialize both data and data size of the underlaying object. +// This function handles freeing old data object if any before allocating new. +// +// `val` is the golang string object from which the ByteArray is initialized. +func (tbytearray ByteArray) setData(val string) { + bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data + if bufPtr == (*_Ctype_char)(C.NULL) { + C.free(unsafe.Pointer(bufPtr)) + } + + ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data = C.CString(val) + ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size = C.ulong(len(val)) +} + +// getData returns the golang byte slice corresponding to the ByteArray. +func (tbytearray ByteArray) getData() (retVal []byte) { + val := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data + blen := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).size + retVal = C.GoBytes(unsafe.Pointer(val), C.int(blen)) + return +} + +// newByteArray initilizes the native TVMByteArray object with given byte slice +// +//`val` is the golang byte array used to initialize. +// +// returns newly created ByteArray. +func newByteArray(val []byte) (retVal ByteArray) { + handle := ByteArray(C.malloc(C.sizeof_TVMByteArray)) + ((*C.TVMByteArray)(unsafe.Pointer(handle))).data = (*_Ctype_char)(C.NULL) + ((*C.TVMByteArray)(unsafe.Pointer(handle))).size = 0 + handle.setData(string(val)) + retVal = handle + return +} + +// deleteTVMByteArray releases the allocated native object of ByteArray. +// +// This delete handles freeing of underlaying native data object too. +func (tbytearray ByteArray) deleteTVMByteArray() { + bufPtr := ((*C.TVMByteArray)(unsafe.Pointer(tbytearray))).data + C.free(unsafe.Pointer(bufPtr)) + C.free(unsafe.Pointer(tbytearray.nativeCPtr())) +} diff --git a/golang/src/bytearray_test.go b/golang/src/bytearray_test.go new file mode 100644 index 000000000000..f49e75ee2fa6 --- /dev/null +++ b/golang/src/bytearray_test.go @@ -0,0 +1,32 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package + * \file bytearray_test.go + */ + + +package gotvm + +import ( + "testing" + "math/rand" +) + +// Check ByteArray creation from byte slice and verify the data. +func TestByteArrayGet(t *testing.T) { + data := make([]byte, 1024) + rand.Read(data) + + barr := newByteArray(data) + dataRet := barr.getData() + if len(data) != len(dataRet) { + t.Errorf("Data expected Len: %v Got :%v\n", len(data), len(dataRet)) + return + } + for i := range data { + if data[i] != dataRet[i] { + t.Errorf("Data expected: %v Got :%v at : %v\n", data[i], dataRet[i], i) + return + } + } +} diff --git a/golang/src/context.go b/golang/src/context.go new file mode 100644 index 000000000000..8a3b613ea6b9 --- /dev/null +++ b/golang/src/context.go @@ -0,0 +1,89 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package source for TVMContext interface + * \file context.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +// KDLCPU is golang enum correspond to TVM device type kDLCPU. +var KDLCPU = int32(C.kDLCPU) +// KDLGPU is golang enum correspond to TVM device type kDLGPU. +var KDLGPU = int32(C.kDLGPU) +// KDLCPUPinned is golang enum correspond to TVM device type kDLCPUPinned. +var KDLCPUPinned = int32(C.kDLCPUPinned) +// KDLOpenCL is golang enum correspond to TVM device type kDLOpenCL. +var KDLOpenCL = int32(C.kDLOpenCL) +// KDLMetal is golang enum correspond to TVM device type kDLMetal. +var KDLMetal = int32(C.kDLMetal) +// KDLVPI is golang enum correspond to TVM device type kDLVPI. +var KDLVPI = int32(C.kDLVPI) +// KDLROCM is golang enum correspond to TVM device type kDLROCM. +var KDLROCM = int32(C.kDLROCM) +// KDLSDAccel is golang enum correspond to TVM device type kDLSDAccel. +var KDLSDAccel = int32(C.kDLSDAccel) +// KDLVulkan is golang enum correspond to TVM device type kDLVulkan. +var KDLVulkan = int32(C.kDLVulkan) +// KOpenGL is golang enum correspond to TVM device type kOpenGL. +var KOpenGL = int32(C.kOpenGL) +// KExtDev is golang enum correspond to TVM device type kDLExtDev. +var KExtDev = int32(C.kDLExtDev) + +// Context dtype corresponding to TVMContext aka DLContext +type Context struct { + DeviceType int32 + DeviceID int32 +} + +// CPU returns the Context object for CPU target on given index +func CPU(index int32) Context { + return Context{KDLCPU, index} +} + +// GPU returns the Context object for GPU target on given index +func GPU(index int32) Context { + return Context{KDLGPU, index} +} + +// CPUPinned returns the Context object for CPUPinned target on given index +func CPUPinned(index int32) Context { + return Context{KDLCPUPinned, index} +} + +// OpenCL returns the Context object for OpenCL target on given index +func OpenCL(index int32) Context { + return Context{KDLOpenCL, index} +} + +// Metal returns the Context object for Metal target on given index +func Metal(index int32) Context { + return Context{KDLMetal, index} +} + +// VPI returns the Context object for VPI target on given index +func VPI(index int32) Context { + return Context{KDLVPI, index} +} + +// ROCM returns the Context object for ROCM target on given index +func ROCM(index int32) Context { + return Context{KDLROCM, index} +} + +// SDAccel returns the Context object for SDAccel target on given index +func SDAccel(index int32) Context { + return Context{KDLSDAccel, index} +} + +// Vulkan returns the Context object for Vulkan target on given index +func Vulkan(index int32) Context { + return Context{KDLVulkan, index} +} + +// OpenGL returns the Context object for OpenGL target on given index +func OpenGL(index int32) Context { + return Context{KOpenGL, index} +} diff --git a/golang/src/error.go b/golang/src/error.go new file mode 100644 index 000000000000..00a24652953c --- /dev/null +++ b/golang/src/error.go @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package source for error related API interface. + * \file error.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +import ( + "unsafe" +) + +// getTVMLastError returns the detailed error string for any api called in TVM runtime. +// +// This is useful when any api returns non zero value. +// +// Returns golang string for the corresponding native error message. +func getTVMLastError() (retVal string) { + errStr := C.TVMGetLastError() + retVal = C.GoString(errStr) + return +} + +func setTVMLastError(errStr string) { + cstr := C.CString(errStr) + C.TVMAPISetLastError(cstr) + C.free(unsafe.Pointer(cstr)) +} diff --git a/golang/src/error_test.go b/golang/src/error_test.go new file mode 100644 index 000000000000..2a8c345b424b --- /dev/null +++ b/golang/src/error_test.go @@ -0,0 +1,28 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package + * \file error_test.go + */ + + +package gotvm + +import ( + "testing" + "strings" +) + +// Check err receiving from TVM global function. +func TestErrorTest(t *testing.T) { + _, err := LoadModuleFromFile("dummy.so") + if err == nil { + t.Error("Expected an error, but not received\n") + return + } + + errStr := err.Error() + if !(strings.Contains(errStr, string("cannot open shared object"))) { + t.Error("Ah! TVM didn't report an error\n") + } +} + diff --git a/golang/src/function.go b/golang/src/function.go new file mode 100644 index 000000000000..fa1c53a5917f --- /dev/null +++ b/golang/src/function.go @@ -0,0 +1,365 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package source for TVMFunction interface. + * \file function.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +import ( + "unsafe" + "encoding/binary" + "errors" + "runtime" + "reflect" + "fmt" +) + +// Function type in golang hold pointer for the TVMFunction handle. +type Function uintptr + +// nativeCPtr returns type freed uintptr for the Function. +func (tvmfunction Function) nativeCPtr() (retVal uintptr) { + retVal = (uintptr)(tvmfunction) + return +} + +// Invoke calls the TVM packed function referred by the handle with given arguments. +func (tvmfunction *Function) Invoke(args ...interface{}) (retVal *Value, err error) { + funccall := func (fargs ...interface{}) (*Value, error) { + return callNativeFunction(tvmfunction, fargs) + } + // Check is any args are contain any ValueArray + // Possible is it's a args forward from one packed function to another. + valueArrayFound := false + for ii := range args { + switch args[ii].(type) { + case []*Value: + valueArrayFound = true + } + } + + if !valueArrayFound { + return funccall(args...) + } + if len(args) != 1 { + err = fmt.Errorf("Not supported if packed function args are a mix of []Value and other types") + return + } + + valArray := args[0].([]*Value) + if len(valArray) > 0 { + newArgs := make([]interface{}, len(valArray)) + for ii := range valArray { + newVal := newTVMValue() + newVal.moveFrom(valArray[ii]) + newArgs[ii] = newVal + } + + return funccall(newArgs...) + } + return funccall() +} + +// FuncListGlobalNames is used to query global callable packed function names from TVM. +// +// returns slice of string holding function names and error if any. +func FuncListGlobalNames() (retVal []string, err error) { + var str string + ret := (int32)(C._TVMFuncListGlobalNames(unsafe.Pointer((&str)))) + if ret != 0 { + err = errors.New(getTVMLastError()) + return + } + + str = goStringFromNative(*(*string)(unsafe.Pointer(&str))) + bin := binary.LittleEndian + size := bin.Uint64([]byte(str[:8])) + str = str[8:] + retVal = make([]string, size) + for i := range retVal { + len := bin.Uint64([]byte(str[:8])) + str = str[8:] + retVal[i] = str[:len] + str = str[len:] + } + return +} + +// GetGlobalFunction is to get handle to the given global function name. +// +// `funcname` is the name of global packed function. +// +// returns a function closure with signature +// func (args ...interface{}) (interface{}, error) and error if any. +// +// The closure function can be used to call Function with arguments directly. +// +// Variadic arguments can be any type which can be embed into Value. +func GetGlobalFunction(funcname string) (retVal *Function, err error) { + var funp uintptr + + cfuncname := C.CString(funcname) + ret := (int32)(C.TVMFuncGetGlobal(cfuncname, + (*_Ctype_TVMFunctionHandle)(unsafe.Pointer(&funp)))) + C.free(unsafe.Pointer(cfuncname)) + + if ret != 0 { + err = errors.New(getTVMLastError()) + return + } + + handle := new(Function) + *handle = Function(funp) + finalizer := func(fhandle *Function) { + nativeTVMFuncFree(fhandle) + fhandle = nil + } + runtime.SetFinalizer(handle, finalizer) + retVal = handle + return +} + +// callNativeFunction is routine which calls gotvm native wrapper with given arguments. +// +// `handle` is the handle for Function. +// +// `args` are the variadic arguments to the Function. +// +// returns the interface for the return value from TVM if any and error if any. +func callNativeFunction(handle *Function, args []interface{}) (retVal *Value, err error) { + argsIn := make([]*Value, len(args)) + var typeCodes []int32 + if len(args) != 0 { + typeCodes = make([]int32, len(args)) + } else { + typeCodes = make([]int32, 1) + } + + for ii := range args { + argsIn[ii] = newTVMValue() + if typeCodes[ii], err = argsIn[ii].setValue(args[ii]); err != nil { + return + } + } + + retVal = newTVMValue() + argsOut := []*Value{retVal} + retTypeCode := KNull + err = nativeTVMFuncCall(handle, argsIn, typeCodes, argsOut, &retTypeCode) + if err != nil { + retVal = nil + return + } + retVal.isLocal = false + retVal.dtype = retTypeCode + return +} + +// nativeTVMFuncFree free the function handle allocated in TVM runtime. +// +// `funp` is the Function handle to be freed. +func nativeTVMFuncFree(funp *Function) (retVal int32) { + retVal = (int32) (C.TVMFuncFree(C.TVMFunctionHandle(funp.nativeCPtr()))) + return +} + +// nativeToGoSlice converts native TVMValue array to Golang slice of TVMValue +// +// +func nativeToGoSlice(nargValues (*C.void), argValues []*Value, typeCodes []int32) { + for ii := range argValues { + C._TVMValueNativeGet(unsafe.Pointer(argValues[ii].nativeCPtr()), + unsafe.Pointer(nargValues), + C.int(int32(ii))) + argValues[ii].dtype = typeCodes[ii] + } +} + +// nativeFromGoSlice converts golang slice of TVMValue to native TVMValue array. +// +// +func nativeFromGoSlice(argValues []*Value) (nptr (*C.void)) { + nargValues := ((uintptr)(C.malloc(C.ulong(C.sizeof_TVMValue * len(argValues))))) + for ii := range argValues { + C._TVMValueNativeSet(unsafe.Pointer(nargValues), + unsafe.Pointer(argValues[ii].nativeCPtr()), + C.int(int32(ii))) + } + nptr = (*C.void)(unsafe.Pointer(nargValues)) + return +} + +// nativeTVMFuncCall executes the function with given arguments +// +// `funp` Function handle to the packed function. +// +// `argValues` is the slice of Value which are arguments to the packed function. +// +// `typeCodes` is the alice of argument type codes corresponding to argValues. +// +// `retValues` is return argument which is slice of return values from the packed function. +// +// `retTypeCode` is int32 holding type codes for retValue +// +// Returns err indicating native error if any. +func nativeTVMFuncCall(funp *Function, argValues []*Value, typeCodes []int32, + retValues []*Value, retTypeCode *int32) (err error) { + nargValues := nativeFromGoSlice(argValues) + nretValues := nativeFromGoSlice(retValues) + result := (int32)(C.TVMFuncCall(_Ctype_TVMFunctionHandle(*funp), + (*_Ctype_TVMValue)(unsafe.Pointer(nargValues)), + (*_Ctype_int)(unsafe.Pointer(&(typeCodes[0]))), + C.int(len(argValues)), + (*_Ctype_TVMValue)(unsafe.Pointer(nretValues)), + (*_Ctype_int)(unsafe.Pointer(retTypeCode)))) + nativeToGoSlice(nargValues, argValues, typeCodes) + nativeToGoSlice(nretValues, retValues, (*[1<<31] int32)(unsafe.Pointer(retTypeCode))[:1:1]) + C.free(unsafe.Pointer(nargValues)) + C.free(unsafe.Pointer(nretValues)) + + if result != 0 { + err = errors.New(getTVMLastError()) + } + return +} + +// goCallBack is a structure holding the go callback function pointer. +// This wrapping is necessary as cgo doesn't support +// passing golang functions type conversion to native. +type goCallBack struct { + cb func (args ...*Value) (interface{}, error) +} + +//export goTVMCallback +func goTVMCallback(args C.native_voidp, typeCodes C.native_voidp, numArgs int32, + retArg C.native_voidp, resourceHandle C.native_voidp) (ret int32){ + fcb := (*goCallBack)(resourceHandle) + // Make Value Sice from native TVMValue pointer. + argValues := make([]*Value, numArgs) + + for ii := range argValues { + argValues[ii] = newTVMValue() + argValues[ii].isLocal = false + } + + // Prepare arguments for golang callback function + nativeToGoSlice((*C.void)(unsafe.Pointer(args)), argValues, + (*[1<<31] int32)(unsafe.Pointer(typeCodes))[:numArgs:numArgs]) + cbargs := argValues + + // Execute the callback + retVal, err := fcb.cb(cbargs...) + if err != nil { + errStr := err.Error() + setTVMLastError(errStr) + return -1 + } + + // It's possible a packed function directly return + // the return value of another packed function. + // + // Inside a packed func : + // ```return pfunc.Invoke(args)``` + // + // In this case pfunc returns nil which is + // returned as an interface holding nil *Value. + // Which becomes a valid retVal holding nil *Value. + isRetNull := false + switch retVal.(type) { + case *Value: + pRet := retVal.(*Value) + if pRet == nil { + isRetNull = true + } + } + + // Handle return value from callback function + if retVal != nil && !isRetNull { + var retTypeCode int32 + retValues := []*Value{newTVMValue()} + + retTypeCode, err = retValues[0].setValue(retVal) + if err != nil { + errStr := err.Error() + setTVMLastError(errStr) + return -1 + } + nretValues := nativeFromGoSlice(retValues) + + // Handle KStr, KBytes: Local finalizers shouldn't try freeing them. + retValues[0].isLocal = false + + apiRet := (int32) (C.TVMCFuncSetReturn(_Ctype_TVMRetValueHandle(retArg), + (*_Ctype_TVMValue)(unsafe.Pointer(nretValues)), + (*_Ctype_int)(unsafe.Pointer(&retTypeCode)), 1)) + C.free(unsafe.Pointer(nretValues)) + if apiRet != 0 { + errStr := string("TVMCFuncSetReturn failed ") + setTVMLastError(errStr) + } + } + return +} + +// ConvertFunction converts given golang function to TVM packed function. +// +// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})``` +// +// Returns Function handle and err if any. +func ConvertFunction(args ...interface{}) (retVal *Function, err error) { + function := args[0].(func (args ...*Value) (interface{}, error)) + fcb := &goCallBack{cb:function} + var funp uintptr + + result := (int32) (C._ConvertFunction(unsafe.Pointer(fcb), + unsafe.Pointer(&funp))) + if result != 0 { + err = errors.New(getTVMLastError()) + } + + handle := new(Function) + *handle = Function(funp) + finalizer := func(fhandle *Function) { + nativeTVMFuncFree(fhandle) + fhandle = nil + } + runtime.SetFinalizer(handle, finalizer) + retVal = handle + return +} + +// RegisterFunction registers the golang func in TVM runtime global space. +// +// `args[0]` function pointer for a type ```func (args ...interface{}) (interface{})``` +// +// `args[1]` Optional argument of function name with which it will be registered. +// If not passed we use function name from reflection. +// +// Returns err indicating native error if any. +func RegisterFunction(args ...interface{}) (err error) { + fhandle, err := ConvertFunction(args...) + if err != nil { + return + } + + funcname := runtime.FuncForPC(reflect.ValueOf(args[0]).Pointer()).Name() + if len(args) > 1 { + funcname = args[1].(string) + } + + cfuncname := C.CString(funcname) + result := (int32) (C.TVMFuncRegisterGlobal(cfuncname, + _Ctype_TVMFunctionHandle(*fhandle), + 0)); // Override = False + C.free(unsafe.Pointer(cfuncname)) + if result != 0 { + err = errors.New(getTVMLastError()) + } + // Clear the finalizer as we don't need to control it anymore. + runtime.SetFinalizer(fhandle, nil) + return +} diff --git a/golang/src/function_test.go b/golang/src/function_test.go new file mode 100644 index 000000000000..d53822837220 --- /dev/null +++ b/golang/src/function_test.go @@ -0,0 +1,331 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package + * \file function_test.go + */ + +package gotvm + +import ( + "testing" + "reflect" + "math/rand" + "strings" + "fmt" +) + +// Check global function list API +func TestFunctionGlobals(t *testing.T) { + funcNames, err := FuncListGlobalNames() + if err != nil { + t.Error(err.Error()) + return + } + if len(funcNames) < 1 { + t.Errorf("Global Function names received:%v\n", funcNames) + } +} + +// Check GetFunction API +func TestFunctionGlobalGet(t *testing.T) { + funp, err := GetGlobalFunction("tvm.graph_runtime.create") + if err != nil { + t.Error(err.Error()) + return + } + if reflect.TypeOf(funp).Kind() != reflect.Ptr { + t.Error("Function type mis matched\n") + return + } +} + +func TestFunctionModuleGet(t *testing.T) { + modp, err := LoadModuleFromFile("./deploy.so") + if err != nil { + t.Error(err.Error()) + return + } + funp, err := modp.GetFunction("myadd") + if err != nil { + t.Error(err.Error()) + return + } + if reflect.TypeOf(funp).Kind() != reflect.Ptr { + t.Error("Function type mis matched\n") + return + } + + dlen := int64(1024) + shape := []int64{dlen} + inX, _ := Empty(shape) + inY, _ := Empty(shape) + out, _ := Empty(shape) + dataX := make([]float32, (dlen)) + dataY := make([]float32, (dlen)) + outExpected := make([]float32, (dlen)) + + for i := range dataX { + dataX[i] = rand.Float32() + dataY[i] = rand.Float32() + outExpected[i] = dataX[i] + dataY[i] + } + + inX.CopyFrom(dataX) + inY.CopyFrom(dataY) + + funp.Invoke(inX, inY, out) + outi, _ := out.AsSlice() + outSlice := outi.([]float32) + if len(outSlice) != len(outExpected) { + t.Errorf("Data expected Len: %v Got :%v\n", len(outExpected), len(outSlice)) + return + } + for i := range outSlice { + if outExpected[i] != outSlice[i] { + t.Errorf("Data expected: %v Got :%v at index %v\n", outExpected[i], outSlice[i], i) + return + } + } +} + +// Check FunctionConvert API +func TestFunctionConvert(t *testing.T) { + sampleCb := func (args ...*Value) (retVal interface{}, err error) { + val1 := args[0].AsInt64() + val2 := args[1].AsInt64() + retVal = int64(val1+val2) + return + } + + fhandle, err := ConvertFunction(sampleCb) + if err != nil { + t.Error(err.Error()) + return + } + + retVal, err := fhandle.Invoke(10, 20) + if err != nil { + t.Error(err.Error()) + return + } + + if retVal.AsInt64() != int64(30) { + t.Errorf("Expected result :30 got:%v\n", retVal.AsInt64()) + return + } +} + +func TestFunctionError(t *testing.T) { + sampleCb := func (args ...*Value) (retVal interface{}, err error) { + err = fmt.Errorf("Sample Error XYZABC"); + return + } + + fhandle, err := ConvertFunction(sampleCb) + if err != nil { + t.Error(err.Error()) + return + } + + _, err = fhandle.Invoke() + if err == nil { + t.Error("Expected error but didn't received\n") + return + } + + if !strings.Contains(err.Error(), string("Sample Error XYZABC")) { + t.Errorf("Expected Error should contain :\"Sample Error XYZABC\" got :%v\n", err.Error()) + } +} + +// Check FunctionRegister +func TestFunctionRegister(t *testing.T) { + sampleCb := func (args ...*Value) (retVal interface{}, err error) { + val1 := args[0].AsInt64() + val2 := args[1].AsInt64() + retVal = int64(val1+val2) + return + } + + RegisterFunction(sampleCb, "TestFunctionRegister.sampleCb"); + // Query global functions available + funcNames, err := FuncListGlobalNames() + if err != nil { + t.Error(err.Error()) + return + } + + found := 0 + for ii := range (funcNames) { + if strings.Compare(funcNames[ii], "TestFunctionRegister.sampleCb") == 0 { + found = 1 + } + } + if found == 0 { + t.Error("Registered function not found in global function list.") + return + } + + // Get "sampleCb" and verify the call. + funp, err := GetGlobalFunction("TestFunctionRegister.sampleCb") + if err != nil { + t.Error(err.Error()) + return + } + + // Call function + result, err := funp.Invoke((int64)(10), (int64)(20)) + if err != nil { + t.Error(err.Error()) + return + } + if result.AsInt64() != int64(30) { + t.Errorf("Expected result :30 got:%v\n", result.AsInt64()) + return + } +} + +// Check packed function receiving go-closure as argument. +func TestFunctionClosureArg(t *testing.T) { + // sampleFunctionArg receives a Packed Function handle and calls it. + sampleFunctionArg := func (args ...*Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + + // Call Packed Function by Value + ret, err := pfunc.Invoke(args[1], args[2]) + if err != nil { + return + } + + // Call Packed Function with extracted values + ret1, err := pfunc.Invoke(args[1].AsInt64(), args[2].AsInt64()) + if err != nil { + return + } + if ret1.AsInt64() != ret.AsInt64() { + err = fmt.Errorf("Invoke with int64 didn't match with Value\n") + return + } + retVal = ret + return + } + + RegisterFunction(sampleFunctionArg, "TestFunctionClosureArg.sampleFunctionArg"); + funp, err := GetGlobalFunction("TestFunctionClosureArg.sampleFunctionArg") + if err != nil { + t.Error(err.Error()) + return + } + + // funccall is a simple golang callback function like C = A + B. + funccall := func (args ...*Value) (retVal interface{}, err error) { + val1 := args[0].AsInt64() + val2 := args[1].AsInt64() + retVal = int64(val1+val2) + return + } + + // Call function + result, err := funp.Invoke(funccall, 30, 50) + if err != nil { + t.Error(err.Error()) + return + } + + if result.AsInt64() != int64(80) { + t.Errorf("Expected result :80 got:%v\n", result.AsInt64()) + return + } +} + +// Check packed function returning a go-closure. +func TestFunctionClosureReturn(t *testing.T) { + // sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. + sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) { + funccall := func (cargs ...*Value) (fret interface{}, ferr error) { + val1 := cargs[0].AsInt64() + val2 := cargs[1].AsInt64() + fret = int64(val1+val2) + return + } + retVal = funccall + return + } + + RegisterFunction(sampleFunctionCb, "TestFunctionClosureReturn.sampleFunctionCb"); + funp, err := GetGlobalFunction("TestFunctionClosureReturn.sampleFunctionCb") + if err != nil { + t.Error(err.Error()) + return + } + + // Call function + result, err := funp.Invoke() + if err != nil { + t.Error(err.Error()) + return + } + + pfunc := result.AsFunction() + pfuncRet, err := pfunc.Invoke(30, 40) + if err != nil { + t.Error(err.Error()) + return + } + if pfuncRet.AsInt64() != int64(70) { + t.Errorf("Expected result :70 got:%v\n", pfuncRet.AsInt64()) + return + } +} + +// Check packed function with no arguments and no return values. +func TestFunctionNoArgsReturns(t *testing.T) { + sampleFunction := func (args ...*Value) (retVal interface{}, err error) { + return + } + + fhandle, err := ConvertFunction(sampleFunction) + if err != nil { + t.Error(err.Error()) + return + } + + _, err = fhandle.Invoke() + if err != nil { + t.Error(err.Error()) + return + } +} + +// Check packed function returning a go-closure with no arg and returns. +func TestFunctionNoArgsReturns2(t *testing.T) { + // sampleFunctionCb returns a function closure which is embed as packed function in TVMValue. + sampleFunctionCb := func (args ...*Value) (retVal interface{}, err error) { + funccall := func (cargs ...*Value) (fret interface{}, ferr error) { + return + } + retVal = funccall + return + } + + funp, err := ConvertFunction(sampleFunctionCb) + if err != nil { + t.Error(err.Error()) + return + } + + // Call function + result, err := funp.Invoke() + if err != nil { + t.Error(err.Error()) + return + } + + pfunc := result.AsFunction() + _, err = pfunc.Invoke() + if err != nil { + t.Error(err.Error()) + return + } +} diff --git a/golang/src/gotvm.cc b/golang/src/gotvm.cc new file mode 100644 index 000000000000..cf84e670df79 --- /dev/null +++ b/golang/src/gotvm.cc @@ -0,0 +1,195 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm native interface definition + * \file gotvm.cxx + */ + +// Standard includes +#include +#include +#include +#include +#include +#include + +// golang string compatible definition +typedef struct { char *p; int n; } _gostring_; +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// TVM runtime C interface +#include +#include + +/*! + * \brief Convert native char array to _gostring_ structure. + * _gostring_ structure represents the same memory footprint as golang string object. + * + * \param p is char pointer to a char array. + * \param l is the size of the char array. this method exclusively need length as + * its possible to have a bytearray in a string. + * + * \return _gostring_ object corresponding to native char array. + * Caller is responsible to free the memory block allocated here. + */ +static _gostring_ _native_to_gostring(const char *p, size_t l) { + _gostring_ ret; + ret.p = reinterpret_cast(malloc(l)); + if (NULL == ret.p) { + ret.n = 0; + return ret; + } + memcpy(ret.p, p, l); + ret.n = l; + return ret; +} + +/*! + * \brief embeds a 64bit uint value inside a string to serialize the data. + * + * \param s is string object. + * \param off is the offset in the string object. + * \param v is the uint64_t value which need to embed into given string. + */ +static void putuint64(std::string *s, size_t off, uint64_t v) { + for (int i = 0; i < 8; i++) { + (*s)[off + i] = (v >> (i * 8)) & 0xff; + } +} + +// TVM runtime C interface wrappers + +/*! + * \brief Native interface to query TVM_VERSION in golang string format. + * + * \return char pointer to TVM-VERSION + */ +const char* _TVM_VERSION(void) { + const char *version = TVM_VERSION; + return version; +} + +/*! + * \brief Native interface for getting TVMGlobal function list. + * + * \param names return by argument to return the function names. + * We wrap all strings into single string joined by (len+string) + * which is unpacked and processed in golang. + * + * \return c_runtime_api return status. + */ +int _TVMFuncListGlobalNames(_gostring_* names) { + int names_size; + char **names_array; + int result; + + result = TVMFuncListGlobalNames(&names_size, (char const ***)&names_array); + if (result) { + return result; + } + + size_t tot = 8; + for (int ii = 0; ii < names_size ; ++ii) { + tot += 8 + strlen(names_array[ii]); + } + + std::string str; + str.resize(tot); + putuint64(&str, 0, names_size); + size_t off = 8; + for (int64_t ii = 0; ii < names_size ; ++ii) { + putuint64(&str, off, strlen(names_array[ii])); + off += 8; + str.replace(off, strlen(names_array[ii]), names_array[ii]); + off += strlen(names_array[ii]); + } + *names = _native_to_gostring(str.data(), str.size()); + if (str.size() != names->n) { + TVMAPISetLastError("malloc failed during _native_to_gostring"); + result = 1; + } + return result; +} + +// Helpers for TVMValue + +/*! + * \brief Native helper to copy TVMValue from golang slice to native array. + * this helper is need as underlying momory for golang slice is not continueous. + * + * \param to_ptr is the native pointer of TVMValue array. + * \param from_ptr pointer to TVMValue in golang slice. + * \param array index in native array. + */ +void _TVMValueNativeSet(void* to_ptr, void* from_ptr, int ind) { + TVMValue *from_p = reinterpret_cast(from_ptr); + TVMValue *to_p = reinterpret_cast(to_ptr); + memcpy(to_p+ind, from_p, sizeof(TVMValue)); +} + +/*! + * \brief Native helper to copy TVMValue from golang slice to native array. + * this helper is need as underlying momory for golang slice is not continueous. + * + * \param to_ptr pointer to TVMValue in golang slice. + * \param from_ptr is the native pointer of TVMValue array. + * \param array index in native array. + */ +void _TVMValueNativeGet(void* to_ptr, void* from_ptr, int ind) { + TVMValue *from_p = reinterpret_cast(from_ptr); + TVMValue *to_p = reinterpret_cast(to_ptr); + memcpy(to_p, from_p+ind, sizeof(TVMValue)); +} + +extern int goTVMCallback(void*, void*, int, void*, void*); + +/*! + * \brief _TVMCallback is the TVM runtime callback function for PackedFunction system. + * + * \param args is an array of TVMValue + * \param type_codes is an array of int + * \param num_args is int representing number of in arguments + * \param ret is the return value handle to set the packed function return. + * \param resource_handle is the golang private data pointer. + * + * \returns the error status as TVM_DLL + */ +int _TVMCallback(TVMValue* args, + int* type_codes, + int num_args, + TVMRetValueHandle ret, + void* resource_handle) { + return goTVMCallback(args, type_codes, num_args, ret, resource_handle); +} + +/*! + * _TVMPackedCFuncFinalizer is finalizer for packed function system. + * + */ +void _TVMPackedCFuncFinalizer(void* resource_handle) { + return; +} + +/*! + * /brief _ConvertFunction creates a packed function for with given resource handle. + * + * /param fptr is the pointer to golang resource handle. + * /param *fhandle is the return argument holding packed function. + * + * /return is an int indicating the return status. + */ +int _ConvertFunction(void* fptr, TVMFunctionHandle *fhandle) { + int ret = TVMFuncCreateFromCFunc(_TVMCallback, + fptr, + _TVMPackedCFuncFinalizer, + fhandle); + return ret; +} + +#ifdef __cplusplus +} +#endif + diff --git a/golang/src/gotvm.go b/golang/src/gotvm.go new file mode 100644 index 000000000000..3f7aac93d769 --- /dev/null +++ b/golang/src/gotvm.go @@ -0,0 +1,24 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package + * \file gotvm.go + */ + + +// Package gotvm is TVM runtime interface definition for golang. +// +// Application need to import this package to access the c_runtime_api exposed by TVM. +package gotvm + +//#include "gotvm.h" +import "C" + +// DLPackVersion is the dlpack version of tvm runtime. +var DLPackVersion = int(C.DLPACK_VERSION) +// TVMVersion is the TVM runtime version. +var TVMVersion = getTVMVersion() + +func getTVMVersion() (retStr string) { + retStr = C.GoString(C._TVM_VERSION()) + return +} diff --git a/golang/src/gotvm.h b/golang/src/gotvm.h new file mode 100644 index 000000000000..e4487a362cca --- /dev/null +++ b/golang/src/gotvm.h @@ -0,0 +1,42 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm native interface declaration. + * \file gotvm.h + * + * These declarations are in cgo interface definition while calling API + * across golang and native C boundaries. + */ + +#ifndef GOTVM_GOTVM_H_ +#define GOTVM_GOTVM_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +#include +#include +#include +#include +#include + +// Some type definitions for golang "C" +typedef void* native_voidp; + +// Version +extern char* _TVM_VERSION(void); + +// Wrappers : For incompatible cgo API. +// To handle array of strings wrapped into __gostring__ +extern int _TVMFuncListGlobalNames(void*); +// To handle TVMValue slice to/from native sequential TVMValue array. +extern void _TVMValueNativeSet(void* to, void* from, int index); +extern void _TVMValueNativeGet(void* to, void* from, int index); + +// Callbacks +extern int _ConvertFunction(void* fptr, void* funp); + +#ifdef __cplusplus +} +#endif +#endif // GOTVM_GOTVM_H_ diff --git a/golang/src/gotvm_test.go b/golang/src/gotvm_test.go new file mode 100644 index 000000000000..5058de400ba7 --- /dev/null +++ b/golang/src/gotvm_test.go @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package + * \file gotvm_test.go + */ + + +package gotvm + +import ( + "testing" + "reflect" +) + +// Check TVMVersion API +func TestTVMVersion(t *testing.T) { + if len(TVMVersion) == 0 { + t.Error("TVMVersion not set\n") + } + if reflect.TypeOf(TVMVersion).Kind() != reflect.String { + t.Error("TVMVersion type mismatch\n") + } +} + +// Check DLPackVersion API +func TestDLPackVersion(t *testing.T) { + if reflect.TypeOf(DLPackVersion).Kind() != reflect.Int { + t.Error("TVMVersion type mismatch\n") + } +} diff --git a/golang/src/module.go b/golang/src/module.go new file mode 100644 index 000000000000..422cb6be20ff --- /dev/null +++ b/golang/src/module.go @@ -0,0 +1,121 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package source for TVMModule interface. + * \file module.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +import ( + "errors" + "runtime" + "unsafe" +) + +// Module type in golang hold pointer for the TVMModule handle. +// +// Module initialization happen through TVMModLoadFromFile api in TVM runtime. +type Module uintptr + +// nativeCPtr returns type freed uintptr for the Module. +func (tvmmodule *Module) nativeCPtr() (retVal uintptr) { + retVal = (uintptr)(*tvmmodule) + return +} + +// LoadModuleFromFile loads the given module in TVM runtime. +// +// `modpath` is the path to tvm module. +// +// `args` is an optional arguments of ["dll", "dylib", "dso", "so"] with default value "so" +// +// returns pointer to Module and err or if any. +func LoadModuleFromFile(modpath string, args ...interface{}) (retVal *Module, err error) { + modtype := "so" + if len(args) > 0 { + modtype = args[0].(string) + } + var modp uintptr + + cmodpath := C.CString(modpath) + cmodtype := C.CString(modtype) + + ret := (int32)(C.TVMModLoadFromFile(cmodpath, + cmodtype, + (*_Ctype_TVMModuleHandle)(unsafe.Pointer(&modp)))) + + C.free(unsafe.Pointer(cmodpath)) + C.free(unsafe.Pointer(cmodtype)) + + if ret != 0 { + err = errors.New(getTVMLastError()) + return + } + + handle := new(Module) + *handle = Module(modp) + finalizer := func(mhandle *Module) { + nativeTVMModFree(mhandle) + mhandle = nil + } + runtime.SetFinalizer(handle, finalizer) + retVal = handle + return +} + +// nativeTVMModFree free the module handle allocated in TVM runtime. +// +// `modp` is the Module handle to be freed. +func nativeTVMModFree(modp *Module) (retVal int32) { + retVal = (int32) (C.TVMModFree(C.TVMModuleHandle(modp.nativeCPtr()))) + return +} + +// GetFunction returns the function pointer from the module for given function name. +// +// `tvmmodule` is handle for Module +// +// `funcname` function name in module. +// +// `args` variadic args of `queryImport` +// +// returns function closure with signature +// func (args ...interface{}) (interface{}, error) and error if any. +// +// The closure function can be used to call Function with arguments directly. +// +// Variadic arguments can be any type which can be embed into Value. +func (tvmmodule *Module) GetFunction ( + funcname string, args ...interface{}) ( + retVal *Function, err error){ + queryImports := int32(1) + if len(args) > 0 { + queryImports = int32(args[1].(int)) + } + + var funp uintptr + cfuncname := C.CString(funcname) + ret := (int32)(C.TVMModGetFunction((_Ctype_TVMModuleHandle)(*tvmmodule), + cfuncname, + C.int(queryImports), + (*_Ctype_TVMFunctionHandle)(unsafe.Pointer(&funp)))) + C.free(unsafe.Pointer(cfuncname)) + + if ret != 0 { + err = errors.New(getTVMLastError()) + return + } + + handle := new(Function) + *handle = Function(funp) + finalizer := func(fhandle *Function) { + nativeTVMFuncFree(fhandle) + fhandle = nil + } + runtime.SetFinalizer(handle, finalizer) + retVal = handle + return +} diff --git a/golang/src/module_test.go b/golang/src/module_test.go new file mode 100644 index 000000000000..fac094438e96 --- /dev/null +++ b/golang/src/module_test.go @@ -0,0 +1,93 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package + * \file module_test.go + */ + + +package gotvm + +import ( + "testing" + "reflect" +) + +// Check module loading - dll +func TestModuleTestLoad1(t *testing.T) { + // dll + mod, err := LoadModuleFromFile("./deploy.so", "dll") + if err != nil { + t.Error(err.Error()) + return + } + if reflect.TypeOf(mod).Kind() != reflect.Ptr { + t.Error("Module type mis matched\n") + return + } +} + +// Check module loading - dylib +func TestModuleTestLoad2(t *testing.T) { + // dylib + mod, err := LoadModuleFromFile("./deploy.so", "dylib") + if err != nil { + t.Error(err.Error()) + return + } + if reflect.TypeOf(mod).Kind() != reflect.Ptr { + t.Error("Module type mis matched\n") + return + } +} + +func TestModuleTestLoad3(t *testing.T) { + // dso + mod, err := LoadModuleFromFile("./deploy.so", "dso") + if err != nil { + t.Error(err.Error()) + return + } + if reflect.TypeOf(mod).Kind() != reflect.Ptr { + t.Error("Module type mis matched\n") + return + } +} + +// Check module loading - so +func TestModuleTestLoad4(t *testing.T) { + // so + mod, err := LoadModuleFromFile("./deploy.so", "so") + if err != nil { + t.Error(err.Error()) + return + } + if reflect.TypeOf(mod).Kind() != reflect.Ptr { + t.Error("Module type mis matched\n") + return + } +} + +// Check module loading - default (so) +func TestModuleTestLoad5(t *testing.T) { + // default type as so + mod, err := LoadModuleFromFile("./deploy.so") + if err != nil { + t.Error(err.Error()) + return + } + if reflect.TypeOf(mod).Kind() != reflect.Ptr { + t.Error("Module type mis matched\n") + return + } +} + +// Check module loading err +func TestModuleTestLoadErr(t *testing.T) { + // Unknown file should return error + _, err := LoadModuleFromFile("xyzabc.so") + if err == nil { + t.Error("Expected an error, but not received\n") + return + } +} + diff --git a/golang/src/ndarray.go b/golang/src/ndarray.go new file mode 100644 index 000000000000..ceae7e58c203 --- /dev/null +++ b/golang/src/ndarray.go @@ -0,0 +1,329 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package source for TVMArray aka DLTensor + * \file ndarray.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +import ( + "unsafe" + "fmt" + "errors" + "runtime" + "reflect" +) + +// Array type in golang hold pointer for the TVMArray object from dlpack. +// +// Array initialization happen through Empty api +type Array uintptr + +// nativeCPtr returns type freed uintptr for the Array. +func (parray Array) nativeCPtr() (retVal uintptr) { + retVal = (uintptr)(parray) + return +} + +func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) { + ret := C.TVMArrayCopyFromBytes((*_Ctype_TVMArray)(unsafe.Pointer(parray.nativeCPtr())), + data, + C.ulong(datalen)) + if ret != 0 { + err = errors.New(getTVMLastError()) + } + return +} + +// CopyFrom copies given golang data slice into Array. +// +// `val` is interface homding a slice of Array data type. +// +// returns err is any. +// TOD: Use reflections for better handling +func (parray Array) CopyFrom(val interface{}) (err error) { + var data unsafe.Pointer + var datalen int + dtype := ((*_Ctype_TVMArray)(unsafe.Pointer(parray))).dtype + + switch val.(type) { + case []int8: + sliceVal := val.([]int8) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []int16: + sliceVal := val.([]int16) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []int32: + sliceVal := val.([]int32) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []int64: + sliceVal := val.([]int64) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []uint8: + sliceVal := val.([]uint8) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []uint16: + sliceVal := val.([]uint16) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []uint32: + sliceVal := val.([]uint32) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []uint64: + sliceVal := val.([]uint64) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []float32: + sliceVal := val.([]float32) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + case []float64: + sliceVal := val.([]float64) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + return parray.nativeCopyFrom(data, datalen) + default: + err = fmt.Errorf("Given type not supported : %v\n", reflect.TypeOf(val)) + return + } + return +} + +func (parray Array) nativeCopyTo (data unsafe.Pointer, datalen int) (err error){ + ret := C.TVMArrayCopyToBytes((*_Ctype_TVMArray)(unsafe.Pointer(parray.nativeCPtr())), + unsafe.Pointer(data), + C.ulong(datalen)) + + if ret != 0 { + err = errors.New(getTVMLastError()) + } + return +} + +// AsSlice returns the unitptr of for the data inside Array. +// +// returns the slice of array inside Array and err of any. +// TOD: Use reflections for better handling +func (parray Array) AsSlice() (retVal interface{}, err error) { + shape := parray.GetShape() + size := int64(1) + var data unsafe.Pointer + var datalen int + + for ii := range shape { + size *= shape[ii] + } + dtype := ((*_Ctype_TVMArray)(unsafe.Pointer(parray))).dtype + + switch parray.GetDType() { + case "int8": + sliceVal := make([]int8, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "int16": + sliceVal := make([]int16, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "int32": + sliceVal := make([]int32, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "int64": + sliceVal := make([]int64, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "uint8": + sliceVal := make([]uint8, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "uint16": + sliceVal := make([]uint16, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "uint32": + sliceVal := make([]uint32, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "uint64": + sliceVal := make([]uint64, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "float32": + sliceVal := make([]float32, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + case "float64": + sliceVal := make([]float64, size) + data = unsafe.Pointer(&sliceVal[0]) + datalen = len(sliceVal) * int(dtype.bits / 8) + err = parray.nativeCopyTo(data, datalen) + retVal = sliceVal + default: + err = fmt.Errorf("Given type not supported : %v\n", parray.GetDType()) + return + } + return +} + +// GetNdim returns the number of dimentions in Array +func (parray Array) GetNdim() (retVal int32) { + retVal = int32(((*_Ctype_TVMArray)(unsafe.Pointer(parray))).ndim) + return +} + +// GetShape returns the number of dimentions in Array +func (parray Array) GetShape() (retVal []int64) { + shapePtr := (*C.int64_t)(((*_Ctype_TVMArray)(unsafe.Pointer(parray))).shape) + ndim := parray.GetNdim() + + shapeSlice := (*[1<<31] int64)(unsafe.Pointer(shapePtr))[:ndim:ndim] + retVal = make([]int64, ndim) + copy(retVal, shapeSlice) + return +} + +// GetDType returns the number of dimentions in Array +func (parray Array) GetDType() (retVal string) { + ret := ((*_Ctype_TVMArray)(unsafe.Pointer(parray))).dtype + retVal, _ = dtypeFromTVMType(*(*pTVMType)(unsafe.Pointer(&ret))) + return +} + +// GetCtx returns the number of dimentions in Array +func (parray Array) GetCtx() (retVal Context) { + ret := ((*_Ctype_TVMArray)(unsafe.Pointer(parray))).ctx + retVal = *(*Context)(unsafe.Pointer(&ret)) + return +} + +// nativeTVMArrayAlloc is used to allocate TVMArray from given attributes. +// +// `shape` is int64 slice holding shape of the Array to be created. +// +// `ndim` is the rank of the Array to be created. +// +// `dtypeCode`, `dtypeBits` and `dtypeLanes` describe the data type in Array. +// +// `deviceType` indicates the device on whose memory the Array to allocated. +// +// `deviceID` indicates device index if multiple devices of same type present. +// +// return argument holding native pointer to newly created Array and error is any. +func nativeTVMArrayAlloc(shape []int64, ndim int32, + dtypeCode int32, dtypeBits int32, dtypeLanes int32, + deviceType int32, deviceID int32) (retVal uintptr, err error) { + ret := (int32)(C.TVMArrayAlloc((*_Ctype_long)(&(shape[0])), + C.int(ndim), + C.int(dtypeCode), + C.int(dtypeBits), + C.int(dtypeLanes), + C.int(deviceType), + C.int(deviceID), + (*_Ctype_TVMArrayHandle)(unsafe.Pointer(&retVal)))) + if ret != 0 { + err = errors.New(getTVMLastError()) + return + } + return +} + +// Empty is used to allocate TVM empty array of given epecification. +// +// `shape` is int64 slice holding shape of the Array +// +// `args` is variadic args for +// +// `args[0]` is string for data type. Default value is 'float32' +// +// `args[1]` is Context. Default value is '{KDLCPU, 0}' +// +// returns pointer to Array on successful execution and error if any. +func Empty(shape []int64, args ...interface{}) (parray *Array, err error) { + typeName := "float32" + ctx := Context{KDLCPU, 0} + + if len(shape) < 1 { + err = fmt.Errorf("Invalid shape for Array creation: %v\n", len(shape)) + return + } + + for i, val := range args { + switch val.(type) { + case string: + typeName = args[i].(string) + case Context: + ctx = args[i].(Context) + default: + err = fmt.Errorf("Invalid Optional Argument Type: %T\n", val) + return + } + } + + tvmType, err := dtypeToTVMType(typeName) + if err != nil { + return + } + ndim := int32(len(shape)) + newArray, err := nativeTVMArrayAlloc(shape, ndim, int32(tvmType.code), + int32(tvmType.bits), int32(tvmType.lanes), + ctx.DeviceType, ctx.DeviceID) + if err != nil { + return + } + handle := new(Array) + *handle = Array(newArray) + + finalizer := func (ahandle *Array) { + nativeTVMArrayFree(*ahandle) + ahandle = nil + } + runtime.SetFinalizer(handle, finalizer) + parray = handle + return +} + +// nativeTVMArrayFree is used to release the Array. +// +// `parray` is the Array handle. +// +// `ret` indicates the status of this api execution. +func nativeTVMArrayFree(parray Array) (retVal int32) { + retVal = (int32)(C.TVMArrayFree((*_Ctype_TVMArray)(unsafe.Pointer(parray.nativeCPtr())))) + return +} diff --git a/golang/src/tvm_runtime_pack.cc b/golang/src/tvm_runtime_pack.cc new file mode 100644 index 000000000000..718a79eb7445 --- /dev/null +++ b/golang/src/tvm_runtime_pack.cc @@ -0,0 +1,49 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief This is an all in one TVM runtime file. + * \file tvm_runtime_pack.cc + */ +#include "src/runtime/c_runtime_api.cc" +#include "src/runtime/cpu_device_api.cc" +#include "src/runtime/workspace_pool.cc" +#include "src/runtime/module_util.cc" +#include "src/runtime/module.cc" +#include "src/runtime/registry.cc" +#include "src/runtime/file_util.cc" +#include "src/runtime/threading_backend.cc" +#include "src/runtime/thread_pool.cc" +#include "src/runtime/ndarray.cc" + +// NOTE: all the files after this are optional modules +// that you can include remove, depending on how much feature you use. + +// Likely we only need to enable one of the following +// If you use Module::Load, use dso_module +// For system packed library, use system_lib_module +#include "src/runtime/dso_module.cc" +#include "src/runtime/system_lib_module.cc" + +// Graph runtime +#include "src/runtime/graph/graph_runtime.cc" + +// Uncomment the following lines to enable RPC +// #include "../../src/runtime/rpc/rpc_session.cc" +// #include "../../src/runtime/rpc/rpc_event_impl.cc" +// #include "../../src/runtime/rpc/rpc_server_env.cc" + +// These macros enables the device API when uncommented. +#define TVM_CUDA_RUNTIME 1 +#define TVM_METAL_RUNTIME 1 +#define TVM_OPENCL_RUNTIME 1 + +// Uncomment the following lines to enable Metal +// #include "../../src/runtime/metal/metal_device_api.mm" +// #include "../../src/runtime/metal/metal_module.mm" + +// Uncomment the following lines to enable CUDA +// #include "../../src/runtime/cuda/cuda_device_api.cc" +// #include "../../src/runtime/cuda/cuda_module.cc" + +// Uncomment the following lines to enable OpenCL +// #include "../../src/runtime/opencl/opencl_device_api.cc" +// #include "../../src/runtime/opencl/opencl_module.cc" diff --git a/golang/src/type.go b/golang/src/type.go new file mode 100644 index 000000000000..27364295bf8b --- /dev/null +++ b/golang/src/type.go @@ -0,0 +1,72 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package for TVMType interface + * \file type.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +import ( + "fmt" +) + +// pTVMType corresponding to data types. +type pTVMType struct { + code uint8 + bits uint8 + lanes uint16 +} + +// data type to pTVMType mapping +var dtypeMap = map[string] pTVMType { + "int8": pTVMType{0, 8, 1}, + "int16": pTVMType{0, 16, 1}, + "int32": pTVMType{0, 32, 1}, + "int64": pTVMType{0, 64, 1}, + "uint8": pTVMType{1, 8, 1}, + "uint16": pTVMType{1, 16, 1}, + "uint32": pTVMType{1, 32, 1}, + "uint64": pTVMType{1, 64, 1}, + "float32": pTVMType{2, 32, 1}, + "float64": pTVMType{2, 64, 1}, +} + +// dtypeFromTVMType return the pTVMType corresponding to given dtype +// +// `dtype` string for the given data type. +func dtypeFromTVMType(tvmtype pTVMType) (retVal string, err error) { + for k, v := range dtypeMap { + if v.code == tvmtype.code && v.bits == tvmtype.bits && v.lanes == tvmtype.lanes { + retVal = k + return + } + } + + err = fmt.Errorf("Cannot map TVMType:%v to dtype", tvmtype) + return +} + +// dtypeToTVMType return the pTVMType corresponding to given dtype +// +// `dtype` string for the given data type. +func dtypeToTVMType(args ...interface{}) (tvmtype pTVMType, err error) { + dtype := args[0].(string) + lanes := 1 + + if len(args) == 2 { + lanes = args[1].(int) + } + + for k, v := range dtypeMap { + if k == dtype { + tvmtype = v + tvmtype.lanes = uint16(lanes) + return + } + } + err = fmt.Errorf("Cannot map dtype:%v to TVMType", dtype) + return +} diff --git a/golang/src/util.go b/golang/src/util.go new file mode 100644 index 000000000000..aa5a6016c97f --- /dev/null +++ b/golang/src/util.go @@ -0,0 +1,24 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package source for common utilities + * \file util.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +import ( + "unsafe" +) + +// Native string map for go string +type nativeGoString struct { p uintptr; n int32 } + +func goStringFromNative (s string) (retStr string) { + p := *(*nativeGoString)(unsafe.Pointer(&s)) + retStr = string((*[0x7fffffff]byte)(unsafe.Pointer(p.p))[:p.n]) + C.free(unsafe.Pointer(p.p)) + return +} diff --git a/golang/src/value.go b/golang/src/value.go new file mode 100644 index 000000000000..2a953560f237 --- /dev/null +++ b/golang/src/value.go @@ -0,0 +1,360 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package source for TVMValue interface + * \file value.go + */ + +package gotvm + +//#include "gotvm.h" +import "C" + +import ( + "fmt" + "runtime" + "unsafe" +) + +// KHandle is golang type code for TVM enum kHandle. +var KHandle = int32(C.kHandle) +// KNull is golang type code for TVM kNull. +var KNull = int32(C.kNull) +// KTVMType is golang type code for TVM kTVMType. +var KTVMType = int32(C.kTVMType) +// KTVMContext is golang type code for TVM kTVMContext. +var KTVMContext = int32(C.kTVMContext) +// KArrayHandle is golang type code for TVM kArrayHandle. +var KArrayHandle = int32(C.kArrayHandle) +// KNodeHandle is golang type code for TVM kNodeHandle. +var KNodeHandle = int32(C.kNodeHandle) +// KModuleHandle is gonag type code for TVM kModuleHandle. +var KModuleHandle = int32(C.kModuleHandle) +// KFuncHandle is gonalg type code for TVM kFuncHandle. +var KFuncHandle = int32(C.kFuncHandle) +// KStr is golang type code for TVM kStr. +var KStr = int32(C.kStr) +// KBytes is golang type code for TVM kBytes. +var KBytes = int32(C.kBytes) +// KNDArrayContainer is golang typecode for kNDArrayContainer. +var KNDArrayContainer = int32(C.kNDArrayContainer) +// KExtBegin is golang enum corresponding to TVM kExtBegin. +var KExtBegin = int32(C.kExtBegin) +// KNNVMFirst is golang enum corresponding to TVM kNNVMFirst. +var KNNVMFirst = int32(C.kNNVMFirst) +// KNNVMLast is golang enum corresponding to TVM kNNVMLast. +var KNNVMLast = int32(C.kNNVMLast) +// KExtReserveEnd is golang enum corresponding to TVM kExtReserveEnd. +var KExtReserveEnd = int32(C.kExtReserveEnd) +// KExtEnd is golang enum corresponding to TVM kExtEnd. +var KExtEnd = int32(C.kExtEnd) +// KDLInt is golang type code for TVM kDLInt. +var KDLInt = int32(C.kDLInt) +// KDLUInt is golang type code for TVM kDLUInt. +var KDLUInt = int32(C.kDLUInt) +// KDLFloat is golang type code for TVM kDLFloat. +var KDLFloat = int32(C.kDLFloat) + +// Value Typemap for union exposed by TVM runtime API. +// +// gotvm maps it to a uintptr and then dynamically allocates memory by newTVMValue method. +type Value struct { + nptr uintptr + dtype int32 + isLocal bool +} + +// AsInt64 returns the int64 value inside the Value. +func (tvmval *Value) AsInt64() (retVal int64) { + retVal = tvmval.getVInt64() + return +} + +// AsFloat64 returns the Float64 value inside the Value. +func (tvmval *Value) AsFloat64() (retVal float64) { + retVal = tvmval.getVFloat64() + return +} + +// AsModule returns the Module inside the Value. +func (tvmval *Value) AsModule() (retVal *Module) { + mhandle := tvmval.getVMHandle() + retVal = &mhandle + return +} + +// AsFunction returns the Function inside the Value. +func (tvmval *Value) AsFunction() (retVal *Function) { + fhandle := tvmval.getVFHandle() + retVal = &fhandle + + return +} + +// AsBytes returns the byte slice value inside the Value. +func (tvmval *Value) AsBytes() (retVal []byte) { + retVal = tvmval.getVBHandle().getData() + return +} + +// AsStr returns the golang string in the Value. +func (tvmval *Value) AsStr() (retVal string) { + str := tvmval.getVStr() + retVal = str + return +} + +// nativeCPtr return the unitptr corresponding to Value type. +func (tvmval *Value) nativeCPtr() (ret uintptr) { + ret = (uintptr)(tvmval.nptr) + return +} + +// moveFrom copies the tvmval from other Value object. +func (tvmval *Value) moveFrom(fromval *Value) () { + C.memcpy(unsafe.Pointer(tvmval.nativeCPtr()), + unsafe.Pointer(fromval.nativeCPtr()), + C.sizeof_TVMValue) + + // Move the dtype too. + tvmval.dtype = fromval.dtype + fromval.dtype = KNull + return +} + +// setVInt64 initializes the Value object with given int64 value. +// +// `val` is the int64 value to initialize the Value +func (tvmval *Value) setVInt64(val int64) { + valp := (*C.int64_t)(unsafe.Pointer(tvmval.nativeCPtr())) + *valp = C.int64_t(val) + tvmval.dtype = KDLInt + return +} + + +// getVInt64 returns the int64 value inside the Value. +func (tvmval *Value) getVInt64() (retVal int64) { + valp := (*C.int64_t)(unsafe.Pointer(tvmval.nativeCPtr())) + retVal = int64(*valp) + return +} + +// setVFloat64 initializes the Value object with given float64 value. +// +// `val` is the float64 value to initialize the Value. +func (tvmval *Value) setVFloat64(val float64) { + valp := (*C.double)(unsafe.Pointer(tvmval.nativeCPtr())) + *valp = C.double(val) + tvmval.dtype = KDLFloat + return +} + +// getVFloat64 returns the float64 value inside Value. +func (tvmval *Value) getVFloat64() (retVal float64) { + valp := (*C.double)(unsafe.Pointer(tvmval.nativeCPtr())) + retVal = float64(*valp) + return +} + +// setVHandle initializes the handle inside the Value. +// +// Can be used to store any uintptr type object like +// module handle, function handle and any object's nativeCPtr. +// +// `val` is the uintptr type of given handle. +func (tvmval *Value) setVHandle(val uintptr) { + valp := (**C.void)(unsafe.Pointer(tvmval.nativeCPtr())) + *valp = (*C.void)(unsafe.Pointer(val)) +} + +// getVHandle returns the uintptr handle +func (tvmval *Value) getVHandle() (retVal uintptr) { + valp := (**C.void)(unsafe.Pointer(tvmval.nativeCPtr())) + retVal = uintptr(unsafe.Pointer(*valp)) + return +} + +// setVStr intializes the Value with given golang string object. +// +// `val` is the golang string object used to initialize the Value. +func (tvmval *Value) setVStr(val string) { + valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) + *valp = C.CString(val) + tvmval.dtype = KStr + return +} + + +// getVStr returns the golang string for the native string inside Value. +func (tvmval *Value) getVStr() (retVal string) { + valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) + retVal = C.GoString(*valp) + return +} + +// unSetVStr release the memory allocated in setVStr +func (tvmval *Value) unSetVStr() { + valp := (**C.char)(unsafe.Pointer(tvmval.nativeCPtr())) + C.free(unsafe.Pointer(*valp)) + tvmval.dtype = KNull +} + +// setVAHandle is used to set Array handle in Value. +// +// Application can call the setVHandle with nativeCPtr instead too. +// This is a wrapper to accept Array directly. +func (tvmval *Value) setVAHandle(ptvmarray Array) { + tvmval.setVHandle(ptvmarray.nativeCPtr()) + tvmval.dtype = KArrayHandle + return +} + +// getVAHandle is used to get Array handle in Value. +func (tvmval *Value) getVAHandle() (retVal Array) { + retVal = (Array)(tvmval.getVHandle()) + return +} + +// setVMHandle is used to set Module handle in Value. +// +// Application can call the setVHandle with nativeCPtr instead too. +// This is a wrapper to accept Module directly. +func (tvmval *Value) setVMHandle(tvmmodule Module) { + tvmval.setVHandle(tvmmodule.nativeCPtr()) + tvmval.dtype = KModuleHandle + return +} + +// getVMHandle is used to get Module handle in Value. +func (tvmval *Value) getVMHandle() (retVal Module) { + retVal = (Module)(tvmval.getVHandle()) + return +} + +// setVFHandle is used to set Function handle in Value. +// +// Application can call the setVHandle with nativeCPtr instead. +// This is a wrapper to accept Function directly. +func (tvmval *Value) setVFHandle(tvmfunction Function) { + tvmval.setVHandle(tvmfunction.nativeCPtr()) + tvmval.dtype = KFuncHandle + return +} + +// getVFHandle is used to get Function handle in Value. +func (tvmval *Value) getVFHandle() (retVal Function) { + retVal = (Function)(tvmval.getVHandle()) + return +} + +// setVBHandle is used to set ByteArray handle in Value. +// +// Application can call the setVHandle with nativeCPtr instead. +// This is a wrapper to accept ByteArray directly. +func (tvmval *Value) setVBHandle(tbytearray ByteArray) { + tvmval.setVHandle(tbytearray.nativeCPtr()) + tvmval.dtype = KBytes + return +} + +// getVBHandle is used to get ByteArray handle in Value. +func (tvmval *Value) getVBHandle() (retVal ByteArray) { + retVal = (ByteArray)(tvmval.getVHandle()) + return +} + +// setValue is used to set the given value in Value. +// +// `val` is value of types accepted by Value container or native union. +func (tvmval *Value) setValue(val interface{}) (retVal int32, err error) { + retVal = KNull + switch val.(type) { + case string: + tvmval.setVStr(val.(string)) + case uint8: + tvmval.setVInt64(int64(val.(uint8))) + case uint16: + tvmval.setVInt64(int64(val.(uint16))) + case uint32: + tvmval.setVInt64(int64(val.(uint32))) + case uint64: + tvmval.setVInt64(int64(val.(uint64))) + case int: + tvmval.setVInt64(int64(val.(int))) + case int8: + tvmval.setVInt64(int64(val.(int8))) + case int16: + tvmval.setVInt64(int64(val.(int16))) + case int32: + tvmval.setVInt64(int64(val.(int32))) + case int64: + tvmval.setVInt64(val.(int64)) + case float32: + tvmval.setVFloat64(float64(val.(float32))) + case float64: + tvmval.setVFloat64(val.(float64)) + case *Module: + tvmval.setVMHandle(*(val.(*Module))) + case *Function: + tvmval.setVFHandle(*(val.(*Function))) + case *ByteArray: + tvmval.setVBHandle(*(val.(*ByteArray))) + case []byte: + barray := newByteArray(val.([]byte)) + tvmval.setVBHandle(barray) + case *Array: + tvmval.setVAHandle(*(val.(*Array))) + case func (args ...*Value) (interface{}, error): + fhandle, apierr := ConvertFunction(val) + if apierr != nil { + err = fmt.Errorf("Given value Type not defined for Value: %v : %T\n", val, val); + return + } + tvmval.setVFHandle(*fhandle) + + // Clear the finalizer as we don't need to control it anymore. + runtime.SetFinalizer(fhandle, nil) + case *Value: + tvmval.moveFrom(val.(*Value)) + case Value: + fromval := val.(Value) + tvmval.moveFrom(&fromval) + default: + err = fmt.Errorf("Given value Type not defined for Value: %v : %T\n", val, val); + } + retVal = tvmval.dtype + return +} + +// newTVMValue initialize the TVMValue native object. +// +// This is intended to use as intermediate type between native and golang types. +// Allocated from FuncCall or Callback to handle conversions. +func newTVMValue() (retVal *Value) { + handle := new(Value) + + handle.nptr = (uintptr(C.malloc(C.sizeof_TVMValue))) + handle.dtype = KNull + handle.isLocal = true + finalizer := func(vhandle *Value) { + vhandle.deleteTVMValue() + vhandle = nil + } + runtime.SetFinalizer(handle, finalizer) + retVal = handle + return +} + +// deleteTVMValue free the native Value object which is allocated in newTVMValue. +func (tvmval Value) deleteTVMValue() { + if tvmval.isLocal == true { + if tvmval.dtype == KStr { + tvmval.unSetVStr() + } + if tvmval.dtype == KBytes { + tvmval.getVBHandle().deleteTVMByteArray() + } + } + + C.free(unsafe.Pointer(tvmval.nativeCPtr())) +} diff --git a/golang/src/value_test.go b/golang/src/value_test.go new file mode 100644 index 000000000000..251af82cb7b9 --- /dev/null +++ b/golang/src/value_test.go @@ -0,0 +1,237 @@ +/*! + * Copyright (c) 2018 by Contributors + * \brief gotvm package + * \file value_test.go + */ + +package gotvm + +import ( + "testing" + "math/rand" + "strings" +) + +// Check Int64 Value looping via packed function calling another packed function. +func TestValueLoopInt64(t *testing.T) { + // Receive a function Handle and argument and echo the Value on the handle. + sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + newArgs := args[1:] + + // Call Packed Function by Value + return pfunc.Invoke(newArgs) + } + + fhandle, err := ConvertFunction(sampleFunctionLoop) + if err != nil { + t.Error(err.Error()) + return + } + + // funccall is a simple golang callback function like C = A + B. + funccall := func (args ...*Value) (retVal interface{}, err error) { + retVal = args[0] + return + } + + result := rand.Int63() + retVal, err := fhandle.Invoke(funccall, result) + if err != nil { + t.Error(err.Error()) + return + } + if retVal.AsInt64() != result { + t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) + return + } +} + +// Check Int32 Value looping via packed function calling another packed function. +func TestValueLoopInt32(t *testing.T) { + // Receive a function Handle and argument and echo the Value on the handle. + sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + newArgs := args[1:] + + // Call Packed Function by Value + return pfunc.Invoke(newArgs) + } + + fhandle, err := ConvertFunction(sampleFunctionLoop) + if err != nil { + t.Error(err.Error()) + return + } + + // funccall is a simple golang callback function like C = A + B. + funccall := func (args ...*Value) (retVal interface{}, err error) { + retVal = args[0] + return + } + + result := rand.Int31() + retVal, err := fhandle.Invoke(funccall, result) + if err != nil { + t.Error(err.Error()) + return + } + + if retVal.AsInt64() != int64(result) { + t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) + return + } +} + +// Check Float32 Value looping via packed function calling another packed function. +func TestValueLoopFloat32(t *testing.T) { + // Receive a function Handle and argument and echo the Value on the handle. + sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + newArgs := args[1:] + // Call Packed Function by Value + return pfunc.Invoke(newArgs) + } + + fhandle, err := ConvertFunction(sampleFunctionLoop) + if err != nil { + t.Error(err.Error()) + return + } + + // funccall is a simple golang callback function like C = A + B. + funccall := func (args ...*Value) (retVal interface{}, err error) { + retVal = args[0] + return + } + + result := rand.Float32() + retVal, err := fhandle.Invoke(funccall, result) + if err != nil { + t.Error(err.Error()) + return + } + + if retVal.AsFloat64() != float64(result) { + t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) + return + } +} + +// Check Float64 Value looping via packed function calling another packed function. +func TestValueLoopFloat64(t *testing.T) { + // Receive a function Handle and argument and echo the Value on the handle. + sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + newArgs := args[1:] + // Call Packed Function by Value + return pfunc.Invoke(newArgs) + } + + fhandle, err := ConvertFunction(sampleFunctionLoop) + if err != nil { + t.Error(err.Error()) + return + } + + // funccall is a simple golang callback function like C = A + B. + funccall := func (args ...*Value) (retVal interface{}, err error) { + retVal = args[0] + return + } + + result := rand.Float64() + retVal, err := fhandle.Invoke(funccall, result) + if err != nil { + t.Error(err.Error()) + return + } + + if retVal.AsFloat64() != result { + t.Errorf("Expected : %v got:%v\n", result, retVal.AsInt64()) + return + } +} + +func TestValueLoopString(t *testing.T) { + // Receive a function Handle and argument and echo the Value on the handle. + sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + argStr := args[1].AsStr() + // Call Packed Function by Value + return pfunc.Invoke(argStr) + } + + fhandle, err := ConvertFunction(sampleFunctionLoop) + if err != nil { + t.Error(err.Error()) + return + } + + // funccall is a simple golang callback function like C = A + B. + funccall := func (args ...*Value) (retVal interface{}, err error) { + retVal = args[0].AsStr() + return + } + + retVal, err := fhandle.Invoke(funccall, "TestString") + if err != nil { + t.Error(err.Error()) + return + } + + vStr := retVal.AsStr() + if strings.Compare(vStr, string("TestString")) != 0 { + t.Errorf("Expected : %v got:%v\n", string("TestString"), vStr) + return + } +} + +// Check []byte Value looping via packed function calling another packed function. +func TestValueLoopByteSlice(t *testing.T) { + // Receive a function Handle and argument and echo the Value on the handle. + sampleFunctionLoop := func (args ...*Value) (retVal interface{}, err error) { + // Reveive Packed Function Handle + pfunc := args[0].AsFunction() + argBytes := args[1].AsBytes() + // Call Packed Function by Value + return pfunc.Invoke(argBytes) + } + + fhandle, err := ConvertFunction(sampleFunctionLoop) + if err != nil { + t.Error(err.Error()) + return + } + + // funccall is a simple golang callback function like C = A + B. + funccall := func (args ...*Value) (retVal interface{}, err error) { + retVal = args[0].AsBytes() + return + } + + result := make([]byte, 1024) + rand.Read(result) + retVal, err := fhandle.Invoke(funccall, result) + if err != nil { + t.Error(err.Error()) + return + } + + received := retVal.AsBytes() + if len(result) != len(received) { + t.Errorf("Data expected Len: %v Got :%v\n", len(result), len(received)) + return + } + for i := range result { + if result[i] != received[i] { + t.Errorf("Data expected: %v Got :%v at index %v\n", result[i], received[i], i) + return + } + } +}