diff --git a/pkg/adaptation/plugin.go b/pkg/adaptation/plugin.go index 2078ecd3..f9e9cffb 100644 --- a/pkg/adaptation/plugin.go +++ b/pkg/adaptation/plugin.go @@ -63,11 +63,10 @@ type plugin struct { rpcs *ttrpc.Server events EventMask closed bool - stub api.PluginService regC chan error closeC chan struct{} r *Adaptation - wasm api.Plugin + impl *pluginType } // SetPluginRegistrationTimeout sets the timeout for plugin registration. @@ -114,7 +113,7 @@ func (r *Adaptation) newLaunchedPlugin(dir, idx, base, cfg string) (p *plugin, r idx: idx, base: base, r: r, - wasm: wasm, + impl: &pluginType{wasmImpl: wasm}, }, nil } @@ -258,7 +257,6 @@ func (p *plugin) connect(conn stdnet.Conn) (retErr error) { rpcc.Close() } }() - stub := api.NewPluginClient(rpcc) rpcs, err := ttrpc.NewServer(p.r.serverOpts...) if err != nil { @@ -279,7 +277,7 @@ func (p *plugin) connect(conn stdnet.Conn) (retErr error) { p.rpcc = rpcc p.rpcl = rpcl p.rpcs = rpcs - p.stub = stub + p.impl = &pluginType{ttrpcImpl: api.NewPluginClient(rpcc)} p.pid, err = getPeerPid(p.mux.Trunk()) if err != nil { @@ -295,7 +293,7 @@ func (p *plugin) connect(conn stdnet.Conn) (retErr error) { func (p *plugin) start(name, version string) (err error) { // skip start for WASM plugins and head right to the registration for // events config - if p.wasm == nil { + if p.impl.isTtrpc() { var ( err error timeout = getPluginRegistrationTimeout() @@ -337,7 +335,7 @@ func (p *plugin) start(name, version string) (err error) { // close a plugin shutting down its multiplexed ttrpc connections. func (p *plugin) close() { - if p.wasm != nil { + if p.impl.isWasm() { return } @@ -362,7 +360,7 @@ func (p *plugin) isClosed() bool { // stop a plugin (if it was launched by us) func (p *plugin) stop() error { - if p.isExternal() || p.cmd.Process == nil || p.wasm != nil { + if p.isExternal() || p.cmd.Process == nil || p.impl.isWasm() { return nil } @@ -436,7 +434,6 @@ func (p *plugin) configure(ctx context.Context, name, version, config string) (e ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout()) defer cancel() - var rpl *api.ConfigureResponse req := &ConfigureRequest{ Config: config, RuntimeName: name, @@ -444,11 +441,8 @@ func (p *plugin) configure(ctx context.Context, name, version, config string) (e RegistrationTimeout: getPluginRegistrationTimeout().Milliseconds(), RequestTimeout: getPluginRequestTimeout().Milliseconds(), } - if p.wasm != nil { - rpl, err = p.wasm.Configure(ctx, req) - } else { - rpl, err = p.stub.Configure(ctx, req) - } + + rpl, err := p.impl.Configure(ctx, req) if err != nil { return fmt.Errorf("failed to configure plugin: %w", err) } @@ -493,12 +487,7 @@ func (p *plugin) synchronize(ctx context.Context, pods []*PodSandbox, containers log.Debugf(ctx, "sending sync message, %d/%d, %d/%d (more: %v)", len(req.Pods), len(podsToSend), len(req.Containers), len(ctrsToSend), req.More) - if p.wasm != nil { - rpl, err = p.wasm.Synchronize(ctx, req) - } else { - rpl, err = p.stub.Synchronize(ctx, req) - } - + rpl, err = p.impl.Synchronize(ctx, req) if err == nil { if !req.More { break @@ -574,7 +563,7 @@ func recalcObjsPerSyncMsg(pods, ctrs int, err error) (int, int, error) { } // Relay CreateContainer request to plugin. -func (p *plugin) createContainer(ctx context.Context, req *CreateContainerRequest) (rpl *CreateContainerResponse, err error) { +func (p *plugin) createContainer(ctx context.Context, req *CreateContainerRequest) (*CreateContainerResponse, error) { if !p.events.IsSet(Event_CREATE_CONTAINER) { return nil, nil } @@ -582,11 +571,7 @@ func (p *plugin) createContainer(ctx context.Context, req *CreateContainerReques ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout()) defer cancel() - if p.wasm != nil { - rpl, err = p.wasm.CreateContainer(ctx, req) - } else { - rpl, err = p.stub.CreateContainer(ctx, req) - } + rpl, err := p.impl.CreateContainer(ctx, req) if err != nil { if isFatalError(err) { log.Errorf(ctx, "closing plugin %s, failed to handle CreateContainer request: %v", @@ -601,7 +586,7 @@ func (p *plugin) createContainer(ctx context.Context, req *CreateContainerReques } // Relay UpdateContainer request to plugin. -func (p *plugin) updateContainer(ctx context.Context, req *UpdateContainerRequest) (rpl *UpdateContainerResponse, err error) { +func (p *plugin) updateContainer(ctx context.Context, req *UpdateContainerRequest) (*UpdateContainerResponse, error) { if !p.events.IsSet(Event_UPDATE_CONTAINER) { return nil, nil } @@ -609,11 +594,7 @@ func (p *plugin) updateContainer(ctx context.Context, req *UpdateContainerReques ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout()) defer cancel() - if p.wasm != nil { - rpl, err = p.wasm.UpdateContainer(ctx, req) - } else { - rpl, err = p.stub.UpdateContainer(ctx, req) - } + rpl, err := p.impl.UpdateContainer(ctx, req) if err != nil { if isFatalError(err) { log.Errorf(ctx, "closing plugin %s, failed to handle UpdateContainer request: %v", @@ -636,11 +617,7 @@ func (p *plugin) stopContainer(ctx context.Context, req *StopContainerRequest) ( ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout()) defer cancel() - if p.wasm != nil { - rpl, err = p.wasm.StopContainer(ctx, req) - } else { - rpl, err = p.stub.StopContainer(ctx, req) - } + rpl, err = p.impl.StopContainer(ctx, req) if err != nil { if isFatalError(err) { log.Errorf(ctx, "closing plugin %s, failed to handle StopContainer request: %v", @@ -663,12 +640,7 @@ func (p *plugin) StateChange(ctx context.Context, evt *StateChangeEvent) (err er ctx, cancel := context.WithTimeout(ctx, getPluginRequestTimeout()) defer cancel() - if p.wasm != nil { - _, err = p.wasm.StateChange(ctx, evt) - } else { - _, err = p.stub.StateChange(ctx, evt) - } - if err != nil { + if err = p.impl.StateChange(ctx, evt); err != nil { if isFatalError(err) { log.Errorf(ctx, "closing plugin %s, failed to handle event %d: %v", p.name(), evt.Event, err) diff --git a/pkg/adaptation/plugin_type.go b/pkg/adaptation/plugin_type.go new file mode 100644 index 00000000..a51dce68 --- /dev/null +++ b/pkg/adaptation/plugin_type.go @@ -0,0 +1,80 @@ +/* + Copyright The containerd Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package adaptation + +import ( + "context" + + "github.com/containerd/nri/pkg/api" +) + +type pluginType struct { + wasmImpl api.Plugin + ttrpcImpl api.PluginService +} + +func (p *pluginType) isWasm() bool { + return p.wasmImpl != nil +} + +func (p *pluginType) isTtrpc() bool { + return p.ttrpcImpl != nil +} + +func (p *pluginType) Synchronize(ctx context.Context, req *SynchronizeRequest) (*SynchronizeResponse, error) { + if p.wasmImpl != nil { + return p.wasmImpl.Synchronize(ctx, req) + } + return p.ttrpcImpl.Synchronize(ctx, req) +} + +func (p *pluginType) Configure(ctx context.Context, req *ConfigureRequest) (*ConfigureResponse, error) { + if p.wasmImpl != nil { + return p.wasmImpl.Configure(ctx, req) + } + return p.ttrpcImpl.Configure(ctx, req) +} + +func (p *pluginType) CreateContainer(ctx context.Context, req *CreateContainerRequest) (*CreateContainerResponse, error) { + if p.wasmImpl != nil { + return p.wasmImpl.CreateContainer(ctx, req) + } + return p.ttrpcImpl.CreateContainer(ctx, req) +} + +func (p *pluginType) UpdateContainer(ctx context.Context, req *UpdateContainerRequest) (*UpdateContainerResponse, error) { + if p.wasmImpl != nil { + return p.wasmImpl.UpdateContainer(ctx, req) + } + return p.ttrpcImpl.UpdateContainer(ctx, req) +} + +func (p *pluginType) StopContainer(ctx context.Context, req *StopContainerRequest) (*StopContainerResponse, error) { + if p.wasmImpl != nil { + return p.wasmImpl.StopContainer(ctx, req) + } + return p.ttrpcImpl.StopContainer(ctx, req) +} + +func (p *pluginType) StateChange(ctx context.Context, req *StateChangeEvent) (err error) { + if p.wasmImpl != nil { + _, err = p.wasmImpl.StateChange(ctx, req) + } else { + _, err = p.ttrpcImpl.StateChange(ctx, req) + } + return err +}