Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 15 additions & 43 deletions pkg/adaptation/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,10 @@ type plugin struct {
rpcs *ttrpc.Server
events EventMask
closed bool
stub api.PluginService
regC chan error
closeC chan struct{}
r *Adaptation
wasm api.Plugin
impl *pluginType
}

// SetPluginRegistrationTimeout sets the timeout for plugin registration.
Expand Down Expand Up @@ -114,7 +113,7 @@ func (r *Adaptation) newLaunchedPlugin(dir, idx, base, cfg string) (p *plugin, r
idx: idx,
base: base,
r: r,
wasm: wasm,
impl: &pluginType{wasmImpl: wasm},
}, nil
}

Expand Down Expand Up @@ -258,7 +257,6 @@ func (p *plugin) connect(conn stdnet.Conn) (retErr error) {
rpcc.Close()
}
}()
stub := api.NewPluginClient(rpcc)

rpcs, err := ttrpc.NewServer(p.r.serverOpts...)
if err != nil {
Expand All @@ -279,7 +277,7 @@ func (p *plugin) connect(conn stdnet.Conn) (retErr error) {
p.rpcc = rpcc
p.rpcl = rpcl
p.rpcs = rpcs
p.stub = stub
p.impl = &pluginType{ttrpcImpl: api.NewPluginClient(rpcc)}

p.pid, err = getPeerPid(p.mux.Trunk())
if err != nil {
Expand All @@ -295,7 +293,7 @@ func (p *plugin) connect(conn stdnet.Conn) (retErr error) {
func (p *plugin) start(name, version string) (err error) {
// skip start for WASM plugins and head right to the registration for
// events config
if p.wasm == nil {
if p.impl.isTtrpc() {
var (
err error
timeout = getPluginRegistrationTimeout()
Expand Down Expand Up @@ -337,7 +335,7 @@ func (p *plugin) start(name, version string) (err error) {

// close a plugin shutting down its multiplexed ttrpc connections.
func (p *plugin) close() {
if p.wasm != nil {
if p.impl.isWasm() {
return
}

Expand All @@ -362,7 +360,7 @@ func (p *plugin) isClosed() bool {

// stop a plugin (if it was launched by us)
func (p *plugin) stop() error {
if p.isExternal() || p.cmd.Process == nil || p.wasm != nil {
if p.isExternal() || p.cmd.Process == nil || p.impl.isWasm() {
return nil
}

Expand Down Expand Up @@ -436,19 +434,15 @@ func (p *plugin) configure(ctx context.Context, name, version, config string) (e
ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout())
defer cancel()

var rpl *api.ConfigureResponse
req := &ConfigureRequest{
Config: config,
RuntimeName: name,
RuntimeVersion: version,
RegistrationTimeout: getPluginRegistrationTimeout().Milliseconds(),
RequestTimeout: getPluginRequestTimeout().Milliseconds(),
}
if p.wasm != nil {
rpl, err = p.wasm.Configure(ctx, req)
} else {
rpl, err = p.stub.Configure(ctx, req)
}

rpl, err := p.impl.Configure(ctx, req)
if err != nil {
return fmt.Errorf("failed to configure plugin: %w", err)
}
Expand Down Expand Up @@ -493,12 +487,7 @@ func (p *plugin) synchronize(ctx context.Context, pods []*PodSandbox, containers
log.Debugf(ctx, "sending sync message, %d/%d, %d/%d (more: %v)",
len(req.Pods), len(podsToSend), len(req.Containers), len(ctrsToSend), req.More)

if p.wasm != nil {
rpl, err = p.wasm.Synchronize(ctx, req)
} else {
rpl, err = p.stub.Synchronize(ctx, req)
}

rpl, err = p.impl.Synchronize(ctx, req)
if err == nil {
if !req.More {
break
Expand Down Expand Up @@ -574,19 +563,15 @@ func recalcObjsPerSyncMsg(pods, ctrs int, err error) (int, int, error) {
}

// Relay CreateContainer request to plugin.
func (p *plugin) createContainer(ctx context.Context, req *CreateContainerRequest) (rpl *CreateContainerResponse, err error) {
func (p *plugin) createContainer(ctx context.Context, req *CreateContainerRequest) (*CreateContainerResponse, error) {
if !p.events.IsSet(Event_CREATE_CONTAINER) {
return nil, nil
}

ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout())
defer cancel()

if p.wasm != nil {
rpl, err = p.wasm.CreateContainer(ctx, req)
} else {
rpl, err = p.stub.CreateContainer(ctx, req)
}
rpl, err := p.impl.CreateContainer(ctx, req)
if err != nil {
if isFatalError(err) {
log.Errorf(ctx, "closing plugin %s, failed to handle CreateContainer request: %v",
Expand All @@ -601,19 +586,15 @@ func (p *plugin) createContainer(ctx context.Context, req *CreateContainerReques
}

// Relay UpdateContainer request to plugin.
func (p *plugin) updateContainer(ctx context.Context, req *UpdateContainerRequest) (rpl *UpdateContainerResponse, err error) {
func (p *plugin) updateContainer(ctx context.Context, req *UpdateContainerRequest) (*UpdateContainerResponse, error) {
if !p.events.IsSet(Event_UPDATE_CONTAINER) {
return nil, nil
}

ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout())
defer cancel()

if p.wasm != nil {
rpl, err = p.wasm.UpdateContainer(ctx, req)
} else {
rpl, err = p.stub.UpdateContainer(ctx, req)
}
rpl, err := p.impl.UpdateContainer(ctx, req)
if err != nil {
if isFatalError(err) {
log.Errorf(ctx, "closing plugin %s, failed to handle UpdateContainer request: %v",
Expand All @@ -636,11 +617,7 @@ func (p *plugin) stopContainer(ctx context.Context, req *StopContainerRequest) (
ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout())
defer cancel()

if p.wasm != nil {
rpl, err = p.wasm.StopContainer(ctx, req)
} else {
rpl, err = p.stub.StopContainer(ctx, req)
}
rpl, err = p.impl.StopContainer(ctx, req)
if err != nil {
if isFatalError(err) {
log.Errorf(ctx, "closing plugin %s, failed to handle StopContainer request: %v",
Expand All @@ -663,12 +640,7 @@ func (p *plugin) StateChange(ctx context.Context, evt *StateChangeEvent) (err er
ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout())
defer cancel()

if p.wasm != nil {
_, err = p.wasm.StateChange(ctx, evt)
} else {
_, err = p.stub.StateChange(ctx, evt)
}
if err != nil {
if err = p.impl.StateChange(ctx, evt); err != nil {
if isFatalError(err) {
log.Errorf(ctx, "closing plugin %s, failed to handle event %d: %v",
p.name(), evt.Event, err)
Expand Down
80 changes: 80 additions & 0 deletions pkg/adaptation/plugin_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
Copyright The containerd Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package adaptation

import (
"context"

"github.com/containerd/nri/pkg/api"
)

type pluginType struct {
wasmImpl api.Plugin
ttrpcImpl api.PluginService
}

func (p *pluginType) isWasm() bool {
return p.wasmImpl != nil
}

func (p *pluginType) isTtrpc() bool {
return p.ttrpcImpl != nil
}

func (p *pluginType) Synchronize(ctx context.Context, req *SynchronizeRequest) (*SynchronizeResponse, error) {
if p.wasmImpl != nil {
return p.wasmImpl.Synchronize(ctx, req)
}
return p.ttrpcImpl.Synchronize(ctx, req)
}

func (p *pluginType) Configure(ctx context.Context, req *ConfigureRequest) (*ConfigureResponse, error) {
if p.wasmImpl != nil {
return p.wasmImpl.Configure(ctx, req)
}
return p.ttrpcImpl.Configure(ctx, req)
}

func (p *pluginType) CreateContainer(ctx context.Context, req *CreateContainerRequest) (*CreateContainerResponse, error) {
if p.wasmImpl != nil {
return p.wasmImpl.CreateContainer(ctx, req)
}
return p.ttrpcImpl.CreateContainer(ctx, req)
}

func (p *pluginType) UpdateContainer(ctx context.Context, req *UpdateContainerRequest) (*UpdateContainerResponse, error) {
if p.wasmImpl != nil {
return p.wasmImpl.UpdateContainer(ctx, req)
}
return p.ttrpcImpl.UpdateContainer(ctx, req)
}

func (p *pluginType) StopContainer(ctx context.Context, req *StopContainerRequest) (*StopContainerResponse, error) {
if p.wasmImpl != nil {
return p.wasmImpl.StopContainer(ctx, req)
}
return p.ttrpcImpl.StopContainer(ctx, req)
}

func (p *pluginType) StateChange(ctx context.Context, req *StateChangeEvent) (err error) {
if p.wasmImpl != nil {
_, err = p.wasmImpl.StateChange(ctx, req)
} else {
_, err = p.ttrpcImpl.StateChange(ctx, req)
}
return err
}