Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions cmd/ncproxy/buckets.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package main

import (
bolt "go.etcd.io/bbolt"
)

const schemaVersion = "v1"

var (
bucketKeyVersion = []byte(schemaVersion)
bucketKeyComputeAgent = []byte("computeagent")
)

// Below is the current database schema. This should be updated any time the schema is
// changed or updated. The version should be incremented if breaking changes are made.
// └──v1 - Schema version bucket
// └──computeagent - Compute agent bucket
// └──containerID : <string> - Entry in compute agent bucket: Address to
// the compute agent for containerID

// taken from containerd/containerd/metadata/buckets.go
func getBucket(tx *bolt.Tx, keys ...[]byte) *bolt.Bucket {
bkt := tx.Bucket(keys[0])

for _, key := range keys[1:] {
if bkt == nil {
break
}
bkt = bkt.Bucket(key)
}

return bkt
}

// taken from containerd/containerd/metadata/buckets.go
func createBucketIfNotExists(tx *bolt.Tx, keys ...[]byte) (*bolt.Bucket, error) {
bkt, err := tx.CreateBucketIfNotExists(keys[0])
if err != nil {
return nil, err
}

for _, key := range keys[1:] {
bkt, err = bkt.CreateBucketIfNotExists(key)
if err != nil {
return nil, err
}
}

return bkt, nil
}

func createComputeAgentBucket(tx *bolt.Tx) (*bolt.Bucket, error) {
return createBucketIfNotExists(tx, bucketKeyVersion, bucketKeyComputeAgent)
}

func getComputeAgentBucket(tx *bolt.Tx) *bolt.Bucket {
return getBucket(tx, bucketKeyVersion, bucketKeyComputeAgent)
}
143 changes: 123 additions & 20 deletions cmd/ncproxy/ncproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,71 @@ type computeAgentCache struct {
// lock for synchronizing read/write access to `cache`
rw sync.RWMutex
// mapping of container ID to shim compute agent ttrpc service
cache map[string]computeagent.ComputeAgentService
cache map[string]*computeAgentClient
}

func newComputeAgentCache() *computeAgentCache {
return &computeAgentCache{
cache: make(map[string]computeagent.ComputeAgentService),
cache: make(map[string]*computeAgentClient),
}
}

func (c *computeAgentCache) get(cid string) (computeagent.ComputeAgentService, bool) {
func (c *computeAgentCache) getAllAndClear() ([]*computeAgentClient, error) {
// set c.cache to nil first so that subsequent attempts to reads and writes
// return an error
c.rw.Lock()
cacheCopy := c.cache
c.cache = nil
c.rw.Unlock()

if cacheCopy == nil {
return nil, errors.New("cannot read from a nil cache")
}

results := []*computeAgentClient{}
for _, agent := range cacheCopy {
results = append(results, agent)
}
return results, nil

}

func (c *computeAgentCache) get(cid string) (*computeAgentClient, error) {
c.rw.RLock()
defer c.rw.RUnlock()
result, ok := c.cache[cid]
return result, ok
if c.cache == nil {
return nil, errors.New("cannot read from a nil cache")
}
result := c.cache[cid]
return result, nil
}

func (c *computeAgentCache) put(cid string, agent computeagent.ComputeAgentService) {
func (c *computeAgentCache) put(cid string, agent *computeAgentClient) error {
c.rw.Lock()
defer c.rw.Unlock()
if c.cache == nil {
return errors.New("cannot write to a nil cache")
}
c.cache[cid] = agent
return nil
}

func (c *computeAgentCache) getAndDelete(cid string) (*computeAgentClient, error) {
c.rw.Lock()
defer c.rw.Unlock()
if c.cache == nil {
return nil, errors.New("cannot read from a nil cache")
}
result := c.cache[cid]
delete(c.cache, cid)
return result, nil
}

// GRPC service exposed for use by a Node Network Service.
type grpcService struct {
// containerIDToComputeAgent is a cache that stores the mappings from
// container ID to compute agent address is memory. This is repopulated
// on reconnect and referenced during client calls.
containerIDToComputeAgent *computeAgentCache
}

Expand All @@ -82,7 +123,11 @@ func (s *grpcService) AddNIC(ctx context.Context, req *ncproxygrpc.AddNICRequest
if req.ContainerID == "" || req.EndpointName == "" || req.NicID == "" {
return nil, status.Errorf(codes.InvalidArgument, "received empty field in request: %+v", req)
}
if agent, ok := s.containerIDToComputeAgent.get(req.ContainerID); ok {
agent, err := s.containerIDToComputeAgent.get(req.ContainerID)
if err != nil {
return nil, err
}
if agent != nil {
caReq := &computeagent.AddNICInternalRequest{
ContainerID: req.ContainerID,
NicID: req.NicID,
Expand Down Expand Up @@ -112,7 +157,11 @@ func (s *grpcService) ModifyNIC(ctx context.Context, req *ncproxygrpc.ModifyNICR
return nil, status.Error(codes.InvalidArgument, "received empty field in request")
}

if agent, ok := s.containerIDToComputeAgent.get(req.ContainerID); ok {
agent, err := s.containerIDToComputeAgent.get(req.ContainerID)
if err != nil {
return nil, err
}
if agent != nil {
caReq := &computeagent.ModifyNICInternalRequest{
NicID: req.NicID,
EndpointName: req.EndpointName,
Expand Down Expand Up @@ -188,7 +237,11 @@ func (s *grpcService) DeleteNIC(ctx context.Context, req *ncproxygrpc.DeleteNICR
if req.ContainerID == "" || req.EndpointName == "" || req.NicID == "" {
return nil, status.Errorf(codes.InvalidArgument, "received empty field in request: %+v", req)
}
if agent, ok := s.containerIDToComputeAgent.get(req.ContainerID); ok {
agent, err := s.containerIDToComputeAgent.get(req.ContainerID)
if err != nil {
return nil, err
}
if agent != nil {
caReq := &computeagent.DeleteNICInternalRequest{
ContainerID: req.ContainerID,
NicID: req.NicID,
Expand Down Expand Up @@ -598,15 +651,36 @@ func (s *grpcService) GetNetworks(ctx context.Context, req *ncproxygrpc.GetNetwo

// TTRPC service exposed for use by the shim.
type ttrpcService struct {
// containerIDToComputeAgent is a cache that stores the mappings from
// container ID to compute agent address is memory. This is repopulated
// on reconnect and referenced during client calls.
containerIDToComputeAgent *computeAgentCache
// agentStore refers to the database that stores the mappings from
// containerID to compute agent address persistently. This is referenced
// on reconnect and when registering/unregistering a compute agent.
agentStore *computeAgentStore
}

func newTTRPCService(agentCache *computeAgentCache) *ttrpcService {
func newTTRPCService(ctx context.Context, agent *computeAgentCache, agentStore *computeAgentStore) *ttrpcService {
return &ttrpcService{
containerIDToComputeAgent: agentCache,
containerIDToComputeAgent: agent,
agentStore: agentStore,
}
}

func getComputeAgentClient(agentAddr string) (*computeAgentClient, error) {
conn, err := winioDialPipe(agentAddr, nil)
if err != nil {
return nil, errors.Wrap(err, "failed to connect to compute agent service")
}
raw := ttrpcNewClient(
conn,
ttrpc.WithUnaryClientInterceptor(octtrpc.ClientInterceptor()),
ttrpc.WithOnClose(func() { conn.Close() }),
)
return &computeAgentClient{raw, computeagent.NewComputeAgentClient(raw)}, nil
}

func (s *ttrpcService) RegisterComputeAgent(ctx context.Context, req *ncproxyttrpc.RegisterComputeAgentRequest) (_ *ncproxyttrpc.RegisterComputeAgentResponse, err error) {
ctx, span := trace.StartSpan(ctx, "RegisterComputeAgent") //nolint:ineffassign,staticcheck
defer span.End()
Expand All @@ -616,22 +690,51 @@ func (s *ttrpcService) RegisterComputeAgent(ctx context.Context, req *ncproxyttr
trace.StringAttribute("containerID", req.ContainerID),
trace.StringAttribute("agentAddress", req.AgentAddress))

conn, err := winioDialPipe(req.AgentAddress, nil)
agent, err := getComputeAgentClient(req.AgentAddress)
if err != nil {
return nil, errors.Wrap(err, "failed to connect to compute agent service")
return nil, err
}
client := ttrpcNewClient(
conn,
ttrpc.WithUnaryClientInterceptor(octtrpc.ClientInterceptor()),
ttrpc.WithOnClose(func() { conn.Close() }),
)
// Add to global client map if connection succeeds. Don't check if there's already a map entry

if err := s.agentStore.updateComputeAgent(ctx, req.ContainerID, req.AgentAddress); err != nil {
return nil, err
}

// Add to client cache if connection succeeds. Don't check if there's already a map entry
// just overwrite as the client may have changed the address of the config agent.
s.containerIDToComputeAgent.put(req.ContainerID, computeagent.NewComputeAgentClient(client))
if err := s.containerIDToComputeAgent.put(req.ContainerID, agent); err != nil {
return nil, err
}

return &ncproxyttrpc.RegisterComputeAgentResponse{}, nil
}

func (s *ttrpcService) UnregisterComputeAgent(ctx context.Context, req *ncproxyttrpc.UnregisterComputeAgentRequest) (_ *ncproxyttrpc.UnregisterComputeAgentResponse, err error) {
ctx, span := trace.StartSpan(ctx, "UnregisterComputeAgent") //nolint:ineffassign,staticcheck
defer span.End()
defer func() { oc.SetSpanStatus(span, err) }()

span.AddAttributes(
trace.StringAttribute("containerID", req.ContainerID))

err = s.agentStore.deleteComputeAgent(ctx, req.ContainerID)
if err != nil {
log.G(ctx).WithField("key", req.ContainerID).WithError(err).Warn("failed to delete key from compute agent store")
}

// remove the agent from the cache and return it so we can clean up its resources as well
agent, err := s.containerIDToComputeAgent.getAndDelete(req.ContainerID)
if err != nil {
return nil, err
}
if agent != nil {
if err := agent.Close(); err != nil {
return nil, err
}
}

return &ncproxyttrpc.UnregisterComputeAgentResponse{}, nil
}

func (s *ttrpcService) ConfigureNetworking(ctx context.Context, req *ncproxyttrpc.ConfigureNetworkingInternalRequest) (_ *ncproxyttrpc.ConfigureNetworkingInternalResponse, err error) {
ctx, span := trace.StartSpan(ctx, "ConfigureNetworking") //nolint:ineffassign,staticcheck
defer span.End()
Expand Down
44 changes: 40 additions & 4 deletions cmd/ncproxy/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (
"github.com/Microsoft/go-winio/pkg/etwlogrus"
"github.com/Microsoft/go-winio/pkg/guid"
"github.com/Microsoft/hcsshim/cmd/ncproxy/nodenetsvc"
"github.com/Microsoft/hcsshim/internal/computeagent"
"github.com/Microsoft/hcsshim/internal/debug"
"github.com/Microsoft/hcsshim/internal/log"
"github.com/Microsoft/hcsshim/internal/oc"
"github.com/containerd/ttrpc"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/urfave/cli"
Expand All @@ -31,6 +33,18 @@ type nodeNetSvcConn struct {
grpcConn *grpc.ClientConn
}

type computeAgentClient struct {
raw *ttrpc.Client
computeagent.ComputeAgentService
}

func (c *computeAgentClient) Close() error {
if c.raw == nil {
return nil
}
return c.raw.Close()
}

var (
// Global object representing the connection to the node network service that
// ncproxy will be talking to.
Expand Down Expand Up @@ -61,6 +75,10 @@ and 'node network' services.`
Name: "log-directory",
Usage: "Directory to write ncproxy logs to. This is just panic logs.",
},
cli.StringFlag{
Name: "database-path",
Usage: "Path to database file storing information on container to compute agent mapping.",
},
cli.BoolFlag{
Name: "register-service",
Usage: "Register ncproxy as a Windows service.",
Expand Down Expand Up @@ -89,6 +107,7 @@ func run(clicontext *cli.Context) error {
var (
configPath = clicontext.GlobalString("config")
logDir = clicontext.GlobalString("log-directory")
dbPath = clicontext.GlobalString("database-path")
registerSvc = clicontext.GlobalBool("register-service")
unregisterSvc = clicontext.GlobalBool("unregister-service")
runSvc = clicontext.GlobalBool("run-service")
Expand Down Expand Up @@ -194,6 +213,24 @@ func run(clicontext *cli.Context) error {
}
}

// setup ncproxy databases
if dbPath == "" {
// default location for ncproxy database
binLocation, err := os.Executable()
if err != nil {
return err
}
dbPath = filepath.Dir(binLocation) + "networkproxy.db"
} else {
// If a db path was provided, make sure parent directories exist
dir := filepath.Dir(dbPath)
if _, err := os.Stat(dir); err != nil {
if err := os.MkdirAll(dir, 0); err != nil {
return errors.Wrap(err, "failed to make database directory")
}
}
}

log.G(ctx).WithFields(logrus.Fields{
"TTRPCAddr": conf.TTRPCAddr,
"NodeNetSvcAddr": conf.NodeNetSvcAddr,
Expand All @@ -207,10 +244,11 @@ func run(clicontext *cli.Context) error {
defer signal.Stop(sigChan)

// Create new server and then register NetworkConfigProxyServices.
server, err := newServer(ctx, conf)
server, err := newServer(ctx, conf, dbPath)
if err != nil {
return errors.New("failed to make new ncproxy server")
}
defer server.cleanupResources(ctx)

ttrpcListener, grpcListener, err := server.setup(ctx)
if err != nil {
Expand All @@ -232,9 +270,7 @@ func run(clicontext *cli.Context) error {
}

// Cancel inflight requests and shutdown services
if err := server.gracefulShutdown(ctx); err != nil {
return errors.Wrap(err, "ncproxy failed to shutdown gracefully")
}
server.gracefulShutdown(ctx)

return nil
}
Loading